Compare commits

..

1 Commits

117 changed files with 1220 additions and 4341 deletions

1
.nvmrc
View File

@@ -1 +0,0 @@
v22.12.0

View File

@@ -39,7 +39,7 @@ It has two sections - one for internal use and one for user settings:
```yaml
# Internal metadata - do not edit:
schema_version: 4.0.2
schema_version: 4
# Put user settings here - see https://invoke-ai.github.io/InvokeAI/features/CONFIGURATION/:
host: 0.0.0.0 # serve the app on your local network
@@ -83,10 +83,6 @@ A subset of settings may be specified using CLI args:
- `--root`: specify the root directory
- `--config`: override the default `invokeai.yaml` file location
### Low-VRAM Mode
See the [Low-VRAM mode docs][low-vram] for details on enabling this feature.
### All Settings
Following the table are additional explanations for certain settings.
@@ -189,4 +185,3 @@ The `log_format` option provides several alternative formats:
[basic guide to yaml files]: https://circleci.com/blog/what-is-yaml-a-beginner-s-guide/
[Model Marketplace API Keys]: #model-marketplace-api-keys
[low-vram]: ./features/low-vram.md

View File

@@ -22,7 +22,7 @@ If you just want to use Invoke, you should use the [launcher][launcher link].
4. Follow the [manual install][manual install link] guide, with some modifications to the install command:
- Use `.` instead of `invokeai` to install from the current directory. You don't need to specify the version.
- Use `.` instead of `invokeai` to install from the current directory.
- Add `-e` after the `install` operation to make this an [editable install][editable install link]. That means your changes to the python code will be reflected when you restart the Invoke server.

Binary file not shown.

Before

Width:  |  Height:  |  Size: 72 KiB

View File

@@ -1,129 +0,0 @@
---
title: Low-VRAM mode
---
As of v5.6.0, Invoke has a low-VRAM mode. It works on systems with dedicated GPUs (Nvidia GPUs on Windows/Linux and AMD GPUs on Linux).
This allows you to generate even if your GPU doesn't have enough VRAM to hold full models. Most users should be able to run even the beefiest models - like the ~24GB unquantised FLUX dev model.
## Enabling Low-VRAM mode
To enable Low-VRAM mode, add this line to your `invokeai.yaml` configuration file, then restart Invoke:
```yaml
enable_partial_loading: true
```
**Windows users should also [disable the Nvidia sysmem fallback](#disabling-nvidia-sysmem-fallback-windows-only)**.
It is possible to fine-tune the settings for best performance or if you still get out-of-memory errors (OOMs).
!!! tip "How to find `invokeai.yaml`"
The `invokeai.yaml` configuration file lives in your install directory. To access it, run the **Invoke Community Edition** launcher and click the install location. This will open your install directory in a file explorer window.
You'll see `invokeai.yaml` there and can edit it with any text editor. After making changes, restart Invoke.
If you don't see `invokeai.yaml`, launch Invoke once. It will create the file on its first startup.
## Details and fine-tuning
Low-VRAM mode involves 3 features, each of which can be configured or fine-tuned:
- Partial model loading
- Dynamic RAM and VRAM cache sizes
- Working memory
Read on to learn about these features and understand how to fine-tune them for your system and use-cases.
### Partial model loading
Invoke's partial model loading works by streaming model "layers" between RAM and VRAM as they are needed.
When an operation needs layers that are not in VRAM, but there isn't enough room to load them, inactive layers are offloaded to RAM to make room.
#### Enabling partial model loading
As described above, you can enable partial model loading by adding this line to `invokeai.yaml`:
```yaml
enable_partial_loading: true
```
### Dynamic RAM and VRAM cache sizes
Loading models from disk is slow and can be a major bottleneck for performance. Invoke uses two model caches - RAM and VRAM - to reduce loading from disk to a minimum.
By default, Invoke manages these caches' sizes dynamically for best performance.
#### Fine-tuning cache sizes
Prior to v5.6.0, the cache sizes were static, and for best performance, many users needed to manually fine-tune the `ram` and `vram` settings in `invokeai.yaml`.
As of v5.6.0, the caches are dynamically sized. The `ram` and `vram` settings are no longer used, and new settings are added to configure the cache.
**Most users will not need to fine-tune the cache sizes.**
But, if your GPU has enough VRAM to hold models fully, you might get a perf boost by manually setting the cache sizes in `invokeai.yaml`:
```yaml
# Set the RAM cache size to as large as possible, leaving a few GB free for the rest of your system and Invoke.
# For example, if your system has 32GB RAM, 28GB is a good value.
max_cache_ram_gb: 28
# Set the VRAM cache size to be as large as possible while leaving enough room for the working memory of the tasks you will be doing.
# For example, on a 24GB GPU that will be running unquantized FLUX without any auxiliary models,
# 18GB is a good value.
max_cache_vram_gb: 18
```
!!! tip "Max safe value for `max_cache_vram_gb`"
To determine the max safe value for `max_cache_vram_gb`, subtract `device_working_mem_gb` from your GPU's VRAM. As described below, the default for `device_working_mem_gb` is 3GB.
For example, if you have a 12GB GPU, the max safe value for `max_cache_vram_gb` is `12GB - 3GB = 9GB`.
If you had increased `device_working_mem_gb` to 4GB, then the max safe value for `max_cache_vram_gb` is `12GB - 4GB = 8GB`.
### Working memory
Invoke cannot use _all_ of your VRAM for model caching and loading. It requires some VRAM to use as working memory for various operations.
Invoke reserves 3GB VRAM as working memory by default, which is enough for most use-cases. However, it is possible to fine-tune this setting if you still get OOMs.
#### Fine-tuning working memory
You can increase the working memory size in `invokeai.yaml` to prevent OOMs:
```yaml
# The default is 3GB - bump it up to 4GB to prevent OOMs.
device_working_mem_gb: 4
```
!!! tip "Operations may request more working memory"
For some operations, we can determine VRAM requirements in advance and allocate additional working memory to prevent OOMs.
VAE decoding is one such operation. This operation converts the generation process's output into an image. For large image outputs, this might use more than the default working memory size of 3GB.
During this decoding step, Invoke calculates how much VRAM will be required to decode and requests that much VRAM from the model manager. If the amount exceeds the working memory size, the model manager will offload cached model layers from VRAM until there's enough VRAM to decode.
Once decoding completes, the model manager "reclaims" the extra VRAM allocated as working memory for future model loading operations.
### Disabling Nvidia sysmem fallback (Windows only)
On Windows, Nvidia GPUs are able to use system RAM when their VRAM fills up via **sysmem fallback**. While it sounds like a good idea on the surface, in practice it causes massive slowdowns during generation.
It is strongly suggested to disable this feature:
- Open the **NVIDIA Control Panel** app.
- Expand **3D Settings** on the left panel.
- Click **Manage 3D Settings** in the left panel.
- Find **CUDA - Sysmem Fallback Policy** in the right panel and set it to **Prefer No Sysmem Fallback**.
![cuda-sysmem-fallback](./cuda-sysmem-fallback.png)
!!! tip "Invoke does the same thing, but better"
If the sysmem fallback feature sounds familiar, that's because Invoke's partial model loading strategy is conceptually very similar - use VRAM when there's room, else fall back to RAM.
Unfortunately, the Nvidia implementation is not optimized for applications like Invoke and does more harm than good.

View File

@@ -75,14 +75,14 @@ The following commands vary depending on the version of Invoke being installed a
- If you are on Windows with an Nvidia GPU, use `https://download.pytorch.org/whl/cu124`.
- If you are on Linux with no GPU, use `https://download.pytorch.org/whl/cpu`.
- If you are on Linux with an AMD GPU, use `https://download.pytorch.org/whl/rocm6.1`.
- If you are on Linux with an AMD GPU, use `https://download.pytorch.org/whl/rocm62`.
- **In all other cases, do not use an index.**
=== "Invoke v4"
- If you are on Windows with an Nvidia GPU, use `https://download.pytorch.org/whl/cu124`.
- If you are on Linux with no GPU, use `https://download.pytorch.org/whl/cpu`.
- If you are on Linux with an AMD GPU, use `https://download.pytorch.org/whl/rocm5.2`.
- If you are on Linux with an AMD GPU, use `https://download.pytorch.org/whl/rocm52`.
- **In all other cases, do not use an index.**
8. Install the `invokeai` package. Substitute the package specifier and version.

View File

@@ -54,7 +54,7 @@ If you have an existing Invoke installation, you can select it and let the launc
- Open the **Invoke-Installer-mac-arm64.dmg** file.
- Drag the launcher to **Applications**.
- Open a terminal.
- Run `xattr -d 'com.apple.quarantine' /Applications/Invoke\ Community\ Edition.app`.
- Run `xattr -cr /Applications/Invoke-Installer.app`.
You should now be able to run the launcher.

View File

@@ -535,7 +535,7 @@ View:
**Node Link:** https://github.com/simonfuhrmann/invokeai-stereo
**Example Workflow and Output**
</br><img src="https://raw.githubusercontent.com/simonfuhrmann/invokeai-stereo/refs/heads/main/docs/example_promo_03.jpg" width="600" />
</br><img src="https://github.com/simonfuhrmann/invokeai-stereo/blob/main/docs/example_promo_03.jpg" width="500" />
--------------------------------
### Simple Skin Detection

View File

@@ -4,6 +4,7 @@
import contextlib
import io
import pathlib
import shutil
import traceback
from copy import deepcopy
from enum import Enum
@@ -20,6 +21,7 @@ from starlette.exceptions import HTTPException
from typing_extensions import Annotated
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.config import get_config
from invokeai.app.services.model_images.model_images_common import ModelImageFileNotFoundException
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
from invokeai.app.services.model_records import (
@@ -846,6 +848,74 @@ async def get_starter_models() -> StarterModelResponse:
return StarterModelResponse(starter_models=starter_models, starter_bundles=starter_bundles)
@model_manager_router.get(
"/model_cache",
operation_id="get_cache_size",
response_model=float,
summary="Get maximum size of model manager RAM or VRAM cache.",
)
async def get_cache_size(cache_type: CacheType = Query(description="The cache type", default=CacheType.RAM)) -> float:
"""Return the current RAM or VRAM cache size setting (in GB)."""
cache = ApiDependencies.invoker.services.model_manager.load.ram_cache
value = 0.0
if cache_type == CacheType.RAM:
value = cache.max_cache_size
elif cache_type == CacheType.VRAM:
value = cache.max_vram_cache_size
return value
@model_manager_router.put(
"/model_cache",
operation_id="set_cache_size",
response_model=float,
summary="Set maximum size of model manager RAM or VRAM cache, optionally writing new value out to invokeai.yaml config file.",
)
async def set_cache_size(
value: float = Query(description="The new value for the maximum cache size"),
cache_type: CacheType = Query(description="The cache type", default=CacheType.RAM),
persist: bool = Query(description="Write new value out to invokeai.yaml", default=False),
) -> float:
"""Set the current RAM or VRAM cache size setting (in GB). ."""
cache = ApiDependencies.invoker.services.model_manager.load.ram_cache
app_config = get_config()
# Record initial state.
vram_old = app_config.vram
ram_old = app_config.ram
# Prepare target state.
vram_new = vram_old
ram_new = ram_old
if cache_type == CacheType.RAM:
ram_new = value
elif cache_type == CacheType.VRAM:
vram_new = value
else:
raise ValueError(f"Unexpected {cache_type=}.")
config_path = app_config.config_file_path
new_config_path = config_path.with_suffix(".yaml.new")
try:
# Try to apply the target state.
cache.max_vram_cache_size = vram_new
cache.max_cache_size = ram_new
app_config.ram = ram_new
app_config.vram = vram_new
if persist:
app_config.write_file(new_config_path)
shutil.move(new_config_path, config_path)
except Exception as e:
# If there was a failure, restore the initial state.
cache.max_cache_size = ram_old
cache.max_vram_cache_size = vram_old
app_config.ram = ram_old
app_config.vram = vram_old
raise RuntimeError("Failed to update cache size") from e
return value
@model_manager_router.get(
"/stats",
operation_id="get_stats",

View File

@@ -1,118 +0,0 @@
from typing import Literal
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
Classification,
invocation,
)
from invokeai.app.invocations.fields import (
ImageField,
Input,
InputField,
)
from invokeai.app.invocations.primitives import FloatOutput, ImageOutput, IntegerOutput, StringOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
BATCH_GROUP_IDS = Literal[
"None",
"Group 1",
"Group 2",
"Group 3",
"Group 4",
"Group 5",
]
class NotExecutableNodeError(Exception):
def __init__(self, message: str = "This class should never be executed or instantiated directly."):
super().__init__(message)
pass
class BaseBatchInvocation(BaseInvocation):
batch_group_id: BATCH_GROUP_IDS = InputField(
default="None",
description="The ID of this batch node's group. If provided, all batch nodes in with the same ID will be 'zipped' before execution, and all nodes' collections must be of the same size.",
input=Input.Direct,
title="Batch Group",
)
def __init__(self):
raise NotExecutableNodeError()
@invocation(
"image_batch",
title="Image Batch",
tags=["primitives", "image", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class ImageBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each image in the batch."""
images: list[ImageField] = InputField(
default=[], min_length=1, description="The images to batch over", input=Input.Direct
)
def invoke(self, context: InvocationContext) -> ImageOutput:
raise NotExecutableNodeError()
@invocation(
"string_batch",
title="String Batch",
tags=["primitives", "string", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class StringBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each string in the batch."""
strings: list[str] = InputField(
default=[], min_length=1, description="The strings to batch over", input=Input.Direct
)
def invoke(self, context: InvocationContext) -> StringOutput:
raise NotExecutableNodeError()
@invocation(
"integer_batch",
title="Integer Batch",
tags=["primitives", "integer", "number", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class IntegerBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each integer in the batch."""
integers: list[int] = InputField(
default=[], min_length=1, description="The integers to batch over", input=Input.Direct
)
def invoke(self, context: InvocationContext) -> IntegerOutput:
raise NotExecutableNodeError()
@invocation(
"float_batch",
title="Float Batch",
tags=["primitives", "float", "number", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class FloatBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each float in the batch."""
floats: list[float] = InputField(
default=[], min_length=1, description="The floats to batch over", input=Input.Direct
)
def invoke(self, context: InvocationContext) -> FloatOutput:
raise NotExecutableNodeError()

View File

@@ -63,6 +63,9 @@ class CompelInvocation(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
tokenizer_info = context.models.load(self.clip.tokenizer)
text_encoder_info = context.models.load(self.clip.text_encoder)
def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
for lora in self.clip.loras:
lora_info = context.models.load(lora.lora)
@@ -73,13 +76,12 @@ class CompelInvocation(BaseInvocation):
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
text_encoder_info = context.models.load(self.clip.text_encoder)
ti_list = generate_ti_list(self.prompt, text_encoder_info.config.base, context)
with (
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
context.models.load(self.clip.tokenizer) as tokenizer,
tokenizer_info as tokenizer,
LayerPatcher.apply_smart_model_patches(
model=text_encoder,
patches=_lora_loader(),
@@ -103,7 +105,6 @@ class CompelInvocation(BaseInvocation):
textual_inversion_manager=ti_manager,
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
truncate_long_prompts=False,
device=TorchDevice.choose_torch_device(),
)
conjunction = Compel.parse_prompt_string(self.prompt)
@@ -138,7 +139,9 @@ class SDXLPromptInvocationBase:
lora_prefix: str,
zero_on_empty: bool,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
tokenizer_info = context.models.load(clip_field.tokenizer)
text_encoder_info = context.models.load(clip_field.text_encoder)
# return zero on empty
if prompt == "" and zero_on_empty:
cpu_text_encoder = text_encoder_info.model
@@ -176,7 +179,7 @@ class SDXLPromptInvocationBase:
with (
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
context.models.load(clip_field.tokenizer) as tokenizer,
tokenizer_info as tokenizer,
LayerPatcher.apply_smart_model_patches(
model=text_encoder,
patches=_lora_loader(),
@@ -204,7 +207,6 @@ class SDXLPromptInvocationBase:
truncate_long_prompts=False, # TODO:
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
requires_pooled=get_pooled,
device=TorchDevice.choose_torch_device(),
)
conjunction = Compel.parse_prompt_string(prompt)
@@ -222,6 +224,7 @@ class SDXLPromptInvocationBase:
del tokenizer
del text_encoder
del tokenizer_info
del text_encoder_info
c = c.detach().to("cpu")

View File

@@ -10,9 +10,7 @@ import torchvision.transforms as T
from diffusers.configuration_utils import ConfigMixin
from diffusers.models.adapter import T2IAdapter
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
from diffusers.schedulers.scheduling_dpmsolver_sde import DPMSolverSDEScheduler
from diffusers.schedulers.scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler
from diffusers.schedulers.scheduling_tcd import TCDScheduler
from diffusers.schedulers.scheduling_utils import SchedulerMixin as Scheduler
from PIL import Image
@@ -40,7 +38,6 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
from invokeai.backend.model_manager.config import AnyModelConfig
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
@@ -86,14 +83,12 @@ def get_scheduler(
scheduler_info: ModelIdentifierField,
scheduler_name: str,
seed: int,
unet_config: AnyModelConfig,
) -> Scheduler:
"""Load a scheduler and apply some scheduler-specific overrides."""
# TODO(ryand): Silently falling back to ddim seems like a bad idea. Look into why this was added and remove if
# possible.
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
orig_scheduler_info = context.models.load(scheduler_info)
with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config
@@ -105,17 +100,10 @@ def get_scheduler(
"_backup": scheduler_config,
}
if hasattr(unet_config, "prediction_type"):
scheduler_config["prediction_type"] = unet_config.prediction_type
# make dpmpp_sde reproducable(seed can be passed only in initializer)
if scheduler_class is DPMSolverSDEScheduler:
scheduler_config["noise_sampler_seed"] = seed
if scheduler_class is DPMSolverMultistepScheduler or scheduler_class is DPMSolverSinglestepScheduler:
if scheduler_config["_class_name"] == "DEISMultistepScheduler" and scheduler_config["algorithm_type"] == "deis":
scheduler_config["algorithm_type"] = "dpmsolver++"
scheduler = scheduler_class.from_config(scheduler_config)
# hack copied over from generate.py
@@ -423,7 +411,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
context: InvocationContext,
control_input: ControlField | list[ControlField] | None,
latents_shape: List[int],
device: torch.device,
exit_stack: ExitStack,
do_classifier_free_guidance: bool = True,
) -> list[ControlNetData] | None:
@@ -465,7 +452,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
height=control_height_resize,
# batch_size=batch_size * num_images_per_prompt,
# num_images_per_prompt=num_images_per_prompt,
device=device,
device=control_model.device,
dtype=control_model.dtype,
control_mode=control_info.control_mode,
resize_mode=control_info.resize_mode,
@@ -560,6 +547,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
for single_ip_adapter in ip_adapters:
with context.models.load(single_ip_adapter.ip_adapter_model) as ip_adapter_model:
assert isinstance(ip_adapter_model, IPAdapter)
image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model)
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
single_ipa_image_fields = single_ip_adapter.image
if not isinstance(single_ipa_image_fields, list):
@@ -568,7 +556,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
single_ipa_images = [
context.images.get_pil(image.image_name, mode="RGB") for image in single_ipa_image_fields
]
with context.models.load(single_ip_adapter.image_encoder_model) as image_encoder_model:
with image_encoder_model_info as image_encoder_model:
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
# Get image embeddings from CLIP and ImageProjModel.
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
@@ -618,7 +606,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
context: InvocationContext,
t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]],
latents_shape: list[int],
device: torch.device,
do_classifier_free_guidance: bool,
) -> Optional[list[T2IAdapterData]]:
if t2i_adapter is None:
@@ -634,6 +621,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
t2i_adapter_data = []
for t2i_adapter_field in t2i_adapter:
t2i_adapter_model_config = context.models.get_config(t2i_adapter_field.t2i_adapter_model.key)
t2i_adapter_loaded_model = context.models.load(t2i_adapter_field.t2i_adapter_model)
image = context.images.get_pil(t2i_adapter_field.image.image_name, mode="RGB")
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
@@ -649,7 +637,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
raise ValueError(f"Unexpected T2I-Adapter base model type: '{t2i_adapter_model_config.base}'.")
t2i_adapter_model: T2IAdapter
with context.models.load(t2i_adapter_field.t2i_adapter_model) as t2i_adapter_model:
with t2i_adapter_loaded_model as t2i_adapter_model:
total_downscale_factor = t2i_adapter_model.total_downscale_factor
# Note: We have hard-coded `do_classifier_free_guidance=False`. This is because we only want to prepare
@@ -669,7 +657,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
width=control_width_resize,
height=control_height_resize,
num_channels=t2i_adapter_model.config["in_channels"], # mypy treats this as a FrozenDict
device=device,
device=t2i_adapter_model.device,
dtype=t2i_adapter_model.dtype,
resize_mode=t2i_adapter_field.resize_mode,
)
@@ -834,9 +822,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
_, _, latent_height, latent_width = latents.shape
# get the unet's config so that we can pass the base to sd_step_callback()
unet_config = context.models.get_config(self.unet.unet.key)
conditioning_data = self.get_conditioning_data(
context=context,
positive_conditioning_field=self.positive_conditioning,
@@ -856,7 +841,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
seed=seed,
unet_config=unet_config,
)
timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
@@ -868,6 +852,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
denoising_end=self.denoising_end,
)
# get the unet's config so that we can pass the base to sd_step_callback()
unet_config = context.models.get_config(self.unet.unet.key)
### preview
def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, unet_config.base)
@@ -939,8 +926,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
# ext: t2i/ip adapter
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
unet_info = context.models.load(self.unet.unet)
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
context.models.load(self.unet.unet).model_on_device() as (cached_weights, unet),
unet_info.model_on_device() as (cached_weights, unet),
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
# ext: controlnet
ext_manager.patch_extensions(denoise_ctx),
@@ -961,7 +950,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
@torch.no_grad()
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
def _old_invoke(self, context: InvocationContext) -> LatentsOutput:
device = TorchDevice.choose_torch_device()
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
@@ -976,7 +964,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
context,
self.t2i_adapter,
latents.shape,
device=device,
do_classifier_free_guidance=True,
)
@@ -1008,9 +995,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
del lora_info
return
unet_info = context.models.load(self.unet.unet)
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
ExitStack() as exit_stack,
context.models.load(self.unet.unet).model_on_device() as (cached_weights, unet),
unet_info.model_on_device() as (cached_weights, unet),
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
# Apply the LoRA after unet has been moved to its target device for faster patching.
@@ -1023,20 +1012,19 @@ class DenoiseLatentsInvocation(BaseInvocation):
),
):
assert isinstance(unet, UNet2DConditionModel)
latents = latents.to(device=device, dtype=unet.dtype)
latents = latents.to(device=unet.device, dtype=unet.dtype)
if noise is not None:
noise = noise.to(device=device, dtype=unet.dtype)
noise = noise.to(device=unet.device, dtype=unet.dtype)
if mask is not None:
mask = mask.to(device=device, dtype=unet.dtype)
mask = mask.to(device=unet.device, dtype=unet.dtype)
if masked_latents is not None:
masked_latents = masked_latents.to(device=device, dtype=unet.dtype)
masked_latents = masked_latents.to(device=unet.device, dtype=unet.dtype)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
seed=seed,
unet_config=unet_config,
)
pipeline = self.create_pipeline(unet, scheduler)
@@ -1046,7 +1034,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
device=device,
device=unet.device,
dtype=unet.dtype,
latent_height=latent_height,
latent_width=latent_width,
@@ -1059,7 +1047,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
context=context,
control_input=self.control,
latents_shape=latents.shape,
device=device,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
@@ -1077,7 +1064,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
scheduler,
device=device,
device=unet.device,
steps=self.steps,
denoising_start=self.denoising_start,
denoising_end=self.denoising_end,

View File

@@ -199,8 +199,8 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
else None
)
transformer_config = context.models.get_config(self.transformer.transformer)
is_schnell = "schnell" in getattr(transformer_config, "config_path", "")
transformer_info = context.models.load(self.transformer.transformer)
is_schnell = "schnell" in getattr(transformer_info.config, "config_path", "")
# Calculate the timestep schedule.
timesteps = get_schedule(
@@ -276,7 +276,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
# TODO(ryand): We should really do this in a separate invocation to benefit from caching.
ip_adapter_fields = self._normalize_ip_adapter_fields()
pos_image_prompt_clip_embeds, neg_image_prompt_clip_embeds = self._prep_ip_adapter_image_prompt_clip_embeds(
ip_adapter_fields, context, device=x.device
ip_adapter_fields, context
)
cfg_scale = self.prep_cfg_scale(
@@ -299,11 +299,9 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
)
# Load the transformer model.
(cached_weights, transformer) = exit_stack.enter_context(
context.models.load(self.transformer.transformer).model_on_device()
)
(cached_weights, transformer) = exit_stack.enter_context(transformer_info.model_on_device())
assert isinstance(transformer, Flux)
config = transformer_config
config = transformer_info.config
assert config is not None
# Determine if the model is quantized.
@@ -514,18 +512,15 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
# before loading the models. Then make sure that all VAE encoding is done before loading the ControlNets to
# minimize peak memory.
# First, load the ControlNet models so that we can determine the ControlNet types.
controlnet_models = [context.models.load(controlnet.control_model) for controlnet in controlnets]
# Calculate the controlnet conditioning tensors.
# We do this before loading the ControlNet models because it may require running the VAE, and we are trying to
# keep peak memory down.
controlnet_conds: list[torch.Tensor] = []
for controlnet in controlnets:
for controlnet, controlnet_model in zip(controlnets, controlnet_models, strict=True):
image = context.images.get_pil(controlnet.image.image_name)
# HACK(ryand): We have to load the ControlNet model to determine whether the VAE needs to be run. We really
# shouldn't have to load the model here. There's a risk that the model will be dropped from the model cache
# before we load it into VRAM and thus we'll have to load it again (context:
# https://github.com/invoke-ai/InvokeAI/issues/7513).
controlnet_model = context.models.load(controlnet.control_model)
if isinstance(controlnet_model.model, InstantXControlNetFlux):
if self.controlnet_vae is None:
raise ValueError("A ControlNet VAE is required when using an InstantX FLUX ControlNet.")
@@ -555,8 +550,10 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
# Finally, load the ControlNet models and initialize the ControlNet extensions.
controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension] = []
for controlnet, controlnet_cond in zip(controlnets, controlnet_conds, strict=True):
model = exit_stack.enter_context(context.models.load(controlnet.control_model))
for controlnet, controlnet_cond, controlnet_model in zip(
controlnets, controlnet_conds, controlnet_models, strict=True
):
model = exit_stack.enter_context(controlnet_model)
if isinstance(model, XLabsControlNetFlux):
controlnet_extensions.append(
@@ -626,7 +623,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
self,
ip_adapter_fields: list[IPAdapterField],
context: InvocationContext,
device: torch.device,
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
"""Run the IPAdapter CLIPVisionModel, returning image prompt embeddings."""
clip_image_processor = CLIPImageProcessor()
@@ -666,11 +662,11 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
clip_image: torch.Tensor = clip_image_processor(images=pos_images, return_tensors="pt").pixel_values
clip_image = clip_image.to(device=device, dtype=image_encoder_model.dtype)
clip_image = clip_image.to(device=image_encoder_model.device, dtype=image_encoder_model.dtype)
pos_clip_image_embeds = image_encoder_model(clip_image).image_embeds
clip_image = clip_image_processor(images=neg_images, return_tensors="pt").pixel_values
clip_image = clip_image.to(device=device, dtype=image_encoder_model.dtype)
clip_image = clip_image.to(device=image_encoder_model.device, dtype=image_encoder_model.dtype)
neg_clip_image_embeds = image_encoder_model(clip_image).image_embeds
pos_image_prompt_clip_embeds.append(pos_clip_image_embeds)

View File

@@ -10,10 +10,6 @@ from invokeai.app.invocations.baseinvocation import (
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.t5_model_identifier import (
preprocess_t5_encoder_model_identifier,
preprocess_t5_tokenizer_model_identifier,
)
from invokeai.backend.flux.util import max_seq_lengths
from invokeai.backend.model_manager.config import (
CheckpointConfigBase,
@@ -78,8 +74,8 @@ class FluxModelLoaderInvocation(BaseInvocation):
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
tokenizer2 = preprocess_t5_tokenizer_model_identifier(self.t5_encoder_model)
t5_encoder = preprocess_t5_encoder_model_identifier(self.t5_encoder_model)
tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
transformer_config = context.models.get_config(transformer)
assert isinstance(transformer_config, CheckpointConfigBase)

View File

@@ -2,7 +2,7 @@ from contextlib import ExitStack
from typing import Iterator, Literal, Optional, Tuple
import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer, T5TokenizerFast
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
@@ -69,14 +69,17 @@ class FluxTextEncoderInvocation(BaseInvocation):
)
def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
prompt = [self.prompt]
with (
context.models.load(self.t5_encoder.text_encoder) as t5_text_encoder,
context.models.load(self.t5_encoder.tokenizer) as t5_tokenizer,
t5_text_encoder_info as t5_text_encoder,
t5_tokenizer_info as t5_tokenizer,
):
assert isinstance(t5_text_encoder, T5EncoderModel)
assert isinstance(t5_tokenizer, (T5Tokenizer, T5TokenizerFast))
assert isinstance(t5_tokenizer, T5Tokenizer)
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)
@@ -87,20 +90,22 @@ class FluxTextEncoderInvocation(BaseInvocation):
return prompt_embeds
def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
prompt = [self.prompt]
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
clip_text_encoder_config = clip_text_encoder_info.config
assert clip_text_encoder_config is not None
prompt = [self.prompt]
with (
clip_text_encoder_info.model_on_device() as (cached_weights, clip_text_encoder),
context.models.load(self.clip.tokenizer) as clip_tokenizer,
clip_tokenizer_info as clip_tokenizer,
ExitStack() as exit_stack,
):
assert isinstance(clip_text_encoder, CLIPTextModel)
assert isinstance(clip_tokenizer, CLIPTokenizer)
clip_text_encoder_config = clip_text_encoder_info.config
assert clip_text_encoder_config is not None
# Apply LoRA models to the CLIP encoder.
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:

View File

@@ -3,7 +3,6 @@ from einops import rearrange
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
@@ -25,7 +24,7 @@ from invokeai.backend.util.devices import TorchDevice
title="FLUX Latents to Image",
tags=["latents", "image", "vae", "l2i", "flux"],
category="latents",
version="1.0.1",
version="1.0.0",
)
class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates an image from latents."""
@@ -39,23 +38,8 @@ class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
input=Input.Connection,
)
def _estimate_working_memory(self, latents: torch.Tensor, vae: AutoEncoder) -> int:
"""Estimate the working memory required by the invocation in bytes."""
# It was found experimentally that the peak working memory scales linearly with the number of pixels and the
# element size (precision).
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
element_size = next(vae.parameters()).element_size()
scaling_constant = 1090 # Determined experimentally.
working_memory = out_h * out_w * element_size * scaling_constant
# We add a 20% buffer to the working memory estimate to be safe.
working_memory = working_memory * 1.2
return int(working_memory)
def _vae_decode(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image:
estimated_working_memory = self._estimate_working_memory(latents, vae_info.model)
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
with vae_info as vae:
assert isinstance(vae, AutoEncoder)
vae_dtype = next(iter(vae.parameters())).dtype
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)

View File

@@ -23,7 +23,6 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.image_records.image_records_common import ImageCategory
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.misc import SEED_MAX
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
from invokeai.backend.image_util.safety_checker import SafetyChecker
@@ -162,12 +161,12 @@ class ImagePasteInvocation(BaseInvocation, WithMetadata, WithBoard):
crop: bool = InputField(default=False, description="Crop to base image dimensions")
def invoke(self, context: InvocationContext) -> ImageOutput:
base_image = context.images.get_pil(self.base_image.image_name, mode="RGBA")
image = context.images.get_pil(self.image.image_name, mode="RGBA")
base_image = context.images.get_pil(self.base_image.image_name)
image = context.images.get_pil(self.image.image_name)
mask = None
if self.mask is not None:
mask = context.images.get_pil(self.mask.image_name, mode="L")
mask = ImageOps.invert(mask)
mask = context.images.get_pil(self.mask.image_name)
mask = ImageOps.invert(mask.convert("L"))
# TODO: probably shouldn't invert mask here... should user be required to do it?
min_x = min(0, self.x)
@@ -177,11 +176,7 @@ class ImagePasteInvocation(BaseInvocation, WithMetadata, WithBoard):
new_image = Image.new(mode="RGBA", size=(max_x - min_x, max_y - min_y), color=(0, 0, 0, 0))
new_image.paste(base_image, (abs(min_x), abs(min_y)))
# Create a temporary image to paste the image with transparency
temp_image = Image.new("RGBA", new_image.size)
temp_image.paste(image, (max(0, self.x), max(0, self.y)), mask=mask)
new_image = Image.alpha_composite(new_image, temp_image)
new_image.paste(image, (max(0, self.x), max(0, self.y)), mask=mask)
if self.crop:
base_w, base_h = base_image.size
@@ -306,44 +301,14 @@ class ImageBlurInvocation(BaseInvocation, WithMetadata, WithBoard):
blur_type: Literal["gaussian", "box"] = InputField(default="gaussian", description="The type of blur")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, mode="RGBA")
image = context.images.get_pil(self.image.image_name)
# Split the image into RGBA channels
r, g, b, a = image.split()
# Premultiply RGB channels by alpha
premultiplied_image = ImageChops.multiply(image, a.convert("RGBA"))
premultiplied_image.putalpha(a)
# Apply the blur
blur = (
ImageFilter.GaussianBlur(self.radius) if self.blur_type == "gaussian" else ImageFilter.BoxBlur(self.radius)
)
blurred_image = premultiplied_image.filter(blur)
blur_image = image.filter(blur)
# Split the blurred image into RGBA channels
r, g, b, a_orig = blurred_image.split()
# Convert to float using NumPy. float 32/64 division are much faster than float 16
r = numpy.array(r, dtype=numpy.float32)
g = numpy.array(g, dtype=numpy.float32)
b = numpy.array(b, dtype=numpy.float32)
a = numpy.array(a_orig, dtype=numpy.float32) / 255.0 # Normalize alpha to [0, 1]
# Unpremultiply RGB channels by alpha
r /= a + 1e-6 # Add a small epsilon to avoid division by zero
g /= a + 1e-6
b /= a + 1e-6
# Convert back to PIL images
r = Image.fromarray(numpy.uint8(numpy.clip(r, 0, 255)))
g = Image.fromarray(numpy.uint8(numpy.clip(g, 0, 255)))
b = Image.fromarray(numpy.uint8(numpy.clip(b, 0, 255)))
# Merge back into a single image
result_image = Image.merge("RGBA", (r, g, b, a_orig))
image_dto = context.images.save(image=result_image)
image_dto = context.images.save(image=blur_image)
return ImageOutput.build(image_dto)
@@ -1090,67 +1055,3 @@ class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard):
image_dto = context.images.save(image=generated_image)
return ImageOutput.build(image_dto)
@invocation(
"img_noise",
title="Add Image Noise",
tags=["image", "noise"],
category="image",
version="1.0.1",
)
class ImageNoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Add noise to an image"""
image: ImageField = InputField(description="The image to add noise to")
seed: int = InputField(
default=0,
ge=0,
le=SEED_MAX,
description=FieldDescriptions.seed,
)
noise_type: Literal["gaussian", "salt_and_pepper"] = InputField(
default="gaussian",
description="The type of noise to add",
)
amount: float = InputField(default=0.1, ge=0, le=1, description="The amount of noise to add")
noise_color: bool = InputField(default=True, description="Whether to add colored noise")
size: int = InputField(default=1, ge=1, description="The size of the noise points")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, mode="RGBA")
# Save out the alpha channel
alpha = image.getchannel("A")
# Set the seed for numpy random
rs = numpy.random.RandomState(numpy.random.MT19937(numpy.random.SeedSequence(self.seed)))
if self.noise_type == "gaussian":
if self.noise_color:
noise = rs.normal(0, 1, (image.height // self.size, image.width // self.size, 3)) * 255
else:
noise = rs.normal(0, 1, (image.height // self.size, image.width // self.size)) * 255
noise = numpy.stack([noise] * 3, axis=-1)
elif self.noise_type == "salt_and_pepper":
if self.noise_color:
noise = rs.choice(
[0, 255], (image.height // self.size, image.width // self.size, 3), p=[1 - self.amount, self.amount]
)
else:
noise = rs.choice(
[0, 255], (image.height // self.size, image.width // self.size), p=[1 - self.amount, self.amount]
)
noise = numpy.stack([noise] * 3, axis=-1)
noise = Image.fromarray(noise.astype(numpy.uint8), mode="RGB").resize(
(image.width, image.height), Image.Resampling.NEAREST
)
noisy_image = Image.blend(image.convert("RGB"), noise, self.amount).convert("RGBA")
# Paste back the alpha channel
noisy_image.putalpha(alpha)
image_dto = context.images.save(image=noisy_image)
return ImageOutput.build(image_dto)

View File

@@ -26,7 +26,6 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
from invokeai.backend.util.devices import TorchDevice
@invocation(
@@ -99,7 +98,7 @@ class ImageToLatentsInvocation(BaseInvocation):
)
# non_noised_latents_from_image
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae.dtype)
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
with torch.inference_mode(), tiling_context:
latents = ImageToLatentsInvocation._encode_to_tensor(vae, image_tensor)

View File

@@ -34,7 +34,7 @@ from invokeai.backend.util.devices import TorchDevice
title="Latents to Image",
tags=["latents", "image", "vae", "l2i"],
category="latents",
version="1.3.1",
version="1.3.0",
)
class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates an image from latents."""
@@ -53,58 +53,16 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
fp32: bool = InputField(default=False, description=FieldDescriptions.fp32)
def _estimate_working_memory(
self, latents: torch.Tensor, use_tiling: bool, vae: AutoencoderKL | AutoencoderTiny
) -> int:
"""Estimate the working memory required by the invocation in bytes."""
# It was found experimentally that the peak working memory scales linearly with the number of pixels and the
# element size (precision). This estimate is accurate for both SD1 and SDXL.
element_size = 4 if self.fp32 else 2
scaling_constant = 960 # Determined experimentally.
if use_tiling:
tile_size = self.tile_size
if tile_size == 0:
tile_size = vae.tile_sample_min_size
assert isinstance(tile_size, int)
out_h = tile_size
out_w = tile_size
working_memory = out_h * out_w * element_size * scaling_constant
# We add 25% to the working memory estimate when tiling is enabled to account for factors like tile overlap
# and number of tiles. We could make this more precise in the future, but this should be good enough for
# most use cases.
working_memory = working_memory * 1.25
else:
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
working_memory = out_h * out_w * element_size * scaling_constant
if self.fp32:
# If we are running in FP32, then we should account for the likely increase in model size (~250MB).
working_memory += 250 * 2**20
# We add 20% to the working memory estimate to be safe.
working_memory = int(working_memory * 1.2)
return working_memory
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
use_tiling = self.tiled or context.config.get().force_tiled_decode
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
estimated_working_memory = self._estimate_working_memory(latents, use_tiling, vae_info.model)
with (
SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes),
vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae),
):
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
context.util.signal_progress("Running VAE decoder")
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
latents = latents.to(TorchDevice.choose_torch_device())
latents = latents.to(vae.device)
if self.fp32:
vae.to(dtype=torch.float32)
@@ -130,7 +88,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
vae.to(dtype=torch.float16)
latents = latents.half()
if use_tiling:
if self.tiled or context.config.get().force_tiled_decode:
vae.enable_tiling()
else:
vae.disable_tiling()

View File

@@ -7,6 +7,7 @@ import torch
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
@@ -538,3 +539,23 @@ class BoundingBoxInvocation(BaseInvocation):
# endregion
@invocation(
"image_batch",
title="Image Batch",
tags=["primitives", "image", "batch", "internal"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class ImageBatchInvocation(BaseInvocation):
"""Create a batched generation, where the workflow is executed once for each image in the batch."""
images: list[ImageField] = InputField(min_length=1, description="The images to batch over", input=Input.Direct)
def __init__(self):
raise NotImplementedError("This class should never be executed or instantiated directly.")
def invoke(self, context: InvocationContext) -> ImageOutput:
raise NotImplementedError("This class should never be executed or instantiated directly.")

View File

@@ -16,7 +16,6 @@ from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.util.devices import TorchDevice
@invocation(
@@ -40,7 +39,7 @@ class SD3ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
vae.disable_tiling()
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae.dtype)
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
with torch.inference_mode():
image_tensor_dist = vae.encode(image_tensor).latent_dist
# TODO: Use seed to make sampling reproducible.

View File

@@ -6,7 +6,6 @@ from einops import rearrange
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
@@ -27,7 +26,7 @@ from invokeai.backend.util.devices import TorchDevice
title="SD3 Latents to Image",
tags=["latents", "image", "vae", "l2i", "sd3"],
category="latents",
version="1.3.1",
version="1.3.0",
)
class SD3LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates an image from latents."""
@@ -41,34 +40,16 @@ class SD3LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
input=Input.Connection,
)
def _estimate_working_memory(self, latents: torch.Tensor, vae: AutoencoderKL) -> int:
"""Estimate the working memory required by the invocation in bytes."""
# It was found experimentally that the peak working memory scales linearly with the number of pixels and the
# element size (precision).
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
element_size = next(vae.parameters()).element_size()
scaling_constant = 1230 # Determined experimentally.
working_memory = out_h * out_w * element_size * scaling_constant
# We add a 20% buffer to the working memory estimate to be safe.
working_memory = working_memory * 1.2
return int(working_memory)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (AutoencoderKL))
estimated_working_memory = self._estimate_working_memory(latents, vae_info.model)
with (
SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes),
vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae),
):
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
context.util.signal_progress("Running VAE")
assert isinstance(vae, (AutoencoderKL))
latents = latents.to(TorchDevice.choose_torch_device())
latents = latents.to(vae.device)
vae.disable_tiling()

View File

@@ -10,10 +10,6 @@ from invokeai.app.invocations.baseinvocation import (
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.t5_model_identifier import (
preprocess_t5_encoder_model_identifier,
preprocess_t5_tokenizer_model_identifier,
)
from invokeai.backend.model_manager.config import SubModelType
@@ -92,8 +88,16 @@ class Sd3ModelLoaderInvocation(BaseInvocation):
if self.clip_g_model
else self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
)
tokenizer_t5 = preprocess_t5_tokenizer_model_identifier(self.t5_encoder_model or self.model)
t5_encoder = preprocess_t5_encoder_model_identifier(self.t5_encoder_model or self.model)
tokenizer_t5 = (
self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
if self.t5_encoder_model
else self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
)
t5_encoder = (
self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
if self.t5_encoder_model
else self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
)
return Sd3ModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[]),

View File

@@ -21,7 +21,6 @@ from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, SD3ConditioningInfo
from invokeai.backend.util.devices import TorchDevice
# The SD3 T5 Max Sequence Length set based on the default in diffusers.
SD3_T5_MAX_SEQ_LEN = 256
@@ -87,11 +86,14 @@ class Sd3TextEncoderInvocation(BaseInvocation):
def _t5_encode(self, context: InvocationContext, max_seq_len: int) -> torch.Tensor:
assert self.t5_encoder is not None
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
prompt = [self.prompt]
with (
context.models.load(self.t5_encoder.text_encoder) as t5_text_encoder,
context.models.load(self.t5_encoder.tokenizer) as t5_tokenizer,
t5_text_encoder_info as t5_text_encoder,
t5_tokenizer_info as t5_tokenizer,
):
context.util.signal_progress("Running T5 encoder")
assert isinstance(t5_text_encoder, T5EncoderModel)
@@ -118,7 +120,7 @@ class Sd3TextEncoderInvocation(BaseInvocation):
f" {max_seq_len} tokens: {removed_text}"
)
prompt_embeds = t5_text_encoder(text_input_ids.to(TorchDevice.choose_torch_device()))[0]
prompt_embeds = t5_text_encoder(text_input_ids.to(t5_text_encoder.device))[0]
assert isinstance(prompt_embeds, torch.Tensor)
return prompt_embeds
@@ -126,12 +128,14 @@ class Sd3TextEncoderInvocation(BaseInvocation):
def _clip_encode(
self, context: InvocationContext, clip_model: CLIPField, tokenizer_max_length: int = 77
) -> Tuple[torch.Tensor, torch.Tensor]:
clip_tokenizer_info = context.models.load(clip_model.tokenizer)
clip_text_encoder_info = context.models.load(clip_model.text_encoder)
prompt = [self.prompt]
clip_text_encoder_info = context.models.load(clip_model.text_encoder)
with (
clip_text_encoder_info.model_on_device() as (cached_weights, clip_text_encoder),
context.models.load(clip_model.tokenizer) as clip_tokenizer,
clip_tokenizer_info as clip_tokenizer,
ExitStack() as exit_stack,
):
context.util.signal_progress("Running CLIP encoder")
@@ -181,7 +185,7 @@ class Sd3TextEncoderInvocation(BaseInvocation):
f" {tokenizer_max_length} tokens: {removed_text}"
)
prompt_embeds = clip_text_encoder(
input_ids=text_input_ids.to(TorchDevice.choose_torch_device()), output_hidden_states=True
input_ids=text_input_ids.to(clip_text_encoder.device), output_hidden_states=True
)
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]

View File

@@ -22,7 +22,6 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
from invokeai.backend.tiles.tiles import calc_tiles_min_overlap
from invokeai.backend.tiles.utils import TBLR, Tile
from invokeai.backend.util.devices import TorchDevice
@invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.3.0")
@@ -103,7 +102,7 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
(height * scale, width * scale, channels), dtype=torch.uint8, device=torch.device("cpu")
)
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=spandrel_model.dtype)
image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)
# Run the model on each tile.
pbar = tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles")
@@ -117,7 +116,9 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
raise CanceledException
# Extract the current tile from the input tensor.
input_tile = image_tensor[:, :, tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right]
input_tile = image_tensor[
:, :, tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right
].to(device=spandrel_model.device, dtype=spandrel_model.dtype)
# Run the model on the tile.
output_tile = spandrel_model.run(input_tile)
@@ -150,12 +151,15 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
return pil_image
@torch.no_grad()
@torch.inference_mode()
def invoke(self, context: InvocationContext) -> ImageOutput:
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
# revisit this.
image = context.images.get_pil(self.image.image_name, mode="RGB")
# Load the model.
spandrel_model_info = context.models.load(self.image_to_image_model)
def step_callback(step: int, total_steps: int) -> None:
context.util.signal_progress(
message=f"Processing tile {step}/{total_steps}",
@@ -163,7 +167,7 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
)
# Do the upscaling.
with context.models.load(self.image_to_image_model) as spandrel_model:
with spandrel_model_info as spandrel_model:
assert isinstance(spandrel_model, SpandrelImageToImageModel)
# Upscale the image
@@ -196,12 +200,15 @@ class SpandrelImageToImageAutoscaleInvocation(SpandrelImageToImageInvocation):
description="If true, the output image will be resized to the nearest multiple of 8 in both dimensions.",
)
@torch.no_grad()
@torch.inference_mode()
def invoke(self, context: InvocationContext) -> ImageOutput:
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
# revisit this.
image = context.images.get_pil(self.image.image_name, mode="RGB")
# Load the model.
spandrel_model_info = context.models.load(self.image_to_image_model)
# The target size of the image, determined by the provided scale. We'll run the upscaler until we hit this size.
# Later, we may mutate this value if the model doesn't upscale the image or if the user requested a multiple of 8.
target_width = int(image.width * self.scale)
@@ -214,7 +221,7 @@ class SpandrelImageToImageAutoscaleInvocation(SpandrelImageToImageInvocation):
)
# Do the upscaling.
with context.models.load(self.image_to_image_model) as spandrel_model:
with spandrel_model_info as spandrel_model:
assert isinstance(spandrel_model, SpandrelImageToImageModel)
iteration = 1

View File

@@ -201,24 +201,25 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
yield (lora_info.model, lora.weight)
del lora_info
device = TorchDevice.choose_torch_device()
# Load the UNet model.
unet_info = context.models.load(self.unet.unet)
with (
ExitStack() as exit_stack,
context.models.load(self.unet.unet) as unet,
unet_info as unet,
LayerPatcher.apply_smart_model_patches(
model=unet, patches=_lora_loader(), prefix="lora_unet_", dtype=unet.dtype
),
):
assert isinstance(unet, UNet2DConditionModel)
latents = latents.to(device=device, dtype=unet.dtype)
latents = latents.to(device=unet.device, dtype=unet.dtype)
if noise is not None:
noise = noise.to(device=device, dtype=unet.dtype)
noise = noise.to(device=unet.device, dtype=unet.dtype)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
seed=seed,
unet_config=unet_config,
)
pipeline = self.create_pipeline(unet=unet, scheduler=scheduler)
@@ -227,7 +228,7 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
device=device,
device=unet.device,
dtype=unet.dtype,
latent_height=latent_tile_height,
latent_width=latent_tile_width,
@@ -240,7 +241,6 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
context=context,
control_input=self.control,
latents_shape=list(latents.shape),
device=device,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
@@ -266,7 +266,7 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
timesteps, init_timestep, scheduler_step_kwargs = DenoiseLatentsInvocation.init_scheduler(
scheduler,
device=device,
device=unet.device,
steps=self.steps,
denoising_start=self.denoising_start,
denoising_end=self.denoising_end,

View File

@@ -13,6 +13,7 @@ from functools import lru_cache
from pathlib import Path
from typing import Any, Literal, Optional
import psutil
import yaml
from pydantic import BaseModel, Field, PrivateAttr, field_validator
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict
@@ -24,6 +25,8 @@ from invokeai.frontend.cli.arg_parser import InvokeAIArgs
INIT_FILE = Path("invokeai.yaml")
DB_FILE = Path("invokeai.db")
LEGACY_INIT_FILE = Path("invokeai.init")
DEFAULT_RAM_CACHE = 10.0
DEFAULT_VRAM_CACHE = 0.25
DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
PRECISION = Literal["auto", "float16", "bfloat16", "float32"]
ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
@@ -33,6 +36,24 @@ LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
CONFIG_SCHEMA_VERSION = "4.0.2"
def get_default_ram_cache_size() -> float:
"""Run a heuristic for the default RAM cache based on installed RAM."""
# On some machines, psutil.virtual_memory().total gives a value that is slightly less than the actual RAM, so the
# limits are set slightly lower than than what we expect the actual RAM to be.
GB = 1024**3
max_ram = psutil.virtual_memory().total / GB
if max_ram >= 60:
return 15.0
if max_ram >= 30:
return 7.5
if max_ram >= 14:
return 4.0
return 2.1 # 2.1 is just large enough for sd 1.5 ;-)
class URLRegexTokenPair(BaseModel):
url_regex: str = Field(description="Regular expression to match against the URL")
token: str = Field(description="Token to use when the URL matches the regex")
@@ -82,14 +103,10 @@ class InvokeAIAppConfig(BaseSettings):
profile_graphs: Enable graph profiling using `cProfile`.
profile_prefix: An optional prefix for profile output files.
profiles_dir: Path to profiles output directory.
max_cache_ram_gb: The maximum amount of CPU RAM to use for model caching in GB. If unset, the limit will be configured based on the available RAM. In most cases, it is recommended to leave this unset.
max_cache_vram_gb: The amount of VRAM to use for model caching in GB. If unset, the limit will be configured based on the available VRAM and the device_working_mem_gb. In most cases, it is recommended to leave this unset.
ram: Maximum memory amount used by memory model cache for rapid switching (GB).
vram: Amount of VRAM reserved for model storage (GB).
lazy_offload: Keep models in VRAM until their space is needed.
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
device_working_mem_gb: The amount of working memory to keep available on the compute device (in GB). Has no effect if running on CPU. If you are experiencing OOM errors, try increasing this value.
enable_partial_loading: Enable partial loading of models. This enables models to run with reduced VRAM requirements (at the cost of slower speed) by streaming the model from RAM to VRAM as its used. In some edge cases, partial loading can cause models to run more slowly if they were previously being fully loaded into VRAM.
ram: DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_ram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.
vram: DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_vram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.
lazy_offload: DEPRECATED: This setting is no longer used. Lazy-offloading is enabled by default. This config setting will be removed once the new model cache behavior is stable.
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
@@ -157,15 +174,10 @@ class InvokeAIAppConfig(BaseSettings):
profiles_dir: Path = Field(default=Path("profiles"), description="Path to profiles output directory.")
# CACHE
max_cache_ram_gb: Optional[float] = Field(default=None, gt=0, description="The maximum amount of CPU RAM to use for model caching in GB. If unset, the limit will be configured based on the available RAM. In most cases, it is recommended to leave this unset.")
max_cache_vram_gb: Optional[float] = Field(default=None, ge=0, description="The amount of VRAM to use for model caching in GB. If unset, the limit will be configured based on the available VRAM and the device_working_mem_gb. In most cases, it is recommended to leave this unset.")
ram: float = Field(default_factory=get_default_ram_cache_size, gt=0, description="Maximum memory amount used by memory model cache for rapid switching (GB).")
vram: float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (GB).")
lazy_offload: bool = Field(default=True, description="Keep models in VRAM until their space is needed.")
log_memory_usage: bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.")
device_working_mem_gb: float = Field(default=3, description="The amount of working memory to keep available on the compute device (in GB). Has no effect if running on CPU. If you are experiencing OOM errors, try increasing this value.")
enable_partial_loading: bool = Field(default=False, description="Enable partial loading of models. This enables models to run with reduced VRAM requirements (at the cost of slower speed) by streaming the model from RAM to VRAM as its used. In some edge cases, partial loading can cause models to run more slowly if they were previously being fully loaded into VRAM.")
# Deprecated CACHE configs
ram: Optional[float] = Field(default=None, gt=0, description="DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_ram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.")
vram: Optional[float] = Field(default=None, ge=0, description="DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_vram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.")
lazy_offload: bool = Field(default=True, description="DEPRECATED: This setting is no longer used. Lazy-offloading is enabled by default. This config setting will be removed once the new model cache behavior is stable.")
# DEVICE
device: DEVICE = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.")

View File

@@ -82,12 +82,11 @@ class ModelManagerService(ModelManagerServiceBase):
logger.setLevel(app_config.log_level.upper())
ram_cache = ModelCache(
execution_device_working_mem_gb=app_config.device_working_mem_gb,
enable_partial_loading=app_config.enable_partial_loading,
max_ram_cache_size_gb=app_config.max_cache_ram_gb,
max_vram_cache_size_gb=app_config.max_cache_vram_gb,
execution_device=execution_device or TorchDevice.choose_torch_device(),
max_cache_size=app_config.ram,
max_vram_cache_size=app_config.vram,
lazy_offloading=app_config.lazy_offload,
logger=logger,
execution_device=execution_device or TorchDevice.choose_torch_device(),
)
loader = ModelLoadService(
app_config=app_config,

View File

@@ -108,16 +108,8 @@ class Batch(BaseModel):
return v
for batch_data_list in v:
for datum in batch_data_list:
if not datum.items:
continue
# Special handling for numbers - they can be mixed
# TODO(psyche): Update BatchDatum to have a `type` field to specify the type of the items, then we can have strict float and int fields
if all(isinstance(item, (int, float)) for item in datum.items):
continue
# Get the type of the first item in the list
first_item_type = type(datum.items[0])
first_item_type = type(datum.items[0]) if datum.items else None
for item in datum.items:
if type(item) is not first_item_type:
raise BatchItemsTypeError("All items in a batch must have the same type")

View File

@@ -1,26 +0,0 @@
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.backend.model_manager.config import BaseModelType, SubModelType
def preprocess_t5_encoder_model_identifier(model_identifier: ModelIdentifierField) -> ModelIdentifierField:
"""A helper function to normalize a T5 encoder model identifier so that T5 models associated with FLUX
or SD3 models can be used interchangeably.
"""
if model_identifier.base == BaseModelType.Any:
return model_identifier.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
elif model_identifier.base == BaseModelType.StableDiffusion3:
return model_identifier.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
else:
raise ValueError(f"Unsupported model base: {model_identifier.base}")
def preprocess_t5_tokenizer_model_identifier(model_identifier: ModelIdentifierField) -> ModelIdentifierField:
"""A helper function to normalize a T5 tokenizer model identifier so that T5 models associated with FLUX
or SD3 models can be used interchangeably.
"""
if model_identifier.base == BaseModelType.Any:
return model_identifier.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
elif model_identifier.base == BaseModelType.StableDiffusion3:
return model_identifier.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
else:
raise ValueError(f"Unsupported model base: {model_identifier.base}")

View File

@@ -8,7 +8,6 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import XlabsIpAdapterFlux
from invokeai.backend.flux.modules.layers import DoubleStreamBlock
from invokeai.backend.util.devices import TorchDevice
class XLabsIPAdapterExtension:
@@ -46,7 +45,7 @@ class XLabsIPAdapterExtension:
) -> torch.Tensor:
clip_image_processor = CLIPImageProcessor()
clip_image: torch.Tensor = clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image = clip_image.to(device=TorchDevice.choose_torch_device(), dtype=image_encoder.dtype)
clip_image = clip_image.to(device=image_encoder.device, dtype=image_encoder.dtype)
clip_image_embeds = image_encoder(clip_image).image_embeds
return clip_image_embeds

View File

@@ -1,19 +1,11 @@
# Initially pulled from https://github.com/black-forest-labs/flux
from torch import Tensor, nn
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
from invokeai.backend.util.devices import TorchDevice
from transformers import PreTrainedModel, PreTrainedTokenizer
class HFEncoder(nn.Module):
def __init__(
self,
encoder: PreTrainedModel,
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
is_clip: bool,
max_length: int,
):
def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool, max_length: int):
super().__init__()
self.max_length = max_length
self.is_clip = is_clip
@@ -34,7 +26,7 @@ class HFEncoder(nn.Module):
)
outputs = self.hf_module(
input_ids=batch_encoding["input_ids"].to(TorchDevice.choose_torch_device()),
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
attention_mask=None,
output_hidden_states=False,
)

View File

@@ -18,7 +18,6 @@ from invokeai.backend.image_util.util import (
resize_image_to_resolution,
safe_step,
)
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
class DoubleConvBlock(torch.nn.Module):
@@ -110,7 +109,7 @@ class HEDProcessor:
Returns:
The detected edges.
"""
device = get_effective_device(self.network)
device = next(iter(self.network.parameters())).device
np_image = pil_to_np(input_image)
np_image = normalize_image_channel_count(np_image)
np_image = resize_image_to_resolution(np_image, detect_resolution)
@@ -184,7 +183,7 @@ class HEDEdgeDetector:
The detected edges.
"""
device = get_effective_device(self.model)
device = next(iter(self.model.parameters())).device
np_image = pil_to_np(image)

View File

@@ -7,7 +7,6 @@ from PIL import Image
import invokeai.backend.util.logging as logger
from invokeai.backend.model_manager.config import AnyModel
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
def norm_img(np_img):
@@ -32,7 +31,7 @@ class LaMA:
mask = norm_img(mask)
mask = (mask > 0) * 1
device = get_effective_device(self._model)
device = next(self._model.buffers()).device
image = torch.from_numpy(image).unsqueeze(0).to(device)
mask = torch.from_numpy(mask).unsqueeze(0).to(device)

View File

@@ -17,7 +17,6 @@ from invokeai.backend.image_util.util import (
pil_to_np,
resize_image_to_resolution,
)
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
class ResidualBlock(nn.Module):
@@ -131,7 +130,7 @@ class LineartProcessor:
Returns:
The detected lineart.
"""
device = get_effective_device(self.model)
device = next(iter(self.model.parameters())).device
np_image = pil_to_np(input_image)
np_image = normalize_image_channel_count(np_image)
@@ -202,7 +201,7 @@ class LineartEdgeDetector:
Returns:
The detected edges.
"""
device = get_effective_device(self.model)
device = next(iter(self.model.parameters())).device
np_image = pil_to_np(image)

View File

@@ -19,7 +19,6 @@ from invokeai.backend.image_util.util import (
pil_to_np,
resize_image_to_resolution,
)
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
class UnetGenerator(nn.Module):
@@ -172,7 +171,7 @@ class LineartAnimeProcessor:
Returns:
The detected lineart.
"""
device = get_effective_device(self.model)
device = next(iter(self.model.parameters())).device
np_image = pil_to_np(input_image)
np_image = normalize_image_channel_count(np_image)
@@ -240,7 +239,7 @@ class LineartAnimeEdgeDetector:
def run(self, image: Image.Image) -> Image.Image:
"""Processes an image and returns the detected edges."""
device = get_effective_device(self.model)
device = next(iter(self.model.parameters())).device
np_image = pil_to_np(image)

View File

@@ -14,8 +14,6 @@ import numpy as np
import torch
from torch.nn import functional as F
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
def deccode_output_score_and_ptss(tpMap, topk_n = 200, ksize = 5):
'''
@@ -51,7 +49,7 @@ def pred_lines(image, model,
dist_thr=20.0):
h, w, _ = image.shape
device = get_effective_device(model)
device = next(iter(model.parameters())).device
h_ratio, w_ratio = [h / input_shape[0], w / input_shape[1]]
resized_image = np.concatenate([cv2.resize(image, (input_shape[1], input_shape[0]), interpolation=cv2.INTER_AREA),
@@ -110,7 +108,7 @@ def pred_squares(image,
'''
h, w, _ = image.shape
original_shape = [h, w]
device = get_effective_device(model)
device = next(iter(model.parameters())).device
resized_image = np.concatenate([cv2.resize(image, (input_shape[0], input_shape[1]), interpolation=cv2.INTER_AREA),
np.ones([input_shape[0], input_shape[1], 1])], axis=-1)

View File

@@ -13,7 +13,6 @@ from PIL import Image
from invokeai.backend.image_util.normal_bae.nets.NNET import NNET
from invokeai.backend.image_util.util import np_to_pil, pil_to_np, resize_to_multiple
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
class NormalMapDetector:
@@ -65,7 +64,7 @@ class NormalMapDetector:
def run(self, image: Image.Image):
"""Processes an image and returns the detected normal map."""
device = get_effective_device(self.model)
device = next(iter(self.model.parameters())).device
np_image = pil_to_np(image)
height, width, _channels = np_image.shape

View File

@@ -11,7 +11,6 @@ from PIL import Image
from invokeai.backend.image_util.pidi.model import PiDiNet, pidinet
from invokeai.backend.image_util.util import nms, normalize_image_channel_count, np_to_pil, pil_to_np, safe_step
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
class PIDINetDetector:
@@ -46,7 +45,7 @@ class PIDINetDetector:
) -> Image.Image:
"""Processes an image and returns the detected edges."""
device = get_effective_device(self.model)
device = next(iter(self.model.parameters())).device
np_img = pil_to_np(image)
np_img = normalize_image_channel_count(np_img)

View File

@@ -57,31 +57,25 @@ class LoadedModelWithoutConfig:
self._cache = cache
def __enter__(self) -> AnyModel:
self._cache.lock(self._cache_record, None)
self._cache.lock(self._cache_record)
return self.model
def __exit__(self, *args: Any, **kwargs: Any) -> None:
self._cache.unlock(self._cache_record)
@contextmanager
def model_on_device(
self, working_mem_bytes: Optional[int] = None
) -> Generator[Tuple[Optional[Dict[str, torch.Tensor]], AnyModel], None, None]:
"""Return a tuple consisting of the model's state dict (if it exists) and the locked model on execution device.
:param working_mem_bytes: The amount of working memory to keep available on the compute device when loading the
model.
"""
self._cache.lock(self._cache_record, working_mem_bytes)
def model_on_device(self) -> Generator[Tuple[Optional[Dict[str, torch.Tensor]], AnyModel], None, None]:
"""Return a tuple consisting of the model's state dict (if it exists) and the locked model on execution device."""
self._cache.lock(self._cache_record)
try:
yield (self._cache_record.cached_model.get_cpu_state_dict(), self._cache_record.cached_model.model)
yield (self._cache_record.state_dict, self._cache_record.model)
finally:
self._cache.unlock(self._cache_record)
@property
def model(self) -> AnyModel:
"""Return the model without locking it."""
return self._cache_record.cached_model.model
return self._cache_record.model
class LoadedModel(LoadedModelWithoutConfig):

View File

@@ -1,21 +1,38 @@
from dataclasses import dataclass
from typing import Any, Dict, Optional
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import (
CachedModelOnlyFullLoad,
)
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import (
CachedModelWithPartialLoad,
)
import torch
@dataclass
class CacheRecord:
"""A class that represents a model in the model cache."""
"""
Elements of the cache:
key: Unique key for each model, same as used in the models database.
model: Model in memory.
state_dict: A read-only copy of the model's state dict in RAM. It will be
used as a template for creating a copy in the VRAM.
size: Size of the model
loaded: True if the model's state dict is currently in VRAM
Before a model is executed, the state_dict template is copied into VRAM,
and then injected into the model. When the model is finished, the VRAM
copy of the state dict is deleted, and the RAM version is reinjected
into the model.
The state_dict should be treated as a read-only attribute. Do not attempt
to patch or otherwise modify it. Instead, patch the copy of the state_dict
after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel`
context manager call `model_on_device()`.
"""
# Cache key.
key: str
# Model in memory.
cached_model: CachedModelWithPartialLoad | CachedModelOnlyFullLoad
model: Any
device: torch.device
state_dict: Optional[Dict[str, torch.Tensor]]
size: int
loaded: bool = False
_locks: int = 0
def lock(self) -> None:
@@ -28,6 +45,6 @@ class CacheRecord:
assert self._locks >= 0
@property
def is_locked(self) -> bool:
def locked(self) -> bool:
"""Return true if record is locked."""
return self._locks > 0

View File

@@ -7,6 +7,18 @@ from invokeai.backend.util.calc_tensor_size import calc_tensor_size
from invokeai.backend.util.logging import InvokeAILogger
def set_nested_attr(obj: object, attr: str, value: object):
"""A helper function that extends setattr() to support nested attributes.
Example:
set_nested_attr(model, "module.encoder.conv1.weight", new_conv1_weight)
"""
attrs = attr.split(".")
for attr in attrs[:-1]:
obj = getattr(obj, attr)
setattr(obj, attrs[-1], value)
class CachedModelWithPartialLoad:
"""A wrapper around a PyTorch model to handle partial loads and unloads between the CPU and the compute device.
@@ -21,14 +33,9 @@ class CachedModelWithPartialLoad:
# A CPU read-only copy of the model's state dict.
self._cpu_state_dict: dict[str, torch.Tensor] = model.state_dict()
# A dictionary of the size of each tensor in the state dict.
# HACK(ryand): We use this dictionary any time we are doing byte tracking calculations. We do this for
# consistency in case the application code has modified the model's size (e.g. by casting to a different
# precision). Of course, this means that we are making model cache load/unload decisions based on model size
# data that may not be fully accurate.
self._state_dict_bytes = {k: calc_tensor_size(v) for k, v in self._cpu_state_dict.items()}
self._total_bytes = sum(self._state_dict_bytes.values())
# TODO(ryand): Handle the case where the model sizes changes after initial load (e.g. due to dtype casting).
# Consider how we should handle this for both self._total_bytes and self._cur_vram_bytes.
self._total_bytes = sum(calc_tensor_size(p) for p in self._cpu_state_dict.values())
self._cur_vram_bytes: int | None = None
self._modules_that_support_autocast = self._find_modules_that_support_autocast()
@@ -84,9 +91,7 @@ class CachedModelWithPartialLoad:
if self._cur_vram_bytes is None:
cur_state_dict = self._model.state_dict()
self._cur_vram_bytes = sum(
self._state_dict_bytes[k]
for k, v in cur_state_dict.items()
if v.device.type == self._compute_device.type
calc_tensor_size(p) for p in cur_state_dict.values() if p.device.type == self._compute_device.type
)
return self._cur_vram_bytes
@@ -118,7 +123,7 @@ class CachedModelWithPartialLoad:
if param.device.type == self._compute_device.type:
continue
param_size = self._state_dict_bytes[key]
param_size = calc_tensor_size(param)
cur_state_dict[key] = param.to(self._compute_device, copy=True)
vram_bytes_loaded += param_size
@@ -135,7 +140,7 @@ class CachedModelWithPartialLoad:
if param.device.type == self._compute_device.type:
continue
param_size = self._state_dict_bytes[key]
param_size = calc_tensor_size(param)
if vram_bytes_loaded + param_size > vram_bytes_to_load:
# TODO(ryand): Should we just break here? If we couldn't fit this parameter into VRAM, is it really
# worth continuing to search for a smaller parameter that would fit?
@@ -156,6 +161,7 @@ class CachedModelWithPartialLoad:
if fully_loaded:
self._set_autocast_enabled_in_all_modules(False)
# TODO(ryand): Warn if the self.cur_vram_bytes() and self.total_bytes() are out of sync.
else:
self._set_autocast_enabled_in_all_modules(True)
@@ -166,17 +172,13 @@ class CachedModelWithPartialLoad:
return vram_bytes_loaded
@torch.no_grad()
def partial_unload_from_vram(self, vram_bytes_to_free: int, keep_required_weights_in_vram: bool = False) -> int:
def partial_unload_from_vram(self, vram_bytes_to_free: int) -> int:
"""Unload weights from VRAM until vram_bytes_to_free bytes are freed. Or the entire model is unloaded.
:param keep_required_weights_in_vram: If True, any weights that must be kept in VRAM to run the model will be
kept in VRAM.
Returns:
The number of bytes unloaded from VRAM.
"""
vram_bytes_freed = 0
required_weights_in_vram = 0
offload_device = "cpu"
cur_state_dict = self._model.state_dict()
@@ -187,12 +189,8 @@ class CachedModelWithPartialLoad:
if param.device.type == offload_device:
continue
if keep_required_weights_in_vram and key in self._keys_in_modules_that_do_not_support_autocast:
required_weights_in_vram += self._state_dict_bytes[key]
continue
cur_state_dict[key] = self._cpu_state_dict[key]
vram_bytes_freed += self._state_dict_bytes[key]
vram_bytes_freed += calc_tensor_size(param)
if vram_bytes_freed > 0:
self._model.load_state_dict(cur_state_dict, assign=True)

View File

@@ -1,33 +0,0 @@
from contextlib import contextmanager
import torch
from invokeai.backend.util.logging import InvokeAILogger
@contextmanager
def log_operation_vram_usage(operation_name: str):
"""A helper function for tuning working memory requirements for memory-intensive ops.
Sample usage:
```python
with log_operation_vram_usage("some_operation"):
some_operation()
```
"""
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
max_allocated_before = torch.cuda.max_memory_allocated()
max_reserved_before = torch.cuda.max_memory_reserved()
try:
yield
finally:
torch.cuda.synchronize()
max_allocated_after = torch.cuda.max_memory_allocated()
max_reserved_after = torch.cuda.max_memory_reserved()
logger = InvokeAILogger.get_logger()
logger.info(
f">>>{operation_name} Peak VRAM allocated: {(max_allocated_after - max_allocated_before) / 2**20} MB, "
f"Peak VRAM reserved: {(max_reserved_after - max_reserved_before) / 2**20} MB"
)

View File

@@ -1,29 +1,24 @@
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
# TODO: Add Stalker's proper name to copyright
import gc
import logging
import math
import time
from logging import Logger
from typing import Dict, List, Optional
import psutil
import torch
from invokeai.backend.model_manager import AnyModel, SubModelType
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import (
CachedModelOnlyFullLoad,
)
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import (
CachedModelWithPartialLoad,
)
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import (
apply_custom_layers_to_model,
)
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.backend.util.prefix_logger_adapter import PrefixedLoggerAdapter
# Size of a GB in bytes.
GB = 2**30
@@ -34,7 +29,6 @@ MB = 2**20
# TODO(ryand): Where should this go? The ModelCache shouldn't be concerned with submodels.
def get_model_cache_key(model_key: str, submodel_type: Optional[SubModelType] = None) -> str:
"""Get the cache key for a model based on the optional submodel type."""
if submodel_type:
return f"{model_key}:{submodel_type.value}"
else:
@@ -76,51 +70,61 @@ class ModelCache:
def __init__(
self,
execution_device_working_mem_gb: float,
enable_partial_loading: bool,
max_ram_cache_size_gb: float | None = None,
max_vram_cache_size_gb: float | None = None,
execution_device: torch.device | str = "cuda",
storage_device: torch.device | str = "cpu",
max_cache_size: float,
max_vram_cache_size: float,
execution_device: torch.device = torch.device("cuda"),
storage_device: torch.device = torch.device("cpu"),
lazy_offloading: bool = True,
log_memory_usage: bool = False,
logger: Optional[Logger] = None,
):
"""Initialize the model RAM cache.
"""
Initialize the model RAM cache.
:param execution_device_working_mem_gb: The amount of working memory to keep on the GPU (in GB) i.e. non-model
VRAM.
:param enable_partial_loading: Whether to enable partial loading of models.
:param max_ram_cache_size_gb: The maximum amount of CPU RAM to use for model caching in GB. This parameter is
kept to maintain compatibility with previous versions of the model cache, but should be deprecated in the
future. If set, this parameter overrides the default cache size logic.
:param max_vram_cache_size_gb: The amount of VRAM to use for model caching in GB. This parameter is kept to
maintain compatibility with previous versions of the model cache, but should be deprecated in the future.
If set, this parameter overrides the default cache size logic.
:param max_cache_size: Maximum size of the storage_device cache in GBs.
:param max_vram_cache_size: Maximum size of the execution_device cache in GBs.
:param execution_device: Torch device to load active model into [torch.device('cuda')]
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
behaviour.
:param logger: InvokeAILogger to use (otherwise creates one)
"""
self._enable_partial_loading = enable_partial_loading
self._execution_device_working_mem_gb = execution_device_working_mem_gb
self._execution_device: torch.device = torch.device(execution_device)
self._storage_device: torch.device = torch.device(storage_device)
self._max_ram_cache_size_gb = max_ram_cache_size_gb
self._max_vram_cache_size_gb = max_vram_cache_size_gb
self._logger = PrefixedLoggerAdapter(
logger or InvokeAILogger.get_logger(self.__class__.__name__), "MODEL CACHE"
)
# allow lazy offloading only when vram cache enabled
self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
self._max_cache_size: float = max_cache_size
self._max_vram_cache_size: float = max_vram_cache_size
self._execution_device: torch.device = execution_device
self._storage_device: torch.device = storage_device
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
self._log_memory_usage = log_memory_usage
self._stats: Optional[CacheStats] = None
self._cached_models: Dict[str, CacheRecord] = {}
self._cache_stack: List[str] = []
@property
def max_cache_size(self) -> float:
"""Return the cap on cache size."""
return self._max_cache_size
@max_cache_size.setter
def max_cache_size(self, value: float) -> None:
"""Set the cap on cache size."""
self._max_cache_size = value
@property
def max_vram_cache_size(self) -> float:
"""Return the cap on vram cache size."""
return self._max_vram_cache_size
@max_vram_cache_size.setter
def max_vram_cache_size(self, value: float) -> None:
"""Set the cap on vram cache size."""
self._max_vram_cache_size = value
@property
def stats(self) -> Optional[CacheStats]:
"""Return collected CacheStats object."""
@@ -128,17 +132,17 @@ class ModelCache:
@stats.setter
def stats(self, stats: CacheStats) -> None:
"""Set the CacheStats object for collecting cache statistics."""
"""Set the CacheStats object for collectin cache statistics."""
self._stats = stats
def put(self, key: str, model: AnyModel) -> None:
"""Add a model to the cache."""
def put(
self,
key: str,
model: AnyModel,
) -> None:
"""Insert model into the cache."""
if key in self._cached_models:
self._logger.debug(
f"Attempted to add model {key} ({model.__class__.__name__}), but it already exists in the cache. No action necessary."
)
return
size = calc_model_size_by_data(self._logger, model)
self.make_room(size)
@@ -146,26 +150,17 @@ class ModelCache:
if isinstance(model, torch.nn.Module):
apply_custom_layers_to_model(model)
# Partial loading only makes sense on CUDA.
# - When running on CPU, there is no 'loading' to do.
# - When running on MPS, memory is shared with the CPU, so the default OS memory management already handles this
# well.
running_with_cuda = self._execution_device.type == "cuda"
# Wrap model.
if isinstance(model, torch.nn.Module) and running_with_cuda and self._enable_partial_loading:
wrapped_model = CachedModelWithPartialLoad(model, self._execution_device)
else:
wrapped_model = CachedModelOnlyFullLoad(model, self._execution_device, size)
cache_record = CacheRecord(key=key, cached_model=wrapped_model)
running_on_cpu = self._execution_device == torch.device("cpu")
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) and not running_on_cpu else None
cache_record = CacheRecord(key=key, model=model, device=self._storage_device, state_dict=state_dict, size=size)
self._cached_models[key] = cache_record
self._cache_stack.append(key)
self._logger.debug(
f"Added model {key} (Type: {model.__class__.__name__}, Wrap mode: {wrapped_model.__class__.__name__}, Model size: {size/MB:.2f}MB)"
)
def get(self, key: str, stats_name: Optional[str] = None) -> CacheRecord:
def get(
self,
key: str,
stats_name: Optional[str] = None,
) -> CacheRecord:
"""Retrieve a model from the cache.
:param key: Model key
@@ -179,7 +174,6 @@ class ModelCache:
else:
if self.stats:
self.stats.misses += 1
self._logger.debug(f"Cache miss: {key}")
raise IndexError(f"The model with key {key} is not in the cache.")
cache_entry = self._cached_models[key]
@@ -187,44 +181,37 @@ class ModelCache:
# more stats
if self.stats:
stats_name = stats_name or key
self.stats.high_watermark = max(self.stats.high_watermark, self._get_ram_in_use())
self.stats.cache_size = int(self._max_cache_size * GB)
self.stats.high_watermark = max(self.stats.high_watermark, self._get_cache_size())
self.stats.in_cache = len(self._cached_models)
self.stats.loaded_model_sizes[stats_name] = max(
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.cached_model.total_bytes()
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size
)
# This moves the entry to the top (right end) of the stack.
# this moves the entry to the top (right end) of the stack
self._cache_stack = [k for k in self._cache_stack if k != key]
self._cache_stack.append(key)
self._logger.debug(f"Cache hit: {key} (Type: {cache_entry.cached_model.model.__class__.__name__})")
return cache_entry
def lock(self, cache_entry: CacheRecord, working_mem_bytes: Optional[int]) -> None:
def lock(self, cache_entry: CacheRecord) -> None:
"""Lock a model for use and move it into VRAM."""
if cache_entry.key not in self._cached_models:
self._logger.info(
f"Locking model cache entry {cache_entry.key} "
f"(Type: {cache_entry.cached_model.model.__class__.__name__}), but it has already been dropped from "
"the RAM cache. This is a sign that the model loading order is non-optimal in the invocation code "
"(See https://github.com/invoke-ai/InvokeAI/issues/7513)."
f"Locking model cache entry {cache_entry.key} ({cache_entry.model.__class__.__name__}), but it has "
"already been dropped from the RAM cache. This is a sign that the model loading order is non-optimal "
"in the invocation code."
)
# cache_entry = self._cached_models[key]
cache_entry.lock()
self._logger.debug(
f"Locking model {cache_entry.key} (Type: {cache_entry.cached_model.model.__class__.__name__})"
)
if self._execution_device.type == "cpu":
# Models don't need to be loaded into VRAM if we're running on CPU.
return
try:
self._load_locked_model(cache_entry, working_mem_bytes)
self._logger.debug(
f"Finished locking model {cache_entry.key} (Type: {cache_entry.cached_model.model.__class__.__name__})"
)
if self._lazy_offloading:
self._offload_unlocked_models(cache_entry.size)
self._move_model_to_device(cache_entry, self._execution_device)
cache_entry.loaded = True
self._logger.debug(f"Locking {cache_entry.key} in {self._execution_device}")
self._print_cuda_stats()
except torch.cuda.OutOfMemoryError:
self._logger.warning("Insufficient GPU memory to load model. Aborting")
cache_entry.unlock()
@@ -233,333 +220,201 @@ class ModelCache:
cache_entry.unlock()
raise
self._log_cache_state()
def unlock(self, cache_entry: CacheRecord) -> None:
"""Unlock a model."""
if cache_entry.key not in self._cached_models:
self._logger.info(
f"Unlocking model cache entry {cache_entry.key} "
f"(Type: {cache_entry.cached_model.model.__class__.__name__}), but it has already been dropped from "
"the RAM cache. This is a sign that the model loading order is non-optimal in the invocation code "
"(See https://github.com/invoke-ai/InvokeAI/issues/7513)."
f"Unlocking model cache entry {cache_entry.key} ({cache_entry.model.__class__.__name__}), but it has "
"already been dropped from the RAM cache. This is a sign that the model loading order is non-optimal "
"in the invocation code."
)
# cache_entry = self._cached_models[key]
cache_entry.unlock()
self._logger.debug(
f"Unlocked model {cache_entry.key} (Type: {cache_entry.cached_model.model.__class__.__name__})"
)
if not self._lazy_offloading:
self._offload_unlocked_models(0)
self._print_cuda_stats()
def _load_locked_model(self, cache_entry: CacheRecord, working_mem_bytes: Optional[int] = None) -> None:
"""Helper function for self.lock(). Loads a locked model into VRAM."""
start_time = time.time()
# Calculate model_vram_needed, the amount of additional VRAM that will be used if we fully load the model into
# VRAM.
model_cur_vram_bytes = cache_entry.cached_model.cur_vram_bytes()
model_total_bytes = cache_entry.cached_model.total_bytes()
model_vram_needed = model_total_bytes - model_cur_vram_bytes
vram_available = self._get_vram_available(working_mem_bytes)
self._logger.debug(
f"Before unloading: {self._get_vram_state_str(model_cur_vram_bytes, model_total_bytes, vram_available)}"
)
# Make room for the model in VRAM.
# 1. If the model can fit entirely in VRAM, then make enough room for it to be loaded fully.
# 2. If the model can't fit fully into VRAM, then unload all other models and load as much of the model as
# possible.
vram_bytes_freed = self._offload_unlocked_models(model_vram_needed, working_mem_bytes)
self._logger.debug(f"Unloaded models (if necessary): vram_bytes_freed={(vram_bytes_freed/MB):.2f}MB")
# Check the updated vram_available after offloading.
vram_available = self._get_vram_available(working_mem_bytes)
self._logger.debug(
f"After unloading: {self._get_vram_state_str(model_cur_vram_bytes, model_total_bytes, vram_available)}"
)
if vram_available < 0:
# There is insufficient VRAM available. As a last resort, try to unload the model being locked from VRAM,
# as it may still be loaded from a previous use.
vram_bytes_freed_from_own_model = self._move_model_to_ram(cache_entry, -vram_available)
vram_available = self._get_vram_available(working_mem_bytes)
self._logger.debug(
f"Unloaded {vram_bytes_freed_from_own_model/MB:.2f}MB from the model being locked ({cache_entry.key})."
)
# Move as much of the model as possible into VRAM.
# For testing, only allow 10% of the model to be loaded into VRAM.
# vram_available = int(model_vram_needed * 0.1)
# We add 1 MB to the available VRAM to account for small errors in memory tracking (e.g. off-by-one). A fully
# loaded model is much faster than a 95% loaded model.
model_bytes_loaded = self._move_model_to_vram(cache_entry, vram_available + MB)
model_cur_vram_bytes = cache_entry.cached_model.cur_vram_bytes()
vram_available = self._get_vram_available(working_mem_bytes)
loaded_percent = model_cur_vram_bytes / model_total_bytes if model_total_bytes > 0 else 0
self._logger.info(
f"Loaded model '{cache_entry.key}' ({cache_entry.cached_model.model.__class__.__name__}) onto "
f"{self._execution_device.type} device in {(time.time() - start_time):.2f}s. "
f"Total model size: {model_total_bytes/MB:.2f}MB, "
f"VRAM: {model_cur_vram_bytes/MB:.2f}MB ({loaded_percent:.1%})"
)
self._logger.debug(f"Loaded model onto execution device: model_bytes_loaded={(model_bytes_loaded/MB):.2f}MB, ")
self._logger.debug(
f"After loading: {self._get_vram_state_str(model_cur_vram_bytes, model_total_bytes, vram_available)}"
)
def _move_model_to_vram(self, cache_entry: CacheRecord, vram_available: int) -> int:
try:
if isinstance(cache_entry.cached_model, CachedModelWithPartialLoad):
return cache_entry.cached_model.partial_load_to_vram(vram_available)
elif isinstance(cache_entry.cached_model, CachedModelOnlyFullLoad): # type: ignore
# Partial load is not supported, so we have not choice but to try and fit it all into VRAM.
return cache_entry.cached_model.full_load_to_vram()
else:
raise ValueError(f"Unsupported cached model type: {type(cache_entry.cached_model)}")
except Exception as e:
if isinstance(e, torch.cuda.OutOfMemoryError):
self._logger.warning("Insufficient GPU memory to load model. Aborting")
# If an exception occurs, the model could be left in a bad state, so we delete it from the cache entirely.
self._delete_cache_entry(cache_entry)
raise
def _move_model_to_ram(self, cache_entry: CacheRecord, vram_bytes_to_free: int) -> int:
try:
if isinstance(cache_entry.cached_model, CachedModelWithPartialLoad):
return cache_entry.cached_model.partial_unload_from_vram(
vram_bytes_to_free, keep_required_weights_in_vram=cache_entry.is_locked
)
elif isinstance(cache_entry.cached_model, CachedModelOnlyFullLoad): # type: ignore
return cache_entry.cached_model.full_unload_from_vram()
else:
raise ValueError(f"Unsupported cached model type: {type(cache_entry.cached_model)}")
except Exception:
# If an exception occurs, the model could be left in a bad state, so we delete it from the cache entirely.
self._delete_cache_entry(cache_entry)
raise
def _get_vram_available(self, working_mem_bytes: Optional[int]) -> int:
"""Calculate the amount of additional VRAM available for the cache to use (takes into account the working
memory).
"""
# If self._max_vram_cache_size_gb is set, then it overrides the default logic.
if self._max_vram_cache_size_gb is not None:
vram_total_available_to_cache = int(self._max_vram_cache_size_gb * GB)
return vram_total_available_to_cache - self._get_vram_in_use()
working_mem_bytes_default = int(self._execution_device_working_mem_gb * GB)
working_mem_bytes = max(working_mem_bytes or working_mem_bytes_default, working_mem_bytes_default)
if self._execution_device.type == "cuda":
# TODO(ryand): It is debatable whether we should use memory_reserved() or memory_allocated() here.
# memory_reserved() includes memory reserved by the torch CUDA memory allocator that may or may not be
# re-used for future allocations. For now, we use memory_allocated() to be conservative.
# vram_reserved = torch.cuda.memory_reserved(self._execution_device)
vram_allocated = torch.cuda.memory_allocated(self._execution_device)
vram_free, _vram_total = torch.cuda.mem_get_info(self._execution_device)
vram_available_to_process = vram_free + vram_allocated
elif self._execution_device.type == "mps":
vram_reserved = torch.mps.driver_allocated_memory()
# TODO(ryand): Is it accurate that MPS shares memory with the CPU?
vram_free = psutil.virtual_memory().available
vram_available_to_process = vram_free + vram_reserved
else:
raise ValueError(f"Unsupported execution device: {self._execution_device.type}")
vram_total_available_to_cache = vram_available_to_process - working_mem_bytes
vram_cur_available_to_cache = vram_total_available_to_cache - self._get_vram_in_use()
return vram_cur_available_to_cache
def _get_vram_in_use(self) -> int:
"""Get the amount of VRAM currently in use by the cache."""
if self._execution_device.type == "cuda":
return torch.cuda.memory_allocated()
elif self._execution_device.type == "mps":
return torch.mps.current_allocated_memory()
else:
raise ValueError(f"Unsupported execution device type: {self._execution_device.type}")
# Alternative definition of VRAM in use:
# return sum(ce.cached_model.cur_vram_bytes() for ce in self._cached_models.values())
def _get_ram_available(self) -> int:
"""Get the amount of RAM available for the cache to use, while keeping memory pressure under control."""
# If self._max_ram_cache_size_gb is set, then it overrides the default logic.
if self._max_ram_cache_size_gb is not None:
ram_total_available_to_cache = int(self._max_ram_cache_size_gb * GB)
return ram_total_available_to_cache - self._get_ram_in_use()
virtual_memory = psutil.virtual_memory()
ram_total = virtual_memory.total
ram_available = virtual_memory.available
ram_used = ram_total - ram_available
# The total size of all the models in the cache will often be larger than the amount of RAM reported by psutil
# (due to lazy-loading and OS RAM caching behaviour). We could just rely on the psutil values, but it feels
# like a bad idea to over-fill the model cache. So, for now, we'll try to keep the total size of models in the
# cache under the total amount of system RAM.
cache_ram_used = self._get_ram_in_use()
ram_used = max(cache_ram_used, ram_used)
# Aim to keep 10% of RAM free.
ram_available_based_on_memory_usage = int(ram_total * 0.9) - ram_used
# If we are running out of RAM, then there's an increased likelihood that we will run into this issue:
# https://github.com/invoke-ai/InvokeAI/issues/7513
# To keep things running smoothly, there's a minimum RAM cache size that we always allow (even if this means
# using swap).
min_ram_cache_size_bytes = 4 * GB
ram_available_based_on_min_cache_size = min_ram_cache_size_bytes - cache_ram_used
return max(ram_available_based_on_memory_usage, ram_available_based_on_min_cache_size)
def _get_ram_in_use(self) -> int:
"""Get the amount of RAM currently in use."""
return sum(ce.cached_model.total_bytes() for ce in self._cached_models.values())
def _get_cache_size(self) -> int:
"""Get the total size of the models currently cached."""
total = 0
for cache_record in self._cached_models.values():
total += cache_record.size
return total
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
if self._log_memory_usage:
return MemorySnapshot.capture()
return None
def _get_vram_state_str(self, model_cur_vram_bytes: int, model_total_bytes: int, vram_available: int) -> str:
"""Helper function for preparing a VRAM state log string."""
model_cur_vram_bytes_percent = model_cur_vram_bytes / model_total_bytes if model_total_bytes > 0 else 0
return (
f"model_total={model_total_bytes/MB:.0f} MB, "
+ f"model_vram={model_cur_vram_bytes/MB:.0f} MB ({model_cur_vram_bytes_percent:.1%} %), "
# + f"vram_total={int(self._max_vram_cache_size * GB)/MB:.0f} MB, "
+ f"vram_available={(vram_available/MB):.0f} MB, "
)
def _make_cache_key(self, model_key: str, submodel_type: Optional[SubModelType] = None) -> str:
if submodel_type:
return f"{model_key}:{submodel_type.value}"
else:
return model_key
def _offload_unlocked_models(self, vram_bytes_required: int, working_mem_bytes: Optional[int] = None) -> int:
"""Offload models from the execution_device until vram_bytes_required bytes are available, or all models are
offloaded. Of course, locked models are not offloaded.
def _offload_unlocked_models(self, size_required: int) -> None:
"""Offload models from the execution_device to make room for size_required.
Returns:
int: The number of bytes freed based on believed model sizes. The actual change in VRAM may be different.
:param size_required: The amount of space to clear in the execution_device cache, in bytes.
"""
self._logger.debug(
f"Offloading unlocked models with goal of making room for {vram_bytes_required/MB:.2f}MB of VRAM."
)
vram_bytes_freed = 0
# TODO(ryand): Give more thought to the offloading policy used here.
cache_entries_increasing_size = sorted(self._cached_models.values(), key=lambda x: x.cached_model.total_bytes())
for cache_entry in cache_entries_increasing_size:
# We do not fully trust the count of bytes freed, so we check again on each iteration.
vram_available = self._get_vram_available(working_mem_bytes)
vram_bytes_to_free = vram_bytes_required - vram_available
if vram_bytes_to_free <= 0:
reserved = self._max_vram_cache_size * GB
vram_in_use = torch.cuda.memory_allocated() + size_required
self._logger.debug(f"{(vram_in_use/GB):.2f}GB VRAM needed for models; max allowed={(reserved/GB):.2f}GB")
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
if vram_in_use <= reserved:
break
if cache_entry.is_locked:
# TODO(ryand): In the future, we may want to partially unload locked models, but this requires careful
# handling of model patches (e.g. LoRA).
if not cache_entry.loaded:
continue
cache_entry_bytes_freed = self._move_model_to_ram(cache_entry, vram_bytes_to_free)
if cache_entry_bytes_freed > 0:
if not cache_entry.locked:
self._move_model_to_device(cache_entry, self._storage_device)
cache_entry.loaded = False
vram_in_use = torch.cuda.memory_allocated() + size_required
self._logger.debug(
f"Unloaded {cache_entry.key} from VRAM to free {(cache_entry_bytes_freed/MB):.0f} MB."
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GB):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GB):.2f}GB"
)
vram_bytes_freed += cache_entry_bytes_freed
TorchDevice.empty_cache()
return vram_bytes_freed
def _log_cache_state(self, title: str = "Model cache state:", include_entry_details: bool = True):
if self._logger.getEffectiveLevel() > logging.DEBUG:
# Short circuit if the logger is not set to debug. Some of the data lookups could take a non-negligible
# amount of time.
def _move_model_to_device(self, cache_entry: CacheRecord, target_device: torch.device) -> None:
"""Move model into the indicated device.
:param cache_entry: The CacheRecord for the model
:param target_device: The torch.device to move the model into
May raise a torch.cuda.OutOfMemoryError
"""
self._logger.debug(f"Called to move {cache_entry.key} to {target_device}")
source_device = cache_entry.device
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
# This would need to be revised to support multi-GPU.
if torch.device(source_device).type == torch.device(target_device).type:
return
log = f"{title}\n"
# Some models don't have a `to` method, in which case they run in RAM/CPU.
if not hasattr(cache_entry.model, "to"):
return
log_format = " {:<30} Limit: {:>7.1f} MB, Used: {:>7.1f} MB ({:>5.1%}), Available: {:>7.1f} MB ({:>5.1%})\n"
# This roundabout method for moving the model around is done to avoid
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
# When moving to VRAM, we copy (not move) each element of the state dict from
# RAM to a new state dict in VRAM, and then inject it into the model.
# This operation is slightly faster than running `to()` on the whole model.
#
# When the model needs to be removed from VRAM we simply delete the copy
# of the state dict in VRAM, and reinject the state dict that is cached
# in RAM into the model. So this operation is very fast.
start_model_to_time = time.time()
snapshot_before = self._capture_memory_snapshot()
ram_in_use_bytes = self._get_ram_in_use()
ram_available_bytes = self._get_ram_available()
ram_size_bytes = ram_in_use_bytes + ram_available_bytes
ram_in_use_bytes_percent = ram_in_use_bytes / ram_size_bytes if ram_size_bytes > 0 else 0
ram_available_bytes_percent = ram_available_bytes / ram_size_bytes if ram_size_bytes > 0 else 0
log += log_format.format(
f"Storage Device ({self._storage_device.type})",
ram_size_bytes / MB,
ram_in_use_bytes / MB,
ram_in_use_bytes_percent,
ram_available_bytes / MB,
ram_available_bytes_percent,
try:
if cache_entry.state_dict is not None:
assert hasattr(cache_entry.model, "load_state_dict")
if target_device == self._storage_device:
cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True)
else:
new_dict: Dict[str, torch.Tensor] = {}
for k, v in cache_entry.state_dict.items():
new_dict[k] = v.to(target_device, copy=True)
cache_entry.model.load_state_dict(new_dict, assign=True)
cache_entry.model.to(target_device)
cache_entry.device = target_device
except Exception as e: # blow away cache entry
self._delete_cache_entry(cache_entry)
raise e
snapshot_after = self._capture_memory_snapshot()
end_model_to_time = time.time()
self._logger.debug(
f"Moved model '{cache_entry.key}' from {source_device} to"
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
f"Estimated model size: {(cache_entry.size/GB):.3f} GB."
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
if self._execution_device.type != "cpu":
vram_in_use_bytes = self._get_vram_in_use()
vram_available_bytes = self._get_vram_available(None)
vram_size_bytes = vram_in_use_bytes + vram_available_bytes
vram_in_use_bytes_percent = vram_in_use_bytes / vram_size_bytes if vram_size_bytes > 0 else 0
vram_available_bytes_percent = vram_available_bytes / vram_size_bytes if vram_size_bytes > 0 else 0
log += log_format.format(
f"Compute Device ({self._execution_device.type})",
vram_size_bytes / MB,
vram_in_use_bytes / MB,
vram_in_use_bytes_percent,
vram_available_bytes / MB,
vram_available_bytes_percent,
)
if (
snapshot_before is not None
and snapshot_after is not None
and snapshot_before.vram is not None
and snapshot_after.vram is not None
):
vram_change = abs(snapshot_before.vram - snapshot_after.vram)
if torch.cuda.is_available():
log += " {:<30} {:.1f} MB\n".format("CUDA Memory Allocated:", torch.cuda.memory_allocated() / MB)
log += " {:<30} {}\n".format("Total models:", len(self._cached_models))
if include_entry_details and len(self._cached_models) > 0:
log += " Models:\n"
log_format = (
" {:<80} total={:>7.1f} MB, vram={:>7.1f} MB ({:>5.1%}), ram={:>7.1f} MB ({:>5.1%}), locked={}\n"
)
for cache_record in self._cached_models.values():
total_bytes = cache_record.cached_model.total_bytes()
cur_vram_bytes = cache_record.cached_model.cur_vram_bytes()
cur_vram_bytes_percent = cur_vram_bytes / total_bytes if total_bytes > 0 else 0
cur_ram_bytes = total_bytes - cur_vram_bytes
cur_ram_bytes_percent = cur_ram_bytes / total_bytes if total_bytes > 0 else 0
log += log_format.format(
f"{cache_record.key} ({cache_record.cached_model.model.__class__.__name__}):",
total_bytes / MB,
cur_vram_bytes / MB,
cur_vram_bytes_percent,
cur_ram_bytes / MB,
cur_ram_bytes_percent,
cache_record.is_locked,
# If the estimated model size does not match the change in VRAM, log a warning.
if not math.isclose(
vram_change,
cache_entry.size,
rel_tol=0.1,
abs_tol=10 * MB,
):
self._logger.debug(
f"Moving model '{cache_entry.key}' from {source_device} to"
f" {target_device} caused an unexpected change in VRAM usage. The model's"
" estimated size may be incorrect. Estimated model size:"
f" {(cache_entry.size/GB):.3f} GB.\n"
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
self._logger.debug(log)
def _print_cuda_stats(self) -> None:
"""Log CUDA diagnostics."""
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GB)
ram = "%4.2fG" % (self._get_cache_size() / GB)
def make_room(self, bytes_needed: int) -> None:
in_ram_models = 0
in_vram_models = 0
locked_in_vram_models = 0
for cache_record in self._cached_models.values():
if hasattr(cache_record.model, "device"):
if cache_record.model.device == self._storage_device:
in_ram_models += 1
else:
in_vram_models += 1
if cache_record.locked:
locked_in_vram_models += 1
self._logger.debug(
f"Current VRAM/RAM usage: {vram}/{ram}; models_in_ram/models_in_vram(locked) ="
f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})"
)
def make_room(self, size: int) -> None:
"""Make enough room in the cache to accommodate a new model of indicated size.
Note: This function deletes all of the cache's internal references to a model in order to free it. If there are
external references to the model, there's nothing that the cache can do about it, and those models will not be
garbage-collected.
"""
self._logger.debug(f"Making room for {bytes_needed/MB:.2f}MB of RAM.")
self._log_cache_state(title="Before dropping models:")
bytes_needed = size
maximum_size = self._max_cache_size * GB # stored in GB, convert to bytes
current_size = self._get_cache_size()
ram_bytes_available = self._get_ram_available()
ram_bytes_to_free = max(0, bytes_needed - ram_bytes_available)
if current_size + bytes_needed > maximum_size:
self._logger.debug(
f"Max cache size exceeded: {(current_size/GB):.2f}/{self.max_cache_size:.2f} GB, need an additional"
f" {(bytes_needed/GB):.2f} GB"
)
self._logger.debug(f"Before making_room: cached_models={len(self._cached_models)}")
ram_bytes_freed = 0
pos = 0
models_cleared = 0
while ram_bytes_freed < ram_bytes_to_free and pos < len(self._cache_stack):
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
model_key = self._cache_stack[pos]
cache_entry = self._cached_models[model_key]
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
self._logger.debug(
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}"
)
if not cache_entry.is_locked:
ram_bytes_freed += cache_entry.cached_model.total_bytes()
if not cache_entry.locked:
self._logger.debug(
f"Dropping {model_key} from RAM cache to free {(cache_entry.cached_model.total_bytes()/MB):.2f}MB."
f"Removing {model_key} from RAM cache to free at least {(size/GB):.2f} GB (-{(cache_entry.size/GB):.2f} GB)"
)
current_size -= cache_entry.size
models_cleared += 1
self._delete_cache_entry(cache_entry)
del cache_entry
models_cleared += 1
else:
pos += 1
@@ -580,10 +435,8 @@ class ModelCache:
gc.collect()
TorchDevice.empty_cache()
self._logger.debug(f"Dropped {models_cleared} models to free {ram_bytes_freed/MB:.2f}MB of RAM.")
self._log_cache_state(title="After dropping models:")
self._logger.debug(f"After making room: cached_models={len(self._cached_models)}")
def _delete_cache_entry(self, cache_entry: CacheRecord) -> None:
"""Delete cache_entry from the cache if it exists. No exception is thrown if it doesn't exist."""
self._cache_stack = [key for key in self._cache_stack if key != cache_entry.key]
self._cached_models.pop(cache_entry.key, None)
self._cache_stack.remove(cache_entry.key)
del self._cached_models[cache_entry.key]

View File

@@ -1,20 +0,0 @@
import itertools
import torch
def get_effective_device(model: torch.nn.Module) -> torch.device:
"""A utility to infer the 'effective' device of a model.
This utility handles the case where a model is partially loaded onto the GPU, so is safer than just calling:
`next(iter(model.parameters())).device`.
In the worst case, this utility has to check all model parameters, so if you already know the intended model device,
then it is better to avoid calling this function.
"""
# If all parameters are on the CPU, return the CPU device. Otherwise, return the first non-CPU device.
for p in itertools.chain(model.parameters(), model.buffers()):
if p.device.type != "cpu":
return p.device
return torch.device("cpu")

View File

@@ -14,7 +14,6 @@ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokeniz
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw
from invokeai.backend.util.devices import TorchDevice
class ModelPatcher:
@@ -123,7 +122,7 @@ class ModelPatcher:
)
model_embeddings.weight.data[token_id] = embedding.to(
device=TorchDevice.choose_torch_device(), dtype=text_encoder.dtype
device=text_encoder.device, dtype=text_encoder.dtype
)
ti_tokens.append(token_id)

View File

@@ -2,7 +2,6 @@ from typing import Dict, Optional
import torch
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
from invokeai.backend.patches.layers.lora_layer_base import LoRALayerBase
@@ -51,7 +50,7 @@ class IA3Layer(LoRALayerBase):
weight = self.weight
if not self.on_input:
weight = weight.reshape(-1, 1)
return cast_to_device(orig_weight, weight.device) * weight
return orig_weight * weight
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
super().to(device, dtype)

View File

@@ -12,7 +12,6 @@ from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
from invokeai.backend.util.devices import TorchDevice
if TYPE_CHECKING:
from invokeai.app.invocations.model import ModelIdentifierField
@@ -90,7 +89,7 @@ class T2IAdapterExt(ExtensionBase):
width=input_width,
height=input_height,
num_channels=model.config["in_channels"],
device=TorchDevice.choose_torch_device(),
device=model.device,
dtype=model.dtype,
resize_mode=self._resize_mode,
)

View File

@@ -1,12 +0,0 @@
import logging
from typing import Any, MutableMapping
# Issue with type hints related to LoggerAdapter: https://github.com/python/typeshed/issues/7855
class PrefixedLoggerAdapter(logging.LoggerAdapter): # type: ignore
def __init__(self, logger: logging.Logger, prefix: str):
super().__init__(logger, {})
self.prefix = prefix
def process(self, msg: str, kwargs: MutableMapping[str, Any]) -> tuple[str, MutableMapping[str, Any]]:
return f"[{self.prefix}] {msg}", kwargs

View File

@@ -604,12 +604,7 @@
"hfForbidden": "Sie haben keinen Zugriff auf dieses HF-Modell",
"hfTokenInvalid": "Ungültiges oder fehlendes HF-Token",
"restoreDefaultSettings": "Klicken, um die Standardeinstellungen des Modells zu verwenden.",
"usingDefaultSettings": "Die Standardeinstellungen des Modells werden verwendet",
"hfTokenInvalidErrorMessage": "Ungültiges oder fehlendes HuggingFace-Token.",
"hfTokenUnableToVerify": "HF-Token kann nicht überprüft werden",
"hfTokenUnableToVerifyErrorMessage": "HuggingFace-Token kann nicht überprüft werden. Dies ist wahrscheinlich auf einen Netzwerkfehler zurückzuführen. Bitte versuchen Sie es später erneut.",
"hfTokenSaved": "HF-Token gespeichert",
"hfTokenRequired": "Sie versuchen, ein Modell herunterzuladen, für das ein gültiges HuggingFace-Token erforderlich ist."
"usingDefaultSettings": "Die Standardeinstellungen des Modells werden verwendet"
},
"parameters": {
"images": "Bilder",
@@ -660,23 +655,12 @@
"canvasIsCompositing": "Leinwand ist beschäftigt (wird zusammengesetzt)",
"canvasIsFiltering": "Leinwand ist beschäftigt (wird gefiltert)",
"canvasIsSelectingObject": "Leinwand ist beschäftigt (wird Objekt ausgewählt)",
"noPrompts": "Keine Eingabeaufforderungen generiert",
"noModelSelected": "Kein Modell ausgewählt"
"noPrompts": "Keine Eingabeaufforderungen generiert"
},
"seed": "Seed",
"patchmatchDownScaleSize": "Herunterskalieren",
"seamlessXAxis": "Nahtlose X Achse",
"seamlessYAxis": "Nahtlose Y Achse",
"coherenceEdgeSize": "Kantengröße",
"infillColorValue": "Füllfarbe",
"controlNetControlMode": "Kontrollmodus",
"cancel": {
"cancel": "Abbrechen"
},
"iterations": "Iterationen",
"guidance": "Führung",
"coherenceMode": "Modus",
"recallMetadata": "Metadaten abrufen"
"seamlessYAxis": "Nahtlose Y Achse"
},
"settings": {
"displayInProgress": "Zwischenbilder anzeigen",
@@ -1143,12 +1127,6 @@
"Nahtloses Kacheln eines Bildes entlang der horizontalen Achse."
],
"heading": "Nahtlose Kachelung X Achse"
},
"compositingCoherenceEdgeSize": {
"paragraphs": [
"Die Kantengröße des Kohärenzdurchlaufs."
],
"heading": "Kantengröße"
}
},
"invocationCache": {

View File

@@ -177,11 +177,7 @@
"none": "None",
"new": "New",
"generating": "Generating",
"warnings": "Warnings",
"start": "Start",
"count": "Count",
"step": "Step",
"values": "Values"
"warnings": "Warnings"
},
"hrf": {
"hrf": "High Resolution Fix",
@@ -854,14 +850,7 @@
"defaultVAE": "Default VAE"
},
"nodes": {
"noBatchGroup": "no group",
"generator": "Generator",
"generatedValues": "Generated Values",
"commitValues": "Commit Values",
"addValue": "Add Value",
"addNode": "Add Node",
"lockLinearView": "Lock Linear View",
"unlockLinearView": "Unlock Linear View",
"addNodeToolTip": "Add Node (Shift+A, Space)",
"addLinearView": "Add to Linear View",
"animatedEdges": "Animated Edges",
@@ -1035,21 +1024,11 @@
"addingImagesTo": "Adding images to",
"invoke": "Invoke",
"missingFieldTemplate": "Missing field template",
"missingInputForField": "missing input",
"missingInputForField": "{{nodeLabel}} -> {{fieldLabel}}: missing input",
"missingNodeTemplate": "Missing node template",
"collectionEmpty": "empty collection",
"invalidBatchConfiguration": "Invalid batch configuration",
"batchNodeNotConnected": "Batch node not connected: {{label}}",
"collectionTooFewItems": "too few items, minimum {{minItems}}",
"collectionTooManyItems": "too many items, maximum {{maxItems}}",
"collectionStringTooLong": "too long, max {{maxLength}}",
"collectionStringTooShort": "too short, min {{minLength}}",
"collectionNumberGTMax": "{{value}} > {{maximum}} (inc max)",
"collectionNumberLTMin": "{{value}} < {{minimum}} (inc min)",
"collectionNumberGTExclusiveMax": "{{value}} >= {{exclusiveMaximum}} (exc max)",
"collectionNumberLTExclusiveMin": "{{value}} <= {{exclusiveMinimum}} (exc min)",
"collectionNumberNotMultipleOf": "{{value}} not multiple of {{multipleOf}}",
"batchNodeCollectionSizeMismatch": "Collection size mismatch on Batch {{batchGroupId}}",
"collectionEmpty": "{{nodeLabel}} -> {{fieldLabel}} empty collection",
"collectionTooFewItems": "{{nodeLabel}} -> {{fieldLabel}}: too few items, minimum {{minItems}}",
"collectionTooManyItems": "{{nodeLabel}} -> {{fieldLabel}}: too many items, maximum {{maxItems}}",
"noModelSelected": "No model selected",
"noT5EncoderModelSelected": "No T5 Encoder model selected for FLUX generation",
"noFLUXVAEModelSelected": "No VAE model selected for FLUX generation",
@@ -1206,7 +1185,6 @@
"modelAddedSimple": "Model Added to Queue",
"modelImportCanceled": "Model Import Canceled",
"outOfMemoryError": "Out of Memory Error",
"outOfMemoryErrorDescLocal": "Follow our <LinkComponent>Low VRAM guide</LinkComponent> to reduce OOMs.",
"outOfMemoryErrorDesc": "Your current generation settings exceed system capacity. Please adjust your settings and try again.",
"parameters": "Parameters",
"parameterSet": "Parameter Recalled",
@@ -1953,24 +1931,6 @@
"description": "Generates an edge map from the selected layer using the PiDiNet edge detection model.",
"scribble": "Scribble",
"quantize_edges": "Quantize Edges"
},
"img_blur": {
"label": "Blur Image",
"description": "Blurs the selected layer.",
"blur_type": "Blur Type",
"blur_radius": "Radius",
"gaussian_type": "Gaussian",
"box_type": "Box"
},
"img_noise": {
"label": "Noise Image",
"description": "Adds noise to the selected layer.",
"noise_type": "Noise Type",
"noise_amount": "Amount",
"gaussian_type": "Gaussian",
"salt_and_pepper_type": "Salt and Pepper",
"noise_color": "Colored Noise",
"size": "Noise Size"
}
},
"transform": {
@@ -2173,12 +2133,15 @@
"toGetStartedLocal": "To get started, make sure to download or import models needed to run Invoke. Then, enter a prompt in the box and click <StrongComponent>Invoke</StrongComponent> to generate your first image. Select a prompt template to improve results. You can choose to save your images directly to the <StrongComponent>Gallery</StrongComponent> or edit them to the <StrongComponent>Canvas</StrongComponent>.",
"toGetStarted": "To get started, enter a prompt in the box and click <StrongComponent>Invoke</StrongComponent> to generate your first image. Select a prompt template to improve results. You can choose to save your images directly to the <StrongComponent>Gallery</StrongComponent> or edit them to the <StrongComponent>Canvas</StrongComponent>.",
"gettingStartedSeries": "Want more guidance? Check out our <LinkComponent>Getting Started Series</LinkComponent> for tips on unlocking the full potential of the Invoke Studio.",
"lowVRAMMode": "For best performance, follow our <LinkComponent>Low VRAM guide</LinkComponent>.",
"noModelsInstalled": "It looks like you don't have any models installed! You can <DownloadStarterModelsButton>download a starter model bundle</DownloadStarterModelsButton> or <ImportModelsButton>import models</ImportModelsButton>."
"downloadStarterModels": "Download Starter Models",
"importModels": "Import Models",
"noModelsInstalled": "It looks like you don't have any models installed"
},
"whatsNew": {
"whatsNewInInvoke": "What's New in Invoke",
"items": ["Low-VRAM mode", "Dynamic memory management", "Faster model loading times", "Fewer memory errors"],
"items": [
"<StrongComponent>Flux Control Layers</StrongComponent>: New control models for edge detection and depth mapping are now supported for Flux dev models."
],
"readReleaseNotes": "Read Release Notes",
"watchRecentReleaseVideos": "Watch Recent Release Videos",
"watchUiUpdatesOverview": "Watch UI Updates Overview"

View File

@@ -610,13 +610,9 @@
"hfTokenSaved": "Gettone HF salvato",
"hfForbidden": "Non hai accesso a questo modello HF",
"hfTokenLabel": "Gettone HuggingFace (richiesto per alcuni modelli)",
"hfForbiddenErrorMessage": "Consigliamo di visitare la pagina del repository. Il proprietario potrebbe richiedere l'accettazione dei termini per poter effettuare il download.",
"hfForbiddenErrorMessage": "Consigliamo di visitare la pagina del repository su HuggingFace.com. Il proprietario potrebbe richiedere l'accettazione dei termini per poter effettuare il download.",
"hfTokenInvalidErrorMessage2": "Aggiornalo in ",
"controlLora": "Controllo LoRA",
"urlUnauthorizedErrorMessage2": "Scopri come qui.",
"urlForbidden": "Non hai accesso a questo modello",
"urlForbiddenErrorMessage": "Potrebbe essere necessario richiedere l'autorizzazione al sito che distribuisce il modello.",
"urlUnauthorizedErrorMessage": "Potrebbe essere necessario configurare un gettone API per accedere a questo modello."
"controlLora": "Controllo LoRA"
},
"parameters": {
"images": "Immagini",

View File

@@ -1,14 +1,14 @@
{
"common": {
"hotkeysLabel": "Skróty klawiszowe",
"languagePickerLabel": "Język",
"languagePickerLabel": "Wybór języka",
"reportBugLabel": "Zgłoś błąd",
"settingsLabel": "Ustawienia",
"img2img": "Obraz na obraz",
"nodes": "Węzły",
"upload": "Prześlij",
"load": "Załaduj",
"statusDisconnected": "Odłączono",
"statusDisconnected": "Odłączono od serwera",
"githubLabel": "GitHub",
"discordLabel": "Discord",
"clipboard": "Schowek",
@@ -27,87 +27,13 @@
"back": "Do tyłu",
"auto": "Automatyczny",
"beta": "Beta",
"close": "Wyjdź",
"checkpoint": "Punkt kontrolny",
"controlNet": "ControlNet",
"details": "Detale",
"direction": "Kierunek",
"ipAdapter": "Adapter IP",
"dontAskMeAgain": "Nie pytaj ponownie",
"modelManager": "Menedżer modeli",
"blue": "Niebieski",
"orderBy": "Sortuj według",
"openInNewTab": "Otwórz w nowym oknie",
"somethingWentWrong": "Coś poszło nie tak",
"green": "Zielony",
"red": "Czerwony",
"imageFailedToLoad": "Nie można załadować obrazu",
"saveAs": "Zapisz jako",
"outputs": "Wyjścia",
"data": "Dane",
"localSystem": "System Lokalny",
"t2iAdapter": "Adapter T2I",
"selected": "Zaznaczone",
"warnings": "Ostrzeżenia",
"save": "Zapisz",
"created": "Stworzono",
"alpha": "Alfa",
"error": "Bład",
"editor": "Edytor",
"loading": "Ładuję",
"edit": "Edytuj",
"enabled": "Aktywny",
"communityLabel": "Społeczeństwo",
"linear": "Liniowy",
"installed": "Zainstalowany",
"dontShowMeThese": "Nie pokazuj mi tego",
"openInViewer": "Otwórz podgląd",
"safetensors": "Bezpieczniki",
"ok": "Ok",
"goTo": "Idź do",
"loadingImage": "wczytywanie zdjęcia",
"input": "Wejście",
"view": "Podgląd",
"learnMore": "Dowiedz się więcej",
"notInstalled": "Nie $t(common.installed)",
"loadingModel": "Wczytywanie modelu",
"postprocessing": "Przetwarzanie końcowe",
"random": "Losowo",
"disabled": "Wyłączony",
"generating": "Generowanie",
"simple": "Prosty",
"folder": "Katalog",
"format": "Format",
"updated": "Zaktualizowano",
"unknown": "nieznany",
"delete": "Usuń",
"template": "Szablon",
"txt2img": "Tekst na obraz",
"prevPage": "Poprzednia strona",
"file": "Plik",
"toResolve": "Do rozwiązania",
"nextPage": "Następna strona",
"unknownError": "Nieznany błąd",
"placeholderSelectAModel": "Wybierz model",
"new": "Nowy",
"none": "Żadne",
"reset": "Reset",
"on": "Włączony",
"aboutHeading": "Posiadaj swoją kreatywną moc"
"close": "Wyjdź"
},
"gallery": {
"galleryImageSize": "Rozmiar obrazów",
"gallerySettings": "Ustawienia galerii",
"autoSwitchNewImages": "Przełączaj na nowe obrazy",
"noImagesInGallery": "Brak obrazów w galerii",
"gallery": "Galeria",
"alwaysShowImageSizeBadge": "Zawsze pokazuj odznakę wielkości obrazu",
"assetsTab": "Pliki, które wrzuciłeś do użytku w twoich projektach.",
"currentlyInUse": "Ten obraz jest obecnie w użyciu przez następujące funkcje:",
"boardsSettings": "Ustawienia tablic",
"assets": "Aktywy",
"autoAssignBoardOnClick": "Automatycznie przypisz tablicę po kliknięciu",
"copy": "Kopiuj"
"noImagesInGallery": "Brak obrazów w galerii"
},
"parameters": {
"images": "L. obrazów",
@@ -157,14 +83,7 @@
"previousImage": "Poprzedni obraz",
"nextImage": "Następny obraz",
"menu": "Menu",
"mode": "Tryb",
"resetUI": "$t(accessibility.reset) UI",
"uploadImages": "Wgrywaj obrazy",
"about": "Informacje",
"toggleRightPanel": "Przełącz prawy panel (G)",
"toggleLeftPanel": "Przełącz lewy panel (G)",
"createIssue": "Stwórz problem",
"submitSupportTicket": "Wyślij bilet pomocy"
"mode": "Tryb"
},
"boards": {
"cancel": "Anuluj",
@@ -179,34 +98,7 @@
"downloadBoard": "Pobierz tablice",
"loading": "Ładowanie...",
"move": "Przenieś",
"noMatching": "Brak pasujących tablic",
"addBoard": "Dodaj tablicę",
"autoAddBoard": "Automatycznie dodaj tablicę",
"searchBoard": "Szukaj tablic.",
"unarchiveBoard": "Odarchiwizuj tablicę",
"selectedForAutoAdd": "Wybrany do automatycznego dodania",
"deleteBoard": "Usuń tablicę",
"clearSearch": "Usuń historię",
"hideBoards": "Ukryj tablice",
"viewBoards": "Zobacz tablice",
"addSharedBoard": "Dodaj udostępnioną tablicę",
"boards": "Tablice",
"addPrivateBoard": "Dodaj prywatną tablicę",
"movingImagesToBoard_one": "Przenoszenie {{count}} zdjęcia do tablicy:",
"movingImagesToBoard_few": "Przenoszenie {{count}} zdjęć do tablicy:",
"movingImagesToBoard_many": "Przenoszenie {{count}} zdjęć do tablicy:",
"shared": "Udostępnione tablice",
"topMessage": "Ta tablica zawiera obrazy wykorzystywane w następujących funkcjach:",
"deletedPrivateBoardsCannotbeRestored": "Usunięte tablice nie mogą być odzyskane. Wybierając \"Usuń tylko tablicę\" spowoduje że obrazy zostaną przeniesione do prywatnego nieskategoryzowanego stanu autora obrazu.",
"changeBoard": "Zmień tablicę",
"bottomMessage": "Usuwając tę tablicę oraz jej obrazów zresetują wszystkie funkcje które obecnie ich używają.",
"deleteBoardAndImages": "Usuń tablicę i zdjęcia",
"deleteBoardOnly": "Usuń tylko tablicę",
"deletedBoardsCannotbeRestored": "Usunięte tablice nie mogą być odzyskane. Wybierając \"Usuń tylko tablicę\" spowoduje że obrazy zostaną przeniesione do nieskategoryzowanego stanu.",
"archiveBoard": "Zarchiwizuj tablicę",
"archived": "Zarchiwizowano",
"myBoard": "Moja tablica",
"menuItemAutoAdd": "Automatycznie dodaj do tej tablicy"
"noMatching": "Brak pasujących tablic"
},
"accordions": {
"compositing": {
@@ -227,103 +119,5 @@
"control": {
"title": "Kontrola"
}
},
"hrf": {
"metadata": {
"enabled": "Włączono poprawkę wysokiej rozdzielczości",
"strength": "Moc poprawki wysokiej rozdzielczości",
"method": "Metoda High Resolution Fix"
},
"hrf": "Poprawka \"Wysoka rozdzielczość\"",
"enableHrf": "Włącz poprawkę wysokiej rozdzielczości"
},
"queue": {
"cancelTooltip": "Anuluj aktualną pozycję",
"resumeFailed": "Błąd z kontynuowaniem procesora",
"current": "Obecne",
"cancelBatchFailed": "Problem z anulacją masy",
"queueFront": "Dodaj do przodu kolejki",
"cancelBatch": "Anuluj serię",
"cancelFailed": "Problem z anulowaniem pozycji",
"pruneTooltip": "Wyczyść {{item_count}} skończonych pozycji",
"pruneSucceeded": "Wyczyszczono {{item_count}} zakończonych pozycji z kolejki",
"cancelBatchSucceeded": "Partia anulowana",
"clear": "Wyczyść",
"clearTooltip": "Anuluj i usuń wszystkie pozycje",
"clearSucceeded": "Kolejka wyczyszczona",
"cancelItem": "Anuluj pozycję",
"clearQueueAlertDialog2": "Czy na pewno chcesz wyczyścić kolejkę?",
"pauseFailed": "Problem z zapauzowaniem processora",
"clearFailed": "Problem z czyszczeniem kolejki",
"queueBack": "Dodaj do kolejki",
"queueEmpty": "Kolejka pusta",
"enqueueing": "Kolejkowanie partii",
"resumeTooltip": "Kontynuuj processor",
"resumeSucceeded": "Processor kontynuowany",
"pause": "Zapauzuj",
"pauseTooltip": "Zapauzuj processor",
"queue": "Kolejka",
"resume": "Kontynuuj",
"cancel": "Anuluj",
"cancelSucceeded": "Pozycja anulowana",
"prune": "Wyczyść",
"pauseSucceeded": "Processor zapauzowany",
"clearQueueAlertDialog": "Czyszczenie kolejki od razu anuluje wszystkie przetwarzane elementy and całkowicie czyści kolejkę. Oczekujące filtry zostaną anulowane.",
"pruneFailed": "Problem z wyczyszczeniem kolejki",
"batchQueued": "Masa w kolejce",
"openQueue": "Otwórz kolejkę",
"iterations_one": "Iteracja",
"iterations_few": "Iteracje",
"iterations_many": "Iteracje",
"graphQueued": "Wykres w kolejce",
"canvas": "Płótno",
"generation": "Generacja",
"status": "Status",
"total": "Suma",
"time": "Czas",
"front": "Przód",
"back": "tył",
"batchFailedToQueue": "Nie można zkolejkować masy",
"completedIn": "Ukończony w całości",
"other": "Inne",
"origin": "Pochodzenie",
"destination": "Miejsce docelowe",
"notReady": "Nie można zkolejkować",
"canceled": "Anulowano",
"in_progress": "W trakcie",
"gallery": "Galeria",
"session": "Sesja",
"pending": "W toku",
"completed": "Zakończono",
"item": "Pozycja",
"failed": "Niepowodzenie",
"batchFieldValues": "Masowe Wartości pól",
"graphFailedToQueue": "NIe udało się dodać tabeli do kolejki",
"workflows": "Przepływy pracy",
"next": "Następny",
"batchQueuedDesc_one": "Dodano {{count}} sesję do {{direction}} kolejki",
"batchQueuedDesc_few": "Dodano {{count}} sesje do {{direction}} kolejki",
"batchQueuedDesc_many": "Dodano {{count}} sesje do {{direction}} kolejki",
"batch": "Masa",
"upscaling": "Skalowanie w górę",
"generations_one": "Generacja",
"generations_few": "Generacje",
"generations_many": "Generacje",
"prompts_one": "Monit",
"prompts_few": "Monity",
"prompts_many": "Monity",
"batchSize": "Rozmiar masy"
},
"prompt": {
"compatibleEmbeddings": "Kompatybilne osadzenia",
"noMatchingTriggers": "Nie dopasowywanie spustów"
},
"invocationCache": {
"hits": "Uderzenia cache",
"enable": "Włącz",
"clear": "Wyczyść",
"disable": "Wyłącz",
"maxCacheSize": "Maksymalny rozmiar cache",
"cacheSize": "Rozmiar Cache"
}
}

View File

@@ -231,7 +231,7 @@
"resume": "Tiếp Tục",
"enqueueing": "Xếp Vào Hàng Hàng Loạt",
"prompts_other": "Lệnh",
"iterations_other": "Vòng Lặp",
"iterations_other": "Lặp Lại",
"total": "Tổng",
"pruneFailed": "Có Vấn Đề Khi Cắt Bớt Mục Khỏi Hàng",
"clearSucceeded": "Hàng Đã Được Dọn Sạch",
@@ -271,7 +271,7 @@
"queueFront": "Thêm Vào Đầu Hàng",
"resumeTooltip": "Tiếp Tục Bộ Xử Lý",
"clearFailed": "Có Vấn Đề Khi Dọn Dẹp Hàng",
"generations_other": "Ảnh Tạo Sinh",
"generations_other": "Máy Tạo Sinh",
"cancelBatch": "Huỷ Bỏ Lô",
"status": "Trạng Thái",
"pending": "Đang Chờ",
@@ -648,7 +648,7 @@
"defaultSettingsSaved": "Đã Lưu Thiết Lập Mặc Định",
"description": "Dòng Mô Tả",
"imageEncoderModelId": "ID Model Image Encoder",
"hfForbiddenErrorMessage": "Chúng tôi gợi ý vào các repository. Chủ sở hữu có thể yêu cầu chấp nhận điều khoản để tải xuống.",
"hfForbiddenErrorMessage": "Chúng tôi gợi ý vào trang repository trên HuggingFace.com. Chủ sở hữu có thể yêu cầu chấp nhận điều khoản để tải xuống.",
"hfTokenSaved": "Đã Lưu HF Token",
"learnMoreAboutSupportedModels": "Tìm hiểu thêm về những model được hỗ trợ",
"availableModels": "Model Có Sẵn",
@@ -731,7 +731,7 @@
"repo_id": "ID Repository",
"scanFolder": "Quét Thư Mục",
"scanFolderHelper": "Thư mục sẽ được quét để tìm model. Có thể sẽ mất nhiều thời gian với những thư mục lớn.",
"scanResults": "Kết Quả",
"scanResults": "Kết Quả Quét",
"t5Encoder": "T5 Encoder",
"mainModelTriggerPhrases": "Từ Ngữ Kích Hoạt Cho Model Chính",
"textualInversions": "Bộ Đảo Ngược Văn Bản",
@@ -739,12 +739,7 @@
"width": "Chiều Rộng",
"starterModelsInModelManager": "Model khởi đầu có thể tìm thấy ở Trình Quản Lý Model",
"clipLEmbed": "CLIP-L Embed",
"clipGEmbed": "CLIP-G Embed",
"controlLora": "LoRA Điều Khiển Được",
"urlUnauthorizedErrorMessage2": "Tìm hiểu thêm.",
"urlForbidden": "Bạn không có quyền truy cập vào model này",
"urlForbiddenErrorMessage": "Bạn có thể cần yêu cầu quyền truy cập từ trang web đang cung cấp model.",
"urlUnauthorizedErrorMessage": "Bạn có thể cần thiếp lập một token API để dùng được model này."
"clipGEmbed": "CLIP-G Embed"
},
"metadata": {
"guidance": "Hướng Dẫn",
@@ -1438,8 +1433,7 @@
"collectionEmpty": "{{nodeLabel}} -> {{fieldLabel}} tài nguyên trống",
"collectionTooFewItems": "{{nodeLabel}} -> {{fieldLabel}}: quá ít mục, tối thiểu {{minItems}}",
"collectionTooManyItems": "{{nodeLabel}} -> {{fieldLabel}}: quá nhiều mục, tối đa {{maxItems}}",
"canvasIsSelectingObject": "Canvas đang bận (đang chọn đồ vật)",
"fluxModelMultipleControlLoRAs": "Chỉ có thể dùng 1 LoRA Điều Khiển Được"
"canvasIsSelectingObject": "Canvas đang bận (đang chọn đồ vật)"
},
"cfgScale": "Thang CFG",
"useSeed": "Dùng Hạt Giống",

View File

@@ -2,19 +2,10 @@ import { logger } from 'app/logging/logger';
import { enqueueRequested } from 'app/store/actions';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectNodesSlice } from 'features/nodes/store/selectors';
import type { ImageField } from 'features/nodes/types/common';
import {
isFloatFieldCollectionInputInstance,
isImageFieldCollectionInputInstance,
isIntegerFieldCollectionInputInstance,
isStringFieldCollectionInputInstance,
} from 'features/nodes/types/field';
import { resolveNumberFieldCollectionValue } from 'features/nodes/types/fieldValidators';
import type { InvocationNodeEdge } from 'features/nodes/types/invocation';
import { isBatchNode, isInvocationNode } from 'features/nodes/types/invocation';
import { isImageFieldCollectionInputInstance } from 'features/nodes/types/field';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { buildNodesGraph } from 'features/nodes/util/graph/buildNodesGraph';
import { buildWorkflowWithValidation } from 'features/nodes/util/workflow/buildWorkflow';
import { groupBy } from 'lodash-es';
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
import type { Batch, BatchConfig } from 'services/api/types';
@@ -42,140 +33,28 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
const data: Batch['data'] = [];
const batchNodes = nodes.nodes.filter(isInvocationNode).filter(isBatchNode);
// Handle zipping batch nodes. First group the batch nodes by their batch_group_id
const groupedBatchNodes = groupBy(batchNodes, (node) => node.data.inputs['batch_group_id']?.value);
const addProductBatchDataCollectionItem = (
edges: InvocationNodeEdge[],
items?: ImageField[] | string[] | number[]
) => {
const productBatchDataCollectionItems: NonNullable<Batch['data']>[number] = [];
for (const edge of edges) {
// Skip edges from batch nodes - these should not be in the graph, they exist only in the UI
const imageBatchNodes = nodes.nodes.filter(isInvocationNode).filter((node) => node.data.type === 'image_batch');
for (const node of imageBatchNodes) {
const images = node.data.inputs['images'];
if (!isImageFieldCollectionInputInstance(images)) {
log.warn({ nodeId: node.id }, 'Image batch images field is not an image collection');
break;
}
const edgesFromImageBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'image');
const batchDataCollectionItem: NonNullable<Batch['data']>[number] = [];
for (const edge of edgesFromImageBatch) {
if (!edge.targetHandle) {
break;
}
productBatchDataCollectionItems.push({
batchDataCollectionItem.push({
node_path: edge.target,
field_name: edge.targetHandle,
items,
items: images.value,
});
}
if (productBatchDataCollectionItems.length > 0) {
data.push(productBatchDataCollectionItems);
}
};
// Then, we will create a batch data collection item for each group
for (const [batchGroupId, batchNodes] of Object.entries(groupedBatchNodes)) {
const zippedBatchDataCollectionItems: NonNullable<Batch['data']>[number] = [];
const addZippedBatchDataCollectionItem = (
edges: InvocationNodeEdge[],
items?: ImageField[] | string[] | number[]
) => {
for (const edge of edges) {
if (!edge.targetHandle) {
break;
}
zippedBatchDataCollectionItems.push({
node_path: edge.target,
field_name: edge.targetHandle,
items,
});
}
};
// Grab image batch nodes for special handling
const imageBatchNodes = batchNodes.filter((node) => node.data.type === 'image_batch');
for (const node of imageBatchNodes) {
// Satisfy TS
const images = node.data.inputs['images'];
if (!isImageFieldCollectionInputInstance(images)) {
log.warn({ nodeId: node.id }, 'Image batch images field is not an image collection');
break;
}
// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
const edgesFromImageBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'image');
if (batchGroupId !== 'None') {
addZippedBatchDataCollectionItem(edgesFromImageBatch, images.value);
} else {
addProductBatchDataCollectionItem(edgesFromImageBatch, images.value);
}
}
// Grab string batch nodes for special handling
const stringBatchNodes = batchNodes.filter((node) => node.data.type === 'string_batch');
for (const node of stringBatchNodes) {
// Satisfy TS
const strings = node.data.inputs['strings'];
if (!isStringFieldCollectionInputInstance(strings)) {
log.warn({ nodeId: node.id }, 'String batch strings field is not a string collection');
break;
}
// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
const edgesFromStringBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'value');
if (batchGroupId !== 'None') {
addZippedBatchDataCollectionItem(edgesFromStringBatch, strings.value);
} else {
addProductBatchDataCollectionItem(edgesFromStringBatch, strings.value);
}
}
// Grab integer batch nodes for special handling
const integerBatchNodes = batchNodes.filter((node) => node.data.type === 'integer_batch');
for (const node of integerBatchNodes) {
// Satisfy TS
const integers = node.data.inputs['integers'];
if (!isIntegerFieldCollectionInputInstance(integers)) {
log.warn({ nodeId: node.id }, 'Integer batch integers field is not an integer collection');
break;
}
if (!integers.value) {
log.warn({ nodeId: node.id }, 'Integer batch integers field is empty');
break;
}
// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
const edgesFromStringBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'value');
const resolvedValue = resolveNumberFieldCollectionValue(integers);
if (batchGroupId !== 'None') {
addZippedBatchDataCollectionItem(edgesFromStringBatch, resolvedValue);
} else {
addProductBatchDataCollectionItem(edgesFromStringBatch, resolvedValue);
}
}
// Grab float batch nodes for special handling
const floatBatchNodes = batchNodes.filter((node) => node.data.type === 'float_batch');
for (const node of floatBatchNodes) {
// Satisfy TS
const floats = node.data.inputs['floats'];
if (!isFloatFieldCollectionInputInstance(floats)) {
log.warn({ nodeId: node.id }, 'Float batch floats field is not a float collection');
break;
}
if (!floats.value) {
log.warn({ nodeId: node.id }, 'Float batch floats field is empty');
break;
}
// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
const edgesFromStringBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'value');
const resolvedValue = resolveNumberFieldCollectionValue(floats);
if (batchGroupId !== 'None') {
addZippedBatchDataCollectionItem(edgesFromStringBatch, resolvedValue);
} else {
addProductBatchDataCollectionItem(edgesFromStringBatch, resolvedValue);
}
}
// Finally, if this batch data collection item has any items, add it to the data array
if (batchGroupId !== 'None' && zippedBatchDataCollectionItems.length > 0) {
data.push(zippedBatchDataCollectionItems);
if (batchDataCollectionItem.length > 0) {
data.push(batchDataCollectionItem);
}
}

View File

@@ -166,10 +166,8 @@ export const createStore = (uniqueStoreKey?: string, persist = true) =>
reducer: rememberedRootReducer,
middleware: (getDefaultMiddleware) =>
getDefaultMiddleware({
serializableCheck: false,
immutableCheck: false,
// serializableCheck: import.meta.env.MODE === 'development',
// immutableCheck: import.meta.env.MODE === 'development',
serializableCheck: import.meta.env.MODE === 'development',
immutableCheck: import.meta.env.MODE === 'development',
})
.concat(api.middleware)
.concat(dynamicMiddlewares)

View File

@@ -57,7 +57,6 @@ export const CanvasMainPanelContent = memo(() => {
gap={2}
alignItems="center"
justifyContent="center"
overflow="hidden"
>
<CanvasManagerProviderGate>
<CanvasToolbar />
@@ -71,7 +70,6 @@ export const CanvasMainPanelContent = memo(() => {
h="full"
bg={dynamicGrid ? 'base.850' : 'base.900'}
borderRadius="base"
overflow="hidden"
>
<InvokeCanvasComponent />
<CanvasManagerProviderGate>

View File

@@ -1,72 +0,0 @@
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
import { Combobox, CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
import type { BlurFilterConfig, BlurTypes } from 'features/controlLayers/store/filters';
import { IMAGE_FILTERS, isBlurTypes } from 'features/controlLayers/store/filters';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { FilterComponentProps } from './types';
type Props = FilterComponentProps<BlurFilterConfig>;
const DEFAULTS = IMAGE_FILTERS.img_blur.buildDefaults();
export const FilterBlur = memo(({ onChange, config }: Props) => {
const { t } = useTranslation();
const handleBlurTypeChange = useCallback<ComboboxOnChange>(
(v) => {
if (!isBlurTypes(v?.value)) {
return;
}
onChange({ ...config, blur_type: v.value });
},
[config, onChange]
);
const handleRadiusChange = useCallback(
(v: number) => {
onChange({ ...config, radius: v });
},
[config, onChange]
);
const options: { label: string; value: BlurTypes }[] = useMemo(
() => [
{ label: t('controlLayers.filter.img_blur.gaussian_type'), value: 'gaussian' },
{ label: t('controlLayers.filter.img_blur.box_type'), value: 'box' },
],
[t]
);
const value = useMemo(() => options.filter((o) => o.value === config.blur_type)[0], [options, config.blur_type]);
return (
<>
<FormControl>
<FormLabel m={0}>{t('controlLayers.filter.img_blur.blur_type')}</FormLabel>
<Combobox value={value} options={options} onChange={handleBlurTypeChange} isSearchable={false} />
</FormControl>
<FormControl>
<FormLabel m={0}>{t('controlLayers.filter.img_blur.blur_radius')}</FormLabel>
<CompositeSlider
value={config.radius}
defaultValue={DEFAULTS.radius}
onChange={handleRadiusChange}
min={1}
max={64}
step={0.1}
marks
/>
<CompositeNumberInput
value={config.radius}
defaultValue={DEFAULTS.radius}
onChange={handleRadiusChange}
min={1}
max={4096}
step={0.1}
/>
</FormControl>
</>
);
});
FilterBlur.displayName = 'FilterBlur';

View File

@@ -1,111 +0,0 @@
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
import { Combobox, CompositeNumberInput, CompositeSlider, FormControl, FormLabel, Switch } from '@invoke-ai/ui-library';
import type { NoiseFilterConfig, NoiseTypes } from 'features/controlLayers/store/filters';
import { IMAGE_FILTERS, isNoiseTypes } from 'features/controlLayers/store/filters';
import type { ChangeEvent } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { FilterComponentProps } from './types';
type Props = FilterComponentProps<NoiseFilterConfig>;
const DEFAULTS = IMAGE_FILTERS.img_noise.buildDefaults();
export const FilterNoise = memo(({ onChange, config }: Props) => {
const { t } = useTranslation();
const handleNoiseTypeChange = useCallback<ComboboxOnChange>(
(v) => {
if (!isNoiseTypes(v?.value)) {
return;
}
onChange({ ...config, noise_type: v.value });
},
[config, onChange]
);
const handleAmountChange = useCallback(
(v: number) => {
onChange({ ...config, amount: v });
},
[config, onChange]
);
const handleColorChange = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
onChange({ ...config, noise_color: e.target.checked });
},
[config, onChange]
);
const handleSizeChange = useCallback(
(v: number) => {
onChange({ ...config, size: v });
},
[config, onChange]
);
const options: { label: string; value: NoiseTypes }[] = useMemo(
() => [
{ label: t('controlLayers.filter.img_noise.gaussian_type'), value: 'gaussian' },
{ label: t('controlLayers.filter.img_noise.salt_and_pepper_type'), value: 'salt_and_pepper' },
],
[t]
);
const value = useMemo(() => options.filter((o) => o.value === config.noise_type)[0], [options, config.noise_type]);
return (
<>
<FormControl>
<FormLabel m={0}>{t('controlLayers.filter.img_noise.noise_type')}</FormLabel>
<Combobox value={value} options={options} onChange={handleNoiseTypeChange} isSearchable={false} />
</FormControl>
<FormControl>
<FormLabel m={0}>{t('controlLayers.filter.img_noise.noise_amount')}</FormLabel>
<CompositeSlider
value={config.amount}
defaultValue={DEFAULTS.amount}
onChange={handleAmountChange}
min={0}
max={1}
step={0.01}
marks
/>
<CompositeNumberInput
value={config.amount}
defaultValue={DEFAULTS.amount}
onChange={handleAmountChange}
min={0}
max={1}
step={0.01}
/>
</FormControl>
<FormControl>
<FormLabel m={0}>{t('controlLayers.filter.img_noise.size')}</FormLabel>
<CompositeSlider
value={config.size}
defaultValue={DEFAULTS.size}
onChange={handleSizeChange}
min={1}
max={16}
step={1}
marks
/>
<CompositeNumberInput
value={config.size}
defaultValue={DEFAULTS.size}
onChange={handleSizeChange}
min={1}
max={256}
step={1}
/>
</FormControl>
<FormControl w="max-content">
<FormLabel m={0}>{t('controlLayers.filter.img_noise.noise_color')}</FormLabel>
<Switch defaultChecked={DEFAULTS.noise_color} isChecked={config.noise_color} onChange={handleColorChange} />
</FormControl>
</>
);
});
FilterNoise.displayName = 'Filternoise';

View File

@@ -1,5 +1,4 @@
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { FilterBlur } from 'features/controlLayers/components/Filters/FilterBlur';
import { FilterCannyEdgeDetection } from 'features/controlLayers/components/Filters/FilterCannyEdgeDetection';
import { FilterColorMap } from 'features/controlLayers/components/Filters/FilterColorMap';
import { FilterContentShuffle } from 'features/controlLayers/components/Filters/FilterContentShuffle';
@@ -9,7 +8,6 @@ import { FilterHEDEdgeDetection } from 'features/controlLayers/components/Filter
import { FilterLineartEdgeDetection } from 'features/controlLayers/components/Filters/FilterLineartEdgeDetection';
import { FilterMediaPipeFaceDetection } from 'features/controlLayers/components/Filters/FilterMediaPipeFaceDetection';
import { FilterMLSDDetection } from 'features/controlLayers/components/Filters/FilterMLSDDetection';
import { FilterNoise } from 'features/controlLayers/components/Filters/FilterNoise';
import { FilterPiDiNetEdgeDetection } from 'features/controlLayers/components/Filters/FilterPiDiNetEdgeDetection';
import { FilterSpandrel } from 'features/controlLayers/components/Filters/FilterSpandrel';
import type { FilterConfig } from 'features/controlLayers/store/filters';
@@ -21,10 +19,6 @@ type Props = { filterConfig: FilterConfig; onChange: (filterConfig: FilterConfig
export const FilterSettings = memo(({ filterConfig, onChange }: Props) => {
const { t } = useTranslation();
if (filterConfig.type === 'img_blur') {
return <FilterBlur config={filterConfig} onChange={onChange} />;
}
if (filterConfig.type === 'canny_edge_detection') {
return <FilterCannyEdgeDetection config={filterConfig} onChange={onChange} />;
}
@@ -65,10 +59,6 @@ export const FilterSettings = memo(({ filterConfig, onChange }: Props) => {
return <FilterPiDiNetEdgeDetection config={filterConfig} onChange={onChange} />;
}
if (filterConfig.type === 'img_noise') {
return <FilterNoise config={filterConfig} onChange={onChange} />;
}
if (filterConfig.type === 'spandrel_filter') {
return <FilterSpandrel config={filterConfig} onChange={onChange} />;
}

View File

@@ -297,9 +297,10 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
const imageState = imageDTOToImageObject(filterResult.value);
this.$imageState.set(imageState);
// Stash the existing image module - we will destroy it after the new image is rendered to prevent a flash
// of an empty layer
const oldImageModule = this.imageModule;
// Destroy any existing masked image and create a new one
if (this.imageModule) {
this.imageModule.destroy();
}
this.imageModule = new CanvasObjectImage(imageState, this);
@@ -308,16 +309,6 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
this.konva.group.add(this.imageModule.konva.group);
// The filtered image have some transparency, so we need to hide the objects of the parent entity to prevent the
// two images from blending. We will show the objects again in the teardown method, which is always called after
// the filter finishes (applied or canceled).
this.parent.renderer.hideObjects();
if (oldImageModule) {
// Destroy the old image module now that the new one is rendered
oldImageModule.destroy();
}
// The porcessing is complete, set can set the last processed hash and isProcessing to false
this.$lastProcessedHash.set(hash);
@@ -433,8 +424,6 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
teardown = () => {
this.unsubscribe();
// Re-enable the objects of the parent entity
this.parent.renderer.showObjects();
this.konva.group.remove();
// The reset must be done _after_ unsubscribing from listeners, in case the listeners would otherwise react to
// the reset. For example, if auto-processing is enabled and we reset the state, it may trigger processing.

View File

@@ -185,14 +185,6 @@ export class CanvasEntityObjectRenderer extends CanvasModuleBase {
return didRender;
};
hideObjects = () => {
this.konva.objectGroup.hide();
};
showObjects = () => {
this.konva.objectGroup.show();
};
adoptObjectRenderer = (renderer: AnyObjectRenderer) => {
this.renderers.set(renderer.id, renderer);
renderer.konva.group.moveTo(this.konva.objectGroup);

View File

@@ -10,7 +10,6 @@ import {
getKonvaNodeDebugAttrs,
getPrefixedId,
offsetCoord,
roundRect,
} from 'features/controlLayers/konva/util';
import { selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
import type { Coordinate, Rect, RectWithRotation } from 'features/controlLayers/store/types';
@@ -774,7 +773,7 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
const rect = this.getRelativeRect();
const rasterizeResult = await withResultAsync(() =>
this.parent.renderer.rasterize({
rect: roundRect(rect),
rect,
replaceObjects: true,
ignoreCache: true,
attrs: { opacity: 1, filters: [] },

View File

@@ -740,12 +740,3 @@ export const getColorAtCoordinate = (stage: Konva.Stage, coord: Coordinate): Rgb
return { r, g, b };
};
export const roundRect = (rect: Rect): Rect => {
return {
x: Math.round(rect.x),
y: Math.round(rect.y),
width: Math.round(rect.width),
height: Math.round(rect.height),
};
};

View File

@@ -95,28 +95,6 @@ const zSpandrelFilterConfig = z.object({
});
export type SpandrelFilterConfig = z.infer<typeof zSpandrelFilterConfig>;
const zBlurTypes = z.enum(['gaussian', 'box']);
export type BlurTypes = z.infer<typeof zBlurTypes>;
export const isBlurTypes = (v: unknown): v is BlurTypes => zBlurTypes.safeParse(v).success;
const zBlurFilterConfig = z.object({
type: z.literal('img_blur'),
blur_type: zBlurTypes,
radius: z.number().gte(0),
});
export type BlurFilterConfig = z.infer<typeof zBlurFilterConfig>;
const zNoiseTypes = z.enum(['gaussian', 'salt_and_pepper']);
export type NoiseTypes = z.infer<typeof zNoiseTypes>;
export const isNoiseTypes = (v: unknown): v is NoiseTypes => zNoiseTypes.safeParse(v).success;
const zNoiseFilterConfig = z.object({
type: z.literal('img_noise'),
noise_type: zNoiseTypes,
amount: z.number().gte(0).lte(1),
noise_color: z.boolean(),
size: z.number().int().gte(1),
});
export type NoiseFilterConfig = z.infer<typeof zNoiseFilterConfig>;
const zFilterConfig = z.discriminatedUnion('type', [
zCannyEdgeDetectionFilterConfig,
zColorMapFilterConfig,
@@ -131,8 +109,6 @@ const zFilterConfig = z.discriminatedUnion('type', [
zPiDiNetEdgeDetectionFilterConfig,
zDWOpenposeDetectionFilterConfig,
zSpandrelFilterConfig,
zBlurFilterConfig,
zNoiseFilterConfig,
]);
export type FilterConfig = z.infer<typeof zFilterConfig>;
@@ -150,8 +126,6 @@ const zFilterType = z.enum([
'pidi_edge_detection',
'dw_openpose_detection',
'spandrel_filter',
'img_blur',
'img_noise',
]);
export type FilterType = z.infer<typeof zFilterType>;
export const isFilterType = (v: unknown): v is FilterType => zFilterType.safeParse(v).success;
@@ -455,62 +429,6 @@ export const IMAGE_FILTERS: { [key in FilterConfig['type']]: ImageFilterData<key
return true;
},
},
img_blur: {
type: 'img_blur',
buildDefaults: () => ({
type: 'img_blur',
blur_type: 'gaussian',
radius: 8,
}),
buildGraph: ({ image_name }, { blur_type, radius }) => {
const graph = new Graph(getPrefixedId('img_blur'));
const node = graph.addNode({
id: getPrefixedId('img_blur'),
type: 'img_blur',
image: { image_name },
blur_type: blur_type,
radius: radius,
});
return {
graph,
outputNodeId: node.id,
};
},
},
img_noise: {
type: 'img_noise',
buildDefaults: () => ({
type: 'img_noise',
noise_type: 'gaussian',
amount: 0.3,
noise_color: true,
size: 1,
}),
buildGraph: ({ image_name }, { noise_type, amount, noise_color, size }) => {
const graph = new Graph(getPrefixedId('img_noise'));
const node = graph.addNode({
id: getPrefixedId('img_noise'),
type: 'img_noise',
image: { image_name },
noise_type: noise_type,
amount: amount,
noise_color: noise_color,
size: size,
});
const rand = graph.addNode({
id: getPrefixedId('rand_int'),
use_cache: false,
type: 'rand_int',
low: 0,
high: 2147483647,
});
graph.addEdge(rand, 'value', node, 'seed');
return {
graph,
outputNodeId: node.id,
};
},
},
} as const;
/**

View File

@@ -1,5 +1,4 @@
import type {
BlurFilterConfig,
CannyEdgeDetectionFilterConfig,
ColorMapFilterConfig,
ContentShuffleFilterConfig,
@@ -13,7 +12,6 @@ import type {
LineartEdgeDetectionFilterConfig,
MediaPipeFaceDetectionFilterConfig,
MLSDDetectionFilterConfig,
NoiseFilterConfig,
NormalMapFilterConfig,
PiDiNetEdgeDetectionFilterConfig,
} from 'features/controlLayers/store/filters';
@@ -56,7 +54,6 @@ describe('Control Adapter Types', () => {
});
test('Processor Configs', () => {
// Types derived from OpenAPI
type _BlurFilterConfig = Required<Pick<Invocation<'img_blur'>, 'type' | 'radius' | 'blur_type'>>;
type _CannyEdgeDetectionFilterConfig = Required<
Pick<Invocation<'canny_edge_detection'>, 'type' | 'low_threshold' | 'high_threshold'>
>;
@@ -74,9 +71,6 @@ describe('Control Adapter Types', () => {
type _MLSDDetectionFilterConfig = Required<
Pick<Invocation<'mlsd_detection'>, 'type' | 'score_threshold' | 'distance_threshold'>
>;
type _NoiseFilterConfig = Required<
Pick<Invocation<'img_noise'>, 'type' | 'noise_type' | 'amount' | 'noise_color' | 'size'>
>;
type _NormalMapFilterConfig = Required<Pick<Invocation<'normal_map'>, 'type'>>;
type _DWOpenposeDetectionFilterConfig = Required<
Pick<Invocation<'dw_openpose_detection'>, 'type' | 'draw_body' | 'draw_face' | 'draw_hands'>
@@ -87,7 +81,6 @@ describe('Control Adapter Types', () => {
// The processor configs are manually modeled zod schemas. This test ensures that the inferred types are correct.
// The types prefixed with `_` are types generated from OpenAPI, while the types without the prefix are manually modeled.
assert<Equals<_BlurFilterConfig, BlurFilterConfig>>();
assert<Equals<_CannyEdgeDetectionFilterConfig, CannyEdgeDetectionFilterConfig>>();
assert<Equals<_ColorMapFilterConfig, ColorMapFilterConfig>>();
assert<Equals<_ContentShuffleFilterConfig, ContentShuffleFilterConfig>>();
@@ -97,7 +90,6 @@ describe('Control Adapter Types', () => {
assert<Equals<_LineartEdgeDetectionFilterConfig, LineartEdgeDetectionFilterConfig>>();
assert<Equals<_MediaPipeFaceDetectionFilterConfig, MediaPipeFaceDetectionFilterConfig>>();
assert<Equals<_MLSDDetectionFilterConfig, MLSDDetectionFilterConfig>>();
assert<Equals<_NoiseFilterConfig, NoiseFilterConfig>>();
assert<Equals<_NormalMapFilterConfig, NormalMapFilterConfig>>();
assert<Equals<_DWOpenposeDetectionFilterConfig, DWOpenposeDetectionFilterConfig>>();
assert<Equals<_PiDiNetEdgeDetectionFilterConfig, PiDiNetEdgeDetectionFilterConfig>>();

View File

@@ -11,7 +11,7 @@ import type { DndTargetState } from 'features/dnd/types';
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
import { selectMaxImageUploadCount } from 'features/system/store/configSlice';
import { toast } from 'features/toast/toast';
import { memo, useCallback, useEffect, useRef, useState } from 'react';
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { uploadImages } from 'services/api/endpoints/images';
import { useBoardName } from 'services/api/hooks/useBoardName';
@@ -72,10 +72,11 @@ export const FullscreenDropzone = memo(() => {
const maxImageUploadCount = useAppSelector(selectMaxImageUploadCount);
const [dndState, setDndState] = useState<DndTargetState>('idle');
const uploadFilesSchema = useMemo(() => getFilesSchema(maxImageUploadCount), [maxImageUploadCount]);
const validateAndUploadFiles = useCallback(
(files: File[]) => {
const { getState } = getStore();
const uploadFilesSchema = getFilesSchema(maxImageUploadCount);
const parseResult = uploadFilesSchema.safeParse(files);
if (!parseResult.success) {
@@ -104,18 +105,7 @@ export const FullscreenDropzone = memo(() => {
uploadImages(uploadArgs);
},
[maxImageUploadCount, t]
);
const onPaste = useCallback(
(e: ClipboardEvent) => {
if (!e.clipboardData?.files) {
return;
}
const files = Array.from(e.clipboardData.files);
validateAndUploadFiles(files);
},
[validateAndUploadFiles]
[maxImageUploadCount, t, uploadFilesSchema]
);
useEffect(() => {
@@ -154,12 +144,24 @@ export const FullscreenDropzone = memo(() => {
}, [validateAndUploadFiles]);
useEffect(() => {
window.addEventListener('paste', onPaste);
const controller = new AbortController();
document.addEventListener(
'paste',
(e) => {
if (!e.clipboardData?.files) {
return;
}
const files = Array.from(e.clipboardData.files);
validateAndUploadFiles(files);
},
{ signal: controller.signal }
);
return () => {
window.removeEventListener('paste', onPaste);
controller.abort();
};
}, [onPaste]);
}, [validateAndUploadFiles]);
return (
<Box ref={ref} data-dnd-state={dndState} sx={sx}>

View File

@@ -1,4 +1,3 @@
import { logger } from 'app/logging/logger';
import type { AppDispatch, RootState } from 'app/store/store';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import type {
@@ -10,6 +9,7 @@ import { selectComparisonImages } from 'features/gallery/components/ImageViewer/
import type { BoardId } from 'features/gallery/store/types';
import {
addImagesToBoard,
addImagesToNodeImageFieldCollectionAction,
createNewCanvasEntityFromImage,
removeImagesFromBoard,
replaceCanvasEntityObjectsWithImage,
@@ -19,14 +19,10 @@ import {
setRegionalGuidanceReferenceImage,
setUpscaleInitialImage,
} from 'features/imageActions/actions';
import { fieldImageCollectionValueChanged } from 'features/nodes/store/nodesSlice';
import { selectFieldInputInstance, selectNodesSlice } from 'features/nodes/store/selectors';
import { type FieldIdentifier, isImageFieldCollectionInputInstance } from 'features/nodes/types/field';
import type { FieldIdentifier } from 'features/nodes/types/field';
import type { ImageDTO } from 'services/api/types';
import type { JsonObject } from 'type-fest';
const log = logger('dnd');
type RecordUnknown = Record<string | symbol, unknown>;
type DndData<
@@ -272,27 +268,15 @@ export const addImagesToNodeImageFieldCollectionDndTarget: DndTarget<
}
const { fieldIdentifier } = targetData.payload;
const fieldInputInstance = selectFieldInputInstance(
selectNodesSlice(getState()),
fieldIdentifier.nodeId,
fieldIdentifier.fieldName
);
if (!isImageFieldCollectionInputInstance(fieldInputInstance)) {
log.warn({ fieldIdentifier }, 'Attempted to add images to a non-image field collection');
return;
}
const newValue = fieldInputInstance.value ? [...fieldInputInstance.value] : [];
const imageDTOs: ImageDTO[] = [];
if (singleImageDndSource.typeGuard(sourceData)) {
newValue.push({ image_name: sourceData.payload.imageDTO.image_name });
imageDTOs.push(sourceData.payload.imageDTO);
} else {
newValue.push(...sourceData.payload.imageDTOs.map(({ image_name }) => ({ image_name })));
imageDTOs.push(...sourceData.payload.imageDTOs);
}
dispatch(fieldImageCollectionValueChanged({ ...fieldIdentifier, value: newValue }));
addImagesToNodeImageFieldCollectionAction({ fieldIdentifier, imageDTOs, dispatch, getState });
},
};
//#endregion

View File

@@ -46,7 +46,6 @@ export const ImageViewer = memo(({ closeButton }: Props) => {
left={0}
alignItems="center"
justifyContent="center"
overflow="hidden"
>
{hasImageToCompare && <CompareToolbar />}
{!hasImageToCompare && <ViewerToolbar closeButton={closeButton} />}

View File

@@ -1,5 +1,4 @@
import type { ButtonProps } from '@invoke-ai/ui-library';
import { Alert, AlertDescription, AlertIcon, Button, Divider, Flex, Link, Spinner, Text } from '@invoke-ai/ui-library';
import { Button, Divider, Flex, Spinner, Text } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { InvokeLogoIcon } from 'common/components/InvokeLogoIcon';
@@ -8,10 +7,9 @@ import { $installModelsTab } from 'features/modelManagerV2/subpanels/InstallMode
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { selectIsLocal } from 'features/system/store/configSlice';
import { setActiveTab } from 'features/ui/store/uiSlice';
import type { PropsWithChildren } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { Trans, useTranslation } from 'react-i18next';
import { PiArrowSquareOutBold, PiImageBold } from 'react-icons/pi';
import { PiImageBold } from 'react-icons/pi';
import { useMainModels } from 'services/api/hooks/modelsByType';
export const NoContentForViewer = memo(() => {
@@ -20,105 +18,6 @@ export const NoContentForViewer = memo(() => {
const isLocal = useAppSelector(selectIsLocal);
const isEnabled = useFeatureStatus('starterModels');
const { t } = useTranslation();
const showStarterBundles = useMemo(() => {
return isEnabled && data && mainModels.length === 0;
}, [mainModels.length, data, isEnabled]);
if (hasImages === LOADING_SYMBOL) {
// Blank bg w/ a spinner. The new user experience components below have an invoke logo, but it's not centered.
// If we show the logo while loading, there is an awkward layout shift where the invoke logo moves a bit. Less
// jarring to show a blank bg with a spinner - it will only be shown for a moment as we do the initial images
// fetching.
return <LoadingSpinner />;
}
if (hasImages) {
return <IAINoContentFallback icon={PiImageBold} label={t('gallery.noImageSelected')} />;
}
return (
<Flex flexDir="column" gap={8} alignItems="center" textAlign="center" maxW="600px">
<InvokeLogoIcon w={32} h={32} />
<Flex flexDir="column" gap={4} alignItems="center" textAlign="center">
{isLocal ? <GetStartedLocal /> : <GetStartedCommercial />}
{showStarterBundles && <StarterBundlesCallout />}
<Divider />
<GettingStartedVideosCallout />
{isLocal && <LowVRAMAlert />}
</Flex>
</Flex>
);
});
NoContentForViewer.displayName = 'NoContentForViewer';
const LoadingSpinner = () => {
return (
<Flex position="relative" width="full" height="full" alignItems="center" justifyContent="center">
<Spinner label="Loading" color="grey" position="absolute" size="sm" width={8} height={8} right={4} bottom={4} />
</Flex>
);
};
export const ExternalLink = (props: ButtonProps & { href: string }) => {
return (
<Button
as={Link}
variant="unstyled"
isExternal
display="inline-flex"
alignItems="center"
rightIcon={<PiArrowSquareOutBold />}
color="base.50"
mt={-1}
{...props}
/>
);
};
const InlineButton = (props: PropsWithChildren<{ onClick: () => void }>) => {
return (
<Button variant="link" size="md" onClick={props.onClick} color="base.50">
{props.children}
</Button>
);
};
const StrongComponent = <Text as="span" color="base.50" fontSize="md" />;
const GetStartedLocal = () => {
return (
<Text fontSize="md" color="base.200">
<Trans i18nKey="newUserExperience.toGetStartedLocal" components={{ StrongComponent }} />
</Text>
);
};
const GetStartedCommercial = () => {
return (
<Text fontSize="md" color="base.200">
<Trans i18nKey="newUserExperience.toGetStarted" components={{ StrongComponent }} />
</Text>
);
};
const GettingStartedVideosCallout = () => {
return (
<Text fontSize="md" color="base.200">
<Trans
i18nKey="newUserExperience.gettingStartedSeries"
components={{
LinkComponent: (
<ExternalLink href="https://www.youtube.com/playlist?list=PLvWK1Kc8iXGrQy8r9TYg6QdUuJ5MMx-ZO" />
),
}}
/>
</Text>
);
};
const StarterBundlesCallout = () => {
const dispatch = useAppDispatch();
const handleClickDownloadStarterModels = useCallback(() => {
@@ -131,31 +30,89 @@ const StarterBundlesCallout = () => {
$installModelsTab.set(0);
}, [dispatch]);
return (
<Text fontSize="md" color="base.200">
<Trans
i18nKey="newUserExperience.noModelsInstalled"
components={{
DownloadStarterModelsButton: <InlineButton onClick={handleClickDownloadStarterModels} />,
ImportModelsButton: <InlineButton onClick={handleClickImportModels} />,
}}
/>
</Text>
);
};
const showStarterBundles = useMemo(() => {
return isEnabled && data && mainModels.length === 0;
}, [mainModels.length, data, isEnabled]);
if (hasImages === LOADING_SYMBOL) {
return (
// Blank bg w/ a spinner. The new user experience components below have an invoke logo, but it's not centered.
// If we show the logo while loading, there is an awkward layout shift where the invoke logo moves a bit. Less
// jarring to show a blank bg with a spinner - it will only be shown for a moment as we do the initial images
// fetching.
<Flex position="relative" width="full" height="full" alignItems="center" justifyContent="center">
<Spinner label="Loading" color="grey" position="absolute" size="sm" width={8} height={8} right={4} bottom={4} />
</Flex>
);
}
if (hasImages) {
return <IAINoContentFallback icon={PiImageBold} label={t('gallery.noImageSelected')} />;
}
const LowVRAMAlert = () => {
return (
<Alert status="warning" borderRadius="base" fontSize="md" shadow="md" w="fit-content">
<AlertIcon />
<AlertDescription>
<Trans
i18nKey="newUserExperience.lowVRAMMode"
components={{
LinkComponent: <ExternalLink href="https://invoke-ai.github.io/InvokeAI/features/low-vram/" />,
}}
/>
</AlertDescription>
</Alert>
<Flex flexDir="column" gap={4} alignItems="center" textAlign="center" maxW="600px">
<InvokeLogoIcon w={40} h={40} />
<Flex flexDir="column" gap={8} alignItems="center" textAlign="center">
<Text fontSize="md" color="base.200" pt={16}>
{isLocal ? (
<Trans
i18nKey="newUserExperience.toGetStartedLocal"
components={{
StrongComponent: <Text as="span" color="white" fontSize="md" fontWeight="semibold" />,
}}
/>
) : (
<Trans
i18nKey="newUserExperience.toGetStarted"
components={{
StrongComponent: <Text as="span" color="white" fontSize="md" fontWeight="semibold" />,
}}
/>
)}
</Text>
{showStarterBundles && (
<Flex flexDir="column" gap={2} alignItems="center">
<Text fontSize="md" color="base.200">
{t('newUserExperience.noModelsInstalled')}
</Text>
<Flex gap={3} alignItems="center">
<Button size="sm" onClick={handleClickDownloadStarterModels}>
{t('newUserExperience.downloadStarterModels')}
</Button>
<Text fontSize="sm" color="base.200">
{t('common.or')}
</Text>
<Button size="sm" onClick={handleClickImportModels}>
{t('newUserExperience.importModels')}
</Button>
</Flex>
</Flex>
)}
<Divider />
<Text fontSize="md" color="base.200">
<Trans
i18nKey="newUserExperience.gettingStartedSeries"
components={{
LinkComponent: (
<Text
as="a"
color="white"
fontSize="md"
fontWeight="semibold"
href="https://www.youtube.com/playlist?list=PLvWK1Kc8iXGrQy8r9TYg6QdUuJ5MMx-ZO"
target="_blank"
/>
),
}}
/>
</Text>
</Flex>
</Flex>
);
};
});
NoContentForViewer.displayName = 'NoContentForViewer';

View File

@@ -1,3 +1,4 @@
import { logger } from 'app/logging/logger';
import type { AppDispatch, RootState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import { selectDefaultIPAdapter } from 'features/controlLayers/hooks/addLayerHooks';
@@ -30,15 +31,19 @@ import { imageDTOToImageObject, imageDTOToImageWithDims, initialControlNet } fro
import { calculateNewSize } from 'features/controlLayers/util/getScaledBoundingBoxDimensions';
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
import type { BoardId } from 'features/gallery/store/types';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import type { FieldIdentifier } from 'features/nodes/types/field';
import { fieldImageCollectionValueChanged, fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { selectFieldInputInstance, selectNodesSlice } from 'features/nodes/store/selectors';
import { type FieldIdentifier, isImageFieldCollectionInputInstance } from 'features/nodes/types/field';
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
import { getOptimalDimension } from 'features/parameters/util/optimalDimension';
import { uniqBy } from 'lodash-es';
import { imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
const log = logger('system');
export const setGlobalReferenceImage = (arg: {
imageDTO: ImageDTO;
entityIdentifier: CanvasEntityIdentifier<'reference_image'>;
@@ -72,6 +77,54 @@ export const setNodeImageFieldImage = (arg: {
dispatch(fieldImageValueChanged({ ...fieldIdentifier, value: imageDTO }));
};
export const addImagesToNodeImageFieldCollectionAction = (arg: {
imageDTOs: ImageDTO[];
fieldIdentifier: FieldIdentifier;
dispatch: AppDispatch;
getState: () => RootState;
}) => {
const { imageDTOs, fieldIdentifier, dispatch, getState } = arg;
const fieldInputInstance = selectFieldInputInstance(
selectNodesSlice(getState()),
fieldIdentifier.nodeId,
fieldIdentifier.fieldName
);
if (!isImageFieldCollectionInputInstance(fieldInputInstance)) {
log.warn({ fieldIdentifier }, 'Attempted to add images to a non-image field collection');
return;
}
const images = fieldInputInstance.value ? [...fieldInputInstance.value] : [];
images.push(...imageDTOs.map(({ image_name }) => ({ image_name })));
const uniqueImages = uniqBy(images, 'image_name');
dispatch(fieldImageCollectionValueChanged({ ...fieldIdentifier, value: uniqueImages }));
};
export const removeImageFromNodeImageFieldCollectionAction = (arg: {
imageName: string;
fieldIdentifier: FieldIdentifier;
dispatch: AppDispatch;
getState: () => RootState;
}) => {
const { imageName, fieldIdentifier, dispatch, getState } = arg;
const fieldInputInstance = selectFieldInputInstance(
selectNodesSlice(getState()),
fieldIdentifier.nodeId,
fieldIdentifier.fieldName
);
if (!isImageFieldCollectionInputInstance(fieldInputInstance)) {
log.warn({ fieldIdentifier }, 'Attempted to remove image from a non-image field collection');
return;
}
const images = fieldInputInstance.value ? [...fieldInputInstance.value] : [];
const imagesWithoutTheImageToRemove = images.filter((image) => image.image_name !== imageName);
const uniqueImages = uniqBy(imagesWithoutTheImageToRemove, 'image_name');
dispatch(fieldImageCollectionValueChanged({ ...fieldIdentifier, value: uniqueImages }));
};
export const setComparisonImage = (arg: { imageDTO: ImageDTO; dispatch: AppDispatch }) => {
const { imageDTO, dispatch } = arg;
dispatch(imageToCompareChanged(imageDTO));

View File

@@ -43,7 +43,7 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
{fieldNames.connectionFields.map((fieldName, i) => (
<GridItem gridColumnStart={1} gridRowStart={i + 1} key={`${nodeId}.${fieldName}.input-field`}>
<InvocationInputFieldCheck nodeId={nodeId} fieldName={fieldName}>
<InputField nodeId={nodeId} fieldName={fieldName} isLinearView={false} />
<InputField nodeId={nodeId} fieldName={fieldName} />
</InvocationInputFieldCheck>
</GridItem>
))}
@@ -59,7 +59,7 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
nodeId={nodeId}
fieldName={fieldName}
>
<InputField nodeId={nodeId} fieldName={fieldName} isLinearView={false} />
<InputField nodeId={nodeId} fieldName={fieldName} />
</InvocationInputFieldCheck>
))}
{fieldNames.missingFields.map((fieldName) => (
@@ -68,7 +68,7 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
nodeId={nodeId}
fieldName={fieldName}
>
<InputField nodeId={nodeId} fieldName={fieldName} isLinearView={false} />
<InputField nodeId={nodeId} fieldName={fieldName} />
</InvocationInputFieldCheck>
))}
</Flex>

View File

@@ -1,7 +1,7 @@
import { IconButton } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance';
import { useFieldValue } from 'features/nodes/hooks/useFieldValue';
import {
selectWorkflowSlice,
workflowExposedFieldAdded,
@@ -19,7 +19,7 @@ type Props = {
const FieldLinearViewToggle = ({ nodeId, fieldName }: Props) => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const field = useFieldInputInstance(nodeId, fieldName);
const value = useFieldValue(nodeId, fieldName);
const selectIsExposed = useMemo(
() =>
createSelector(selectWorkflowSlice, (workflow) => {
@@ -31,11 +31,8 @@ const FieldLinearViewToggle = ({ nodeId, fieldName }: Props) => {
const isExposed = useAppSelector(selectIsExposed);
const handleExposeField = useCallback(() => {
if (!field) {
return;
}
dispatch(workflowExposedFieldAdded({ nodeId, fieldName, field }));
}, [dispatch, field, fieldName, nodeId]);
dispatch(workflowExposedFieldAdded({ nodeId, fieldName, value }));
}, [dispatch, fieldName, nodeId, value]);
const handleUnexposeField = useCallback(() => {
dispatch(workflowExposedFieldRemoved({ nodeId, fieldName }));

View File

@@ -14,10 +14,9 @@ import { InputFieldWrapper } from './InputFieldWrapper';
interface Props {
nodeId: string;
fieldName: string;
isLinearView: boolean;
}
const InputField = ({ nodeId, fieldName, isLinearView }: Props) => {
const InputField = ({ nodeId, fieldName }: Props) => {
const fieldTemplate = useFieldInputTemplate(nodeId, fieldName);
const [isHovered, setIsHovered] = useState(false);
const isInvalid = useFieldIsInvalid(nodeId, fieldName);
@@ -70,12 +69,12 @@ const InputField = ({ nodeId, fieldName, isLinearView }: Props) => {
px={2}
>
<Flex flexDir="column" w="full" gap={1} onMouseEnter={onMouseEnter} onMouseLeave={onMouseLeave}>
<Flex gap={1} alignItems="center">
<Flex gap={1}>
<EditableFieldTitle nodeId={nodeId} fieldName={fieldName} kind="inputs" isInvalid={isInvalid} withTooltip />
{isHovered && <FieldResetToDefaultValueButton nodeId={nodeId} fieldName={fieldName} />}
{isHovered && <FieldLinearViewToggle nodeId={nodeId} fieldName={fieldName} />}
</Flex>
<InputFieldRenderer nodeId={nodeId} fieldName={fieldName} isLinearView={isLinearView} />
<InputFieldRenderer nodeId={nodeId} fieldName={fieldName} />
</Flex>
</FormControl>

View File

@@ -1,7 +1,5 @@
import { ImageFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldCollectionInputComponent';
import ModelIdentifierFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent';
import { NumberFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/NumberFieldCollectionInputComponent';
import { StringFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/StringFieldCollectionInputComponent';
import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance';
import { useFieldInputTemplate } from 'features/nodes/hooks/useFieldInputTemplate';
import {
@@ -23,8 +21,6 @@ import {
isControlNetModelFieldInputTemplate,
isEnumFieldInputInstance,
isEnumFieldInputTemplate,
isFloatFieldCollectionInputInstance,
isFloatFieldCollectionInputTemplate,
isFloatFieldInputInstance,
isFloatFieldInputTemplate,
isFluxMainModelFieldInputInstance,
@@ -35,8 +31,6 @@ import {
isImageFieldCollectionInputTemplate,
isImageFieldInputInstance,
isImageFieldInputTemplate,
isIntegerFieldCollectionInputInstance,
isIntegerFieldCollectionInputTemplate,
isIntegerFieldInputInstance,
isIntegerFieldInputTemplate,
isIPAdapterModelFieldInputInstance,
@@ -57,8 +51,6 @@ import {
isSDXLRefinerModelFieldInputTemplate,
isSpandrelImageToImageModelFieldInputInstance,
isSpandrelImageToImageModelFieldInputTemplate,
isStringFieldCollectionInputInstance,
isStringFieldCollectionInputTemplate,
isStringFieldInputInstance,
isStringFieldInputTemplate,
isT2IAdapterModelFieldInputInstance,
@@ -99,285 +91,96 @@ import VAEModelFieldInputComponent from './inputs/VAEModelFieldInputComponent';
type InputFieldProps = {
nodeId: string;
fieldName: string;
isLinearView: boolean;
};
const InputFieldRenderer = ({ nodeId, fieldName, isLinearView }: InputFieldProps) => {
const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
const fieldInstance = useFieldInputInstance(nodeId, fieldName);
const fieldTemplate = useFieldInputTemplate(nodeId, fieldName);
if (isStringFieldCollectionInputInstance(fieldInstance) && isStringFieldCollectionInputTemplate(fieldTemplate)) {
return (
<StringFieldCollectionInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isStringFieldInputInstance(fieldInstance) && isStringFieldInputTemplate(fieldTemplate)) {
return (
<StringFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
return <StringFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isBooleanFieldInputInstance(fieldInstance) && isBooleanFieldInputTemplate(fieldTemplate)) {
return (
<BooleanFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
return <BooleanFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isIntegerFieldInputInstance(fieldInstance) && isIntegerFieldInputTemplate(fieldTemplate)) {
return (
<NumberFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isFloatFieldInputInstance(fieldInstance) && isFloatFieldInputTemplate(fieldTemplate)) {
return (
<NumberFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isIntegerFieldCollectionInputInstance(fieldInstance) && isIntegerFieldCollectionInputTemplate(fieldTemplate)) {
return (
<NumberFieldCollectionInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isFloatFieldCollectionInputInstance(fieldInstance) && isFloatFieldCollectionInputTemplate(fieldTemplate)) {
return (
<NumberFieldCollectionInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
if (
(isIntegerFieldInputInstance(fieldInstance) && isIntegerFieldInputTemplate(fieldTemplate)) ||
(isFloatFieldInputInstance(fieldInstance) && isFloatFieldInputTemplate(fieldTemplate))
) {
return <NumberFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isEnumFieldInputInstance(fieldInstance) && isEnumFieldInputTemplate(fieldTemplate)) {
return (
<EnumFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
return <EnumFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isImageFieldCollectionInputInstance(fieldInstance) && isImageFieldCollectionInputTemplate(fieldTemplate)) {
return (
<ImageFieldCollectionInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
return <ImageFieldCollectionInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isImageFieldInputInstance(fieldInstance) && isImageFieldInputTemplate(fieldTemplate)) {
return (
<ImageFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
return <ImageFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isBoardFieldInputInstance(fieldInstance) && isBoardFieldInputTemplate(fieldTemplate)) {
return (
<BoardFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
return <BoardFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isMainModelFieldInputInstance(fieldInstance) && isMainModelFieldInputTemplate(fieldTemplate)) {
return (
<MainModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
return <MainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isModelIdentifierFieldInputInstance(fieldInstance) && isModelIdentifierFieldInputTemplate(fieldTemplate)) {
return (
<ModelIdentifierFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
return <ModelIdentifierFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isSDXLRefinerModelFieldInputInstance(fieldInstance) && isSDXLRefinerModelFieldInputTemplate(fieldTemplate)) {
return (
<RefinerModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
return <RefinerModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isVAEModelFieldInputInstance(fieldInstance) && isVAEModelFieldInputTemplate(fieldTemplate)) {
return (
<VAEModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
return <VAEModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isT5EncoderModelFieldInputInstance(fieldInstance) && isT5EncoderModelFieldInputTemplate(fieldTemplate)) {
return (
<T5EncoderModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
return <T5EncoderModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isCLIPEmbedModelFieldInputInstance(fieldInstance) && isCLIPEmbedModelFieldInputTemplate(fieldTemplate)) {
return (
<CLIPEmbedModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
return <CLIPEmbedModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isCLIPLEmbedModelFieldInputInstance(fieldInstance) && isCLIPLEmbedModelFieldInputTemplate(fieldTemplate)) {
return (
<CLIPLEmbedModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
return <CLIPLEmbedModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isCLIPGEmbedModelFieldInputInstance(fieldInstance) && isCLIPGEmbedModelFieldInputTemplate(fieldTemplate)) {
return (
<CLIPGEmbedModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
return <CLIPGEmbedModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isControlLoRAModelFieldInputInstance(fieldInstance) && isControlLoRAModelFieldInputTemplate(fieldTemplate)) {
return (
<ControlLoRAModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
return <ControlLoRAModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isFluxVAEModelFieldInputInstance(fieldInstance) && isFluxVAEModelFieldInputTemplate(fieldTemplate)) {
return (
<FluxVAEModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
return <FluxVAEModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isLoRAModelFieldInputInstance(fieldInstance) && isLoRAModelFieldInputTemplate(fieldTemplate)) {
return (
<LoRAModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
return <LoRAModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isControlNetModelFieldInputInstance(fieldInstance) && isControlNetModelFieldInputTemplate(fieldTemplate)) {
return (
<ControlNetModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
return <ControlNetModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isIPAdapterModelFieldInputInstance(fieldInstance) && isIPAdapterModelFieldInputTemplate(fieldTemplate)) {
return (
<IPAdapterModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
return <IPAdapterModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isT2IAdapterModelFieldInputInstance(fieldInstance) && isT2IAdapterModelFieldInputTemplate(fieldTemplate)) {
return (
<T2IAdapterModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
return <T2IAdapterModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (
@@ -389,64 +192,28 @@ const InputFieldRenderer = ({ nodeId, fieldName, isLinearView }: InputFieldProps
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isColorFieldInputInstance(fieldInstance) && isColorFieldInputTemplate(fieldTemplate)) {
return (
<ColorFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
return <ColorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isFluxMainModelFieldInputInstance(fieldInstance) && isFluxMainModelFieldInputTemplate(fieldTemplate)) {
return (
<FluxMainModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
return <FluxMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isSD3MainModelFieldInputInstance(fieldInstance) && isSD3MainModelFieldInputTemplate(fieldTemplate)) {
return (
<SD3MainModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
return <SD3MainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isSDXLMainModelFieldInputInstance(fieldInstance) && isSDXLMainModelFieldInputTemplate(fieldTemplate)) {
return (
<SDXLMainModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
return <SDXLMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isSchedulerFieldInputInstance(fieldInstance) && isSchedulerFieldInputTemplate(fieldTemplate)) {
return (
<SchedulerFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
return <SchedulerFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (fieldTemplate) {

View File

@@ -1,6 +1,6 @@
import { Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { $templates } from 'features/nodes/store/nodesSlice';
import { selectInvocationNode, selectNodesSlice } from 'features/nodes/store/selectors';
@@ -18,7 +18,7 @@ export const InvocationInputFieldCheck = memo(({ nodeId, fieldName, children }:
const templates = useStore($templates);
const selector = useMemo(
() =>
createMemoizedSelector(selectNodesSlice, (nodesSlice) => {
createSelector(selectNodesSlice, (nodesSlice) => {
const node = selectInvocationNode(nodesSlice, nodeId);
const instance = node.data.inputs[fieldName];
const template = templates[node.data.type];

View File

@@ -97,11 +97,7 @@ const LinearViewFieldInternal = ({ fieldIdentifier }: Props) => {
icon={<PiTrashSimpleBold />}
/>
</Flex>
<InputFieldRenderer
nodeId={fieldIdentifier.nodeId}
fieldName={fieldIdentifier.fieldName}
isLinearView={true}
/>
<InputFieldRenderer nodeId={fieldIdentifier.nodeId} fieldName={fieldIdentifier.fieldName} />
</Flex>
</Flex>
<DndListDropIndicator dndState={dndListState} />

View File

@@ -26,7 +26,7 @@ const EnumFieldInputComponent = (props: FieldComponentProps<EnumFieldInputInstan
);
return (
<Select className="nowheel nodrag" onChange={handleValueChanged} value={field.value} size="sm">
<Select className="nowheel nodrag" onChange={handleValueChanged} value={field.value}>
{fieldTemplate.options.map((option) => (
<option key={option} value={option}>
{fieldTemplate.ui_choice_labels ? fieldTemplate.ui_choice_labels[option] : option}

View File

@@ -1,65 +0,0 @@
import { CompositeNumberInput, Flex, FormControl, FormLabel, IconButton } from '@invoke-ai/ui-library';
import {
type FloatRangeStartStepCountGenerator,
getDefaultFloatRangeStartStepCountGenerator,
} from 'features/nodes/types/generators';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowCounterClockwiseBold } from 'react-icons/pi';
type FloatRangeGeneratorProps = {
state: FloatRangeStartStepCountGenerator;
onChange: (state: FloatRangeStartStepCountGenerator) => void;
};
export const FloatRangeGenerator = memo(({ state, onChange }: FloatRangeGeneratorProps) => {
const { t } = useTranslation();
const onChangeStart = useCallback(
(start: number) => {
onChange({ ...state, start });
},
[onChange, state]
);
const onChangeStep = useCallback(
(step: number) => {
onChange({ ...state, step });
},
[onChange, state]
);
const onChangeCount = useCallback(
(count: number) => {
onChange({ ...state, count });
},
[onChange, state]
);
const onReset = useCallback(() => {
onChange(getDefaultFloatRangeStartStepCountGenerator());
}, [onChange]);
return (
<Flex gap={1} alignItems="flex-end" p={1}>
<FormControl orientation="vertical" gap={1}>
<FormLabel m={0}>{t('common.start')}</FormLabel>
<CompositeNumberInput value={state.start} onChange={onChangeStart} min={-Infinity} max={Infinity} step={0.01} />
</FormControl>
<FormControl orientation="vertical" gap={1}>
<FormLabel m={0}>{t('common.count')}</FormLabel>
<CompositeNumberInput value={state.count} onChange={onChangeCount} min={1} max={Infinity} />
</FormControl>
<FormControl orientation="vertical" gap={1}>
<FormLabel m={0}>{t('common.step')}</FormLabel>
<CompositeNumberInput value={state.step} onChange={onChangeStep} min={-Infinity} max={Infinity} step={0.01} />
</FormControl>
<IconButton
onClick={onReset}
aria-label={t('common.reset')}
icon={<PiArrowCounterClockwiseBold />}
variant="ghost"
/>
</Flex>
);
});
FloatRangeGenerator.displayName = 'FloatRangeGenerator';

View File

@@ -10,9 +10,9 @@ import { addImagesToNodeImageFieldCollectionDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { DndImage } from 'features/dnd/DndImage';
import { DndImageIcon } from 'features/dnd/DndImageIcon';
import { removeImageFromNodeImageFieldCollectionAction } from 'features/imageActions/actions';
import { useFieldIsInvalid } from 'features/nodes/hooks/useFieldIsInvalid';
import { fieldImageCollectionValueChanged } from 'features/nodes/store/nodesSlice';
import type { ImageField } from 'features/nodes/types/common';
import type { ImageFieldCollectionInputInstance, ImageFieldCollectionInputTemplate } from 'features/nodes/types/field';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import { memo, useCallback, useMemo } from 'react';
@@ -61,12 +61,15 @@ export const ImageFieldCollectionInputComponent = memo(
);
const onRemoveImage = useCallback(
(index: number) => {
const newValue = field.value ? [...field.value] : [];
newValue.splice(index, 1);
store.dispatch(fieldImageCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
(imageName: string) => {
removeImageFromNodeImageFieldCollectionAction({
imageName,
fieldIdentifier: { nodeId, fieldName: field.name },
dispatch: store.dispatch,
getState: store.getState,
});
},
[field.name, field.value, nodeId, store]
[field.name, nodeId, store.dispatch, store.getState]
);
return (
@@ -87,7 +90,7 @@ export const ImageFieldCollectionInputComponent = memo(
isError={isInvalid}
onUpload={onUpload}
fontSize={24}
variant="ghost"
variant="outline"
/>
)}
{field.value && field.value.length > 0 && (
@@ -99,9 +102,9 @@ export const ImageFieldCollectionInputComponent = memo(
options={overlayscrollbarsOptions}
>
<Grid w="full" h="full" templateColumns="repeat(4, 1fr)" gap={1}>
{field.value.map((value, index) => (
<GridItem key={index} position="relative" className="nodrag">
<ImageGridItemContent value={value} index={index} onRemoveImage={onRemoveImage} />
{field.value.map(({ image_name }) => (
<GridItem key={image_name} position="relative" className="nodrag">
<ImageGridItemContent imageName={image_name} onRemoveImage={onRemoveImage} />
</GridItem>
))}
</Grid>
@@ -121,11 +124,11 @@ export const ImageFieldCollectionInputComponent = memo(
ImageFieldCollectionInputComponent.displayName = 'ImageFieldCollectionInputComponent';
const ImageGridItemContent = memo(
({ value, index, onRemoveImage }: { value: ImageField; index: number; onRemoveImage: (index: number) => void }) => {
const query = useGetImageDTOQuery(value.image_name);
({ imageName, onRemoveImage }: { imageName: string; onRemoveImage: (imageName: string) => void }) => {
const query = useGetImageDTOQuery(imageName);
const onClickRemove = useCallback(() => {
onRemoveImage(index);
}, [index, onRemoveImage]);
onRemoveImage(imageName);
}, [imageName, onRemoveImage]);
if (query.isLoading) {
return <IAINoContentFallbackWithSpinner />;

View File

@@ -1,320 +0,0 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import {
Button,
CompositeNumberInput,
Divider,
Flex,
FormControl,
FormLabel,
Grid,
GridItem,
IconButton,
Switch,
Text,
} from '@invoke-ai/ui-library';
import { NUMPY_RAND_MAX } from 'app/constants';
import { useAppStore } from 'app/store/nanostores/store';
import { getOverlayScrollbarsParams, overlayScrollbarsStyles } from 'common/components/OverlayScrollbars/constants';
import { FloatRangeGenerator } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FloatRangeGenerator';
import { useFieldIsInvalid } from 'features/nodes/hooks/useFieldIsInvalid';
import {
fieldNumberCollectionGeneratorCommitted,
fieldNumberCollectionGeneratorStateChanged,
fieldNumberCollectionGeneratorToggled,
fieldNumberCollectionLockLinearViewToggled,
fieldNumberCollectionValueChanged,
} from 'features/nodes/store/nodesSlice';
import type {
FloatFieldCollectionInputInstance,
FloatFieldCollectionInputTemplate,
IntegerFieldCollectionInputInstance,
IntegerFieldCollectionInputTemplate,
} from 'features/nodes/types/field';
import { resolveNumberFieldCollectionValue } from 'features/nodes/types/fieldValidators';
import type {
FloatRangeStartStepCountGenerator,
IntegerRangeStartStepCountGenerator,
} from 'features/nodes/types/generators';
import { isNil, round } from 'lodash-es';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiLockSimpleFill, PiLockSimpleOpenFill, PiXBold } from 'react-icons/pi';
import type { FieldComponentProps } from './types';
const overlayscrollbarsOptions = getOverlayScrollbarsParams().options;
const sx = {
borderWidth: 1,
'&[data-error=true]': {
borderColor: 'error.500',
borderStyle: 'solid',
},
} satisfies SystemStyleObject;
export const NumberFieldCollectionInputComponent = memo(
(
props:
| FieldComponentProps<IntegerFieldCollectionInputInstance, IntegerFieldCollectionInputTemplate>
| FieldComponentProps<FloatFieldCollectionInputInstance, FloatFieldCollectionInputTemplate>
) => {
const { nodeId, field, fieldTemplate, isLinearView } = props;
const store = useAppStore();
const { t } = useTranslation();
const isInvalid = useFieldIsInvalid(nodeId, field.name);
const isIntegerField = useMemo(() => fieldTemplate.type.name === 'IntegerField', [fieldTemplate.type]);
const onRemoveNumber = useCallback(
(index: number) => {
const newValue = field.value ? [...field.value] : [];
newValue.splice(index, 1);
store.dispatch(fieldNumberCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
},
[field.name, field.value, nodeId, store]
);
const onChangeNumber = useCallback(
(index: number, value: number) => {
const newValue = field.value ? [...field.value] : [];
newValue[index] = value;
store.dispatch(fieldNumberCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
},
[field.name, field.value, nodeId, store]
);
const onAddNumber = useCallback(() => {
const newValue = field.value ? [...field.value, 0] : [0];
store.dispatch(fieldNumberCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
}, [field.name, field.value, nodeId, store]);
const min = useMemo(() => {
let min = -NUMPY_RAND_MAX;
if (!isNil(fieldTemplate.minimum)) {
min = fieldTemplate.minimum;
}
if (!isNil(fieldTemplate.exclusiveMinimum)) {
min = fieldTemplate.exclusiveMinimum + 0.01;
}
return min;
}, [fieldTemplate.exclusiveMinimum, fieldTemplate.minimum]);
const max = useMemo(() => {
let max = NUMPY_RAND_MAX;
if (!isNil(fieldTemplate.maximum)) {
max = fieldTemplate.maximum;
}
if (!isNil(fieldTemplate.exclusiveMaximum)) {
max = fieldTemplate.exclusiveMaximum - 0.01;
}
return max;
}, [fieldTemplate.exclusiveMaximum, fieldTemplate.maximum]);
const step = useMemo(() => {
if (isNil(fieldTemplate.multipleOf)) {
return isIntegerField ? 1 : 0.1;
}
return fieldTemplate.multipleOf;
}, [fieldTemplate.multipleOf, isIntegerField]);
const fineStep = useMemo(() => {
if (isNil(fieldTemplate.multipleOf)) {
return isIntegerField ? 1 : 0.01;
}
return fieldTemplate.multipleOf;
}, [fieldTemplate.multipleOf, isIntegerField]);
const toggleGenerator = useCallback(() => {
store.dispatch(fieldNumberCollectionGeneratorToggled({ nodeId, fieldName: field.name }));
}, [field.name, nodeId, store]);
const onChangeGenerator = useCallback(
(generatorState: FloatRangeStartStepCountGenerator | IntegerRangeStartStepCountGenerator) => {
store.dispatch(fieldNumberCollectionGeneratorStateChanged({ nodeId, fieldName: field.name, generatorState }));
},
[field.name, nodeId, store]
);
const onCommitGenerator = useCallback(() => {
store.dispatch(fieldNumberCollectionGeneratorCommitted({ nodeId, fieldName: field.name }));
}, [field.name, nodeId, store]);
const onToggleLockLinearView = useCallback(() => {
store.dispatch(fieldNumberCollectionLockLinearViewToggled({ nodeId, fieldName: field.name }));
}, [field.name, nodeId, store]);
const valuesAsString = useMemo(() => {
const resolvedValue = resolveNumberFieldCollectionValue(field);
return resolvedValue ? resolvedValue.map((val) => round(val, 2)).join(', ') : '';
}, [field]);
const isLockedOnLinearView = !(field.lockLinearView && isLinearView);
return (
<Flex
className="nodrag"
position="relative"
w="full"
h="auto"
maxH={64}
alignItems="stretch"
justifyContent="center"
p={1}
sx={sx}
data-error={isInvalid}
borderRadius="base"
flexDir="column"
gap={1}
>
<Flex w="full" gap={2}>
{!field.generator && (
<Button onClick={onAddNumber} variant="ghost" flexGrow={1} size="sm">
{t('nodes.addValue')}
</Button>
)}
{field.generator && isLockedOnLinearView && (
<Button
tooltip={
<Flex p={1} flexDir="column">
<Text fontWeight="semibold">{t('nodes.generatedValues')}:</Text>
<Text fontFamily="monospace">{valuesAsString}</Text>
</Flex>
}
onClick={onCommitGenerator}
variant="ghost"
flexGrow={1}
size="sm"
>
{t('nodes.commitValues')}
</Button>
)}
{isLockedOnLinearView && (
<FormControl w="min-content" pe={isLinearView ? 2 : undefined}>
<FormLabel m={0}>{t('nodes.generator')}</FormLabel>
<Switch onChange={toggleGenerator} isChecked={Boolean(field.generator)} size="sm" />
</FormControl>
)}
{!isLinearView && (
<IconButton
onClick={onToggleLockLinearView}
tooltip={field.lockLinearView ? t('nodes.unlockLinearView') : t('nodes.lockLinearView')}
aria-label={field.lockLinearView ? t('nodes.unlockLinearView') : t('nodes.lockLinearView')}
icon={field.lockLinearView ? <PiLockSimpleFill /> : <PiLockSimpleOpenFill />}
variant="ghost"
size="sm"
/>
)}
</Flex>
{!field.generator && field.value && field.value.length > 0 && (
<>
{!(field.lockLinearView && isLinearView) && <Divider />}
<OverlayScrollbarsComponent
className="nowheel"
defer
style={overlayScrollbarsStyles}
options={overlayscrollbarsOptions}
>
<Grid gap={1} gridTemplateColumns="auto 1fr auto" alignItems="center">
{field.value.map((value, index) => (
<NumberListItemContent
key={index}
value={value}
index={index}
min={min}
max={max}
step={step}
fineStep={fineStep}
isIntegerField={isIntegerField}
onRemoveNumber={onRemoveNumber}
onChangeNumber={onChangeNumber}
/>
))}
</Grid>
</OverlayScrollbarsComponent>
</>
)}
{field.generator && field.generator.type === 'float-range-generator-start-step-count' && (
<>
{!(field.lockLinearView && isLinearView) && <Divider />}
<FloatRangeGenerator state={field.generator} onChange={onChangeGenerator} />
</>
)}
</Flex>
);
}
);
NumberFieldCollectionInputComponent.displayName = 'NumberFieldCollectionInputComponent';
type NumberListItemContentProps = {
value: number;
index: number;
isIntegerField: boolean;
min: number;
max: number;
step: number;
fineStep: number;
onRemoveNumber: (index: number) => void;
onChangeNumber: (index: number, value: number) => void;
};
const NumberListItemContent = memo(
({
value,
index,
isIntegerField,
min,
max,
step,
fineStep,
onRemoveNumber,
onChangeNumber,
}: NumberListItemContentProps) => {
const { t } = useTranslation();
const onClickRemove = useCallback(() => {
onRemoveNumber(index);
}, [index, onRemoveNumber]);
const onChange = useCallback(
(v: number) => {
onChangeNumber(index, isIntegerField ? Math.floor(Number(v)) : Number(v));
},
[index, isIntegerField, onChangeNumber]
);
return (
<>
<GridItem>
<FormLabel ps={1} m={0}>
{index + 1}.
</FormLabel>
</GridItem>
<GridItem>
<CompositeNumberInput
onChange={onChange}
value={value}
min={min}
max={max}
step={step}
fineStep={fineStep}
className="nodrag"
flexGrow={1}
/>
</GridItem>
<GridItem>
<IconButton
tabIndex={-1}
size="sm"
variant="link"
alignSelf="stretch"
onClick={onClickRemove}
icon={<PiXBold />}
aria-label={t('common.delete')}
/>
</GridItem>
</>
);
}
);
NumberListItemContent.displayName = 'NumberListItemContent';

View File

@@ -1,152 +0,0 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Box, Flex, Grid, GridItem, IconButton, Input } from '@invoke-ai/ui-library';
import { useAppStore } from 'app/store/nanostores/store';
import { getOverlayScrollbarsParams, overlayScrollbarsStyles } from 'common/components/OverlayScrollbars/constants';
import { useFieldIsInvalid } from 'features/nodes/hooks/useFieldIsInvalid';
import { fieldStringCollectionValueChanged } from 'features/nodes/store/nodesSlice';
import type {
StringFieldCollectionInputInstance,
StringFieldCollectionInputTemplate,
} from 'features/nodes/types/field';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPlusBold, PiXBold } from 'react-icons/pi';
import type { FieldComponentProps } from './types';
const overlayscrollbarsOptions = getOverlayScrollbarsParams().options;
const sx = {
borderWidth: 1,
'&[data-error=true]': {
borderColor: 'error.500',
borderStyle: 'solid',
},
} satisfies SystemStyleObject;
export const StringFieldCollectionInputComponent = memo(
(props: FieldComponentProps<StringFieldCollectionInputInstance, StringFieldCollectionInputTemplate>) => {
const { nodeId, field } = props;
const store = useAppStore();
const isInvalid = useFieldIsInvalid(nodeId, field.name);
const onRemoveString = useCallback(
(index: number) => {
const newValue = field.value ? [...field.value] : [];
newValue.splice(index, 1);
store.dispatch(fieldStringCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
},
[field.name, field.value, nodeId, store]
);
const onChangeString = useCallback(
(index: number, value: string) => {
const newValue = field.value ? [...field.value] : [];
newValue[index] = value;
store.dispatch(fieldStringCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
},
[field.name, field.value, nodeId, store]
);
const onAddString = useCallback(() => {
const newValue = field.value ? [...field.value, ''] : [''];
store.dispatch(fieldStringCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
}, [field.name, field.value, nodeId, store]);
return (
<Flex
className="nodrag"
position="relative"
w="full"
h="full"
maxH={64}
alignItems="stretch"
justifyContent="center"
>
{(!field.value || field.value.length === 0) && (
<Box w="full" sx={sx} data-error={isInvalid} borderRadius="base">
<IconButton
w="full"
onClick={onAddString}
aria-label="Add Item"
icon={<PiPlusBold />}
variant="ghost"
size="sm"
/>
</Box>
)}
{field.value && field.value.length > 0 && (
<Box w="full" h="auto" p={1} sx={sx} data-error={isInvalid} borderRadius="base">
<OverlayScrollbarsComponent
className="nowheel"
defer
style={overlayScrollbarsStyles}
options={overlayscrollbarsOptions}
>
<Grid w="full" h="full" templateColumns="repeat(1, 1fr)" gap={1}>
<IconButton
onClick={onAddString}
aria-label="Add Item"
icon={<PiPlusBold />}
variant="ghost"
size="sm"
/>
{field.value.map((value, index) => (
<GridItem key={index} position="relative" className="nodrag">
<StringListItemContent
value={value}
index={index}
onRemoveString={onRemoveString}
onChangeString={onChangeString}
/>
</GridItem>
))}
</Grid>
</OverlayScrollbarsComponent>
</Box>
)}
</Flex>
);
}
);
StringFieldCollectionInputComponent.displayName = 'StringFieldCollectionInputComponent';
type StringListItemContentProps = {
value: string;
index: number;
onRemoveString: (index: number) => void;
onChangeString: (index: number, value: string) => void;
};
const StringListItemContent = memo(({ value, index, onRemoveString, onChangeString }: StringListItemContentProps) => {
const { t } = useTranslation();
const onClickRemove = useCallback(() => {
onRemoveString(index);
}, [index, onRemoveString]);
const onChange = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
onChangeString(index, e.target.value);
},
[index, onChangeString]
);
return (
<Flex alignItems="center" gap={1}>
<Input size="xs" resize="none" value={value} onChange={onChange} />
<IconButton
size="sm"
variant="link"
alignSelf="stretch"
onClick={onClickRemove}
icon={<PiXBold />}
aria-label={t('common.remove')}
tooltip={t('common.remove')}
/>
</Flex>
);
});
StringListItemContent.displayName = 'StringListItemContent';

View File

@@ -4,5 +4,4 @@ export type FieldComponentProps<V extends FieldInputInstance, T extends FieldInp
nodeId: string;
field: V;
fieldTemplate: T;
isLinearView: boolean;
};

View File

@@ -1,14 +1,12 @@
import type { SystemStyleObject, TextProps } from '@invoke-ai/ui-library';
import { Box, Editable, EditableInput, Flex, Text, useEditableControls } from '@invoke-ai/ui-library';
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Box, Editable, EditableInput, EditablePreview, Flex, useEditableControls } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useBatchGroupColorToken } from 'features/nodes/hooks/useBatchGroupColorToken';
import { useBatchGroupId } from 'features/nodes/hooks/useBatchGroupId';
import { useNodeLabel } from 'features/nodes/hooks/useNodeLabel';
import { useNodeTemplateTitle } from 'features/nodes/hooks/useNodeTemplateTitle';
import { nodeLabelChanged } from 'features/nodes/store/nodesSlice';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import type { MouseEvent } from 'react';
import { memo, useCallback, useEffect, useMemo, useState } from 'react';
import { memo, useCallback, useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next';
type Props = {
@@ -19,8 +17,6 @@ type Props = {
const NodeTitle = ({ nodeId, title }: Props) => {
const dispatch = useAppDispatch();
const label = useNodeLabel(nodeId);
const batchGroupId = useBatchGroupId(nodeId);
const batchGroupColorToken = useBatchGroupColorToken(batchGroupId);
const templateTitle = useNodeTemplateTitle(nodeId);
const { t } = useTranslation();
@@ -33,16 +29,6 @@ const NodeTitle = ({ nodeId, title }: Props) => {
[dispatch, nodeId, title, templateTitle, label, t]
);
const localTitleWithBatchGroupId = useMemo(() => {
if (!batchGroupId) {
return localTitle;
}
if (batchGroupId === 'None') {
return `${localTitle} (${t('nodes.noBatchGroup')})`;
}
return `${localTitle} (${batchGroupId})`;
}, [batchGroupId, localTitle, t]);
const handleChange = useCallback((newTitle: string) => {
setLocalTitle(newTitle);
}, []);
@@ -64,16 +50,7 @@ const NodeTitle = ({ nodeId, title }: Props) => {
w="full"
h="full"
>
<Preview
fontSize="sm"
p={0}
w="full"
noOfLines={1}
color={batchGroupColorToken}
fontWeight={batchGroupId ? 'semibold' : undefined}
>
{localTitleWithBatchGroupId}
</Preview>
<EditablePreview fontSize="sm" p={0} w="full" noOfLines={1} />
<EditableInput className="nodrag" fontSize="sm" sx={editableInputStyles} />
<EditableControls />
</Editable>
@@ -83,16 +60,6 @@ const NodeTitle = ({ nodeId, title }: Props) => {
export default memo(NodeTitle);
const Preview = (props: TextProps) => {
const { isEditing } = useEditableControls();
if (isEditing) {
return null;
}
return <Text {...props} />;
};
function EditableControls() {
const { isEditing, getEditButtonProps } = useEditableControls();
const handleDoubleClick = useCallback(

View File

@@ -5,7 +5,7 @@ import NodeOpacitySlider from './NodeOpacitySlider';
import ViewportControls from './ViewportControls';
const BottomLeftPanel = () => (
<Flex gap={2} position="absolute" bottom={2} insetInlineStart={2}>
<Flex gap={2} position="absolute" bottom={0} insetInlineStart={0}>
<ViewportControls />
<NodeOpacitySlider />
</Flex>

View File

@@ -20,7 +20,7 @@ const MinimapPanel = () => {
const shouldShowMinimapPanel = useAppSelector(selectShouldShowMinimapPanel);
return (
<Flex gap={2} position="absolute" bottom={2} insetInlineEnd={2}>
<Flex gap={2} position="absolute" bottom={0} insetInlineEnd={0}>
{shouldShowMinimapPanel && (
<ChakraMiniMap
pannable

View File

@@ -12,7 +12,7 @@ import { memo } from 'react';
const TopCenterPanel = () => {
const name = useAppSelector(selectWorkflowName);
return (
<Flex gap={2} top={2} left={2} right={2} position="absolute" alignItems="flex-start" pointerEvents="none">
<Flex gap={2} top={0} left={0} right={0} position="absolute" alignItems="flex-start" pointerEvents="none">
<Flex gap="2">
<AddNodeButton />
<UpdateNodesButton />

View File

@@ -46,7 +46,7 @@ const WorkflowFieldInternal = ({ nodeId, fieldName }: Props) => {
</Flex>
</Tooltip>
</Flex>
<InputFieldRenderer nodeId={nodeId} fieldName={fieldName} isLinearView={true} />
<InputFieldRenderer nodeId={nodeId} fieldName={fieldName} />
</Flex>
);
};

View File

@@ -1,22 +0,0 @@
import { useMemo } from 'react';
export const useBatchGroupColorToken = (batchGroupId?: string) => {
const batchGroupColorToken = useMemo(() => {
switch (batchGroupId) {
case 'Group 1':
return 'invokeGreen.300';
case 'Group 2':
return 'invokeBlue.300';
case 'Group 3':
return 'invokePurple.200';
case 'Group 4':
return 'invokeRed.300';
case 'Group 5':
return 'invokeYellow.300';
default:
return undefined;
}
}, [batchGroupId]);
return batchGroupColorToken;
};

View File

@@ -1,19 +0,0 @@
import { useNode } from 'features/nodes/hooks/useNode';
import { isBatchNode, isInvocationNode } from 'features/nodes/types/invocation';
import { useMemo } from 'react';
export const useBatchGroupId = (nodeId: string) => {
const node = useNode(nodeId);
const batchGroupId = useMemo(() => {
if (!isInvocationNode(node)) {
return;
}
if (!isBatchNode(node)) {
return;
}
return node.data.inputs['batch_group_id']?.value as string;
}, [node]);
return batchGroupId;
};

View File

@@ -3,21 +3,7 @@ import { useAppSelector } from 'app/store/storeHooks';
import { useConnectionState } from 'features/nodes/hooks/useConnectionState';
import { useFieldInputTemplate } from 'features/nodes/hooks/useFieldInputTemplate';
import { selectFieldInputInstance, selectNodesSlice } from 'features/nodes/store/selectors';
import {
isFloatFieldCollectionInputInstance,
isFloatFieldCollectionInputTemplate,
isImageFieldCollectionInputInstance,
isImageFieldCollectionInputTemplate,
isIntegerFieldCollectionInputInstance,
isIntegerFieldCollectionInputTemplate,
isStringFieldCollectionInputInstance,
isStringFieldCollectionInputTemplate,
} from 'features/nodes/types/field';
import {
validateImageFieldCollectionValue,
validateNumberFieldCollectionValue,
validateStringFieldCollectionValue,
} from 'features/nodes/types/fieldValidators';
import { isImageFieldCollectionInputInstance, isImageFieldCollectionInputTemplate } from 'features/nodes/types/field';
import { useMemo } from 'react';
export const useFieldIsInvalid = (nodeId: string, fieldName: string) => {
@@ -49,27 +35,13 @@ export const useFieldIsInvalid = (nodeId: string, fieldName: string) => {
}
// Else special handling for individual field types
if (isImageFieldCollectionInputInstance(field) && isImageFieldCollectionInputTemplate(template)) {
if (validateImageFieldCollectionValue(field.value, template).length > 0) {
// Image collections may have min or max item counts
if (template.minItems !== undefined && field.value.length < template.minItems) {
return true;
}
}
if (isStringFieldCollectionInputInstance(field) && isStringFieldCollectionInputTemplate(template)) {
if (validateStringFieldCollectionValue(field.value, template).length > 0) {
return true;
}
}
if (isIntegerFieldCollectionInputInstance(field) && isIntegerFieldCollectionInputTemplate(template)) {
if (validateNumberFieldCollectionValue(field, template).length > 0) {
return true;
}
}
if (isFloatFieldCollectionInputInstance(field) && isFloatFieldCollectionInputTemplate(template)) {
if (validateNumberFieldCollectionValue(field, template).length > 0) {
if (template.maxItems !== undefined && field.value.length > template.maxItems) {
return true;
}
}

View File

@@ -1,9 +1,8 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance';
import { useFieldValue } from 'features/nodes/hooks/useFieldValue';
import { fieldValueReset } from 'features/nodes/store/nodesSlice';
import { selectWorkflowSlice } from 'features/nodes/store/workflowSlice';
import { isFloatFieldCollectionInputInstance, isIntegerFieldCollectionInputInstance } from 'features/nodes/types/field';
import { isEqual } from 'lodash-es';
import { useCallback, useMemo } from 'react';
@@ -11,38 +10,19 @@ export const useFieldOriginalValue = (nodeId: string, fieldName: string) => {
const dispatch = useAppDispatch();
const selectOriginalExposedFieldValues = useMemo(
() =>
createMemoizedSelector(selectWorkflowSlice, (workflow) =>
workflow.originalExposedFieldValues.find((v) => v.nodeId === nodeId && v.fieldName === fieldName)
createSelector(
selectWorkflowSlice,
(workflow) =>
workflow.originalExposedFieldValues.find((v) => v.nodeId === nodeId && v.fieldName === fieldName)?.value
),
[nodeId, fieldName]
);
const exposedField = useAppSelector(selectOriginalExposedFieldValues);
const field = useFieldInputInstance(nodeId, fieldName);
const isValueChanged = useMemo(() => {
if (!field) {
// Field is not found, so it is not changed
return false;
}
if (isFloatFieldCollectionInputInstance(field) && isFloatFieldCollectionInputInstance(exposedField?.field)) {
return !isEqual(field.generator, exposedField.field.generator);
}
if (isIntegerFieldCollectionInputInstance(field) && isIntegerFieldCollectionInputInstance(exposedField?.field)) {
return !isEqual(field.generator, exposedField.field.generator);
}
return !isEqual(field.value, exposedField?.field.value);
}, [field, exposedField]);
const originalValue = useAppSelector(selectOriginalExposedFieldValues);
const value = useFieldValue(nodeId, fieldName);
const isValueChanged = useMemo(() => !isEqual(value, originalValue), [value, originalValue]);
const onReset = useCallback(() => {
if (!exposedField) {
return;
}
const { value } = exposedField.field;
const generator =
isIntegerFieldCollectionInputInstance(exposedField.field) ||
isFloatFieldCollectionInputInstance(exposedField.field)
? exposedField.field.generator
: undefined;
dispatch(fieldValueReset({ nodeId, fieldName, value, generator }));
}, [dispatch, fieldName, nodeId, exposedField]);
dispatch(fieldValueReset({ nodeId, fieldName, value: originalValue }));
}, [dispatch, fieldName, nodeId, originalValue]);
return { originalValue: exposedField, isValueChanged, onReset };
return { originalValue, isValueChanged, onReset };
};

View File

@@ -19,7 +19,6 @@ import type {
FluxVAEModelFieldValue,
ImageFieldCollectionValue,
ImageFieldValue,
IntegerFieldCollectionValue,
IntegerFieldValue,
IPAdapterModelFieldValue,
LoRAModelFieldValue,
@@ -29,15 +28,12 @@ import type {
SDXLRefinerModelFieldValue,
SpandrelImageToImageModelFieldValue,
StatefulFieldValue,
StringFieldCollectionValue,
StringFieldValue,
T2IAdapterModelFieldValue,
T5EncoderModelFieldValue,
VAEModelFieldValue,
} from 'features/nodes/types/field';
import {
isFloatFieldCollectionInputInstance,
isIntegerFieldCollectionInputInstance,
zBoardFieldValue,
zBooleanFieldValue,
zCLIPEmbedModelFieldValue,
@@ -47,12 +43,10 @@ import {
zControlLoRAModelFieldValue,
zControlNetModelFieldValue,
zEnumFieldValue,
zFloatFieldCollectionValue,
zFloatFieldValue,
zFluxVAEModelFieldValue,
zImageFieldCollectionValue,
zImageFieldValue,
zIntegerFieldCollectionValue,
zIntegerFieldValue,
zIPAdapterModelFieldValue,
zLoRAModelFieldValue,
@@ -62,22 +56,11 @@ import {
zSDXLRefinerModelFieldValue,
zSpandrelImageToImageModelFieldValue,
zStatefulFieldValue,
zStringFieldCollectionValue,
zStringFieldValue,
zT2IAdapterModelFieldValue,
zT5EncoderModelFieldValue,
zVAEModelFieldValue,
} from 'features/nodes/types/field';
import type {
FloatRangeStartStepCountGenerator,
IntegerRangeStartStepCountGenerator,
} from 'features/nodes/types/generators';
import {
floatRangeStartStepCountGenerator,
getDefaultFloatRangeStartStepCountGenerator,
getDefaultIntegerRangeStartStepCountGenerator,
integerRangeStartStepCountGenerator,
} from 'features/nodes/types/generators';
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
import { atom, computed } from 'nanostores';
@@ -95,22 +78,11 @@ const initialNodesState: NodesState = {
edges: [],
};
type FieldValueAction<T extends FieldValue, U = unknown> = PayloadAction<
{
nodeId: string;
fieldName: string;
value: T;
} & U
>;
const selectField = (state: NodesState, nodeId: string, fieldName: string) => {
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
const node = state.nodes?.[nodeIndex];
if (!isInvocationNode(node)) {
return;
}
return node.data?.inputs[fieldName];
};
type FieldValueAction<T extends FieldValue> = PayloadAction<{
nodeId: string;
fieldName: string;
value: T;
}>;
const fieldValueReducer = <T extends FieldValue>(
state: NodesState,
@@ -118,24 +90,17 @@ const fieldValueReducer = <T extends FieldValue>(
schema: z.ZodTypeAny
) => {
const { nodeId, fieldName, value } = action.payload;
const field = selectField(state, nodeId, fieldName);
const result = schema.safeParse(value);
if (!field || !result.success) {
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
const node = state.nodes?.[nodeIndex];
if (!isInvocationNode(node)) {
return;
}
field.value = result.data;
// Special handling if the field value is being reset
if (result.data === undefined) {
if (isFloatFieldCollectionInputInstance(field)) {
if (field.lockLinearView && field.generator) {
field.generator = getDefaultFloatRangeStartStepCountGenerator();
}
} else if (isIntegerFieldCollectionInputInstance(field)) {
if (field.lockLinearView && field.generator) {
field.generator = getDefaultIntegerRangeStartStepCountGenerator();
}
}
const input = node.data?.inputs[fieldName];
const result = schema.safeParse(value);
if (!input || nodeIndex < 0 || !result.success) {
return;
}
input.value = result.data;
};
export const nodesSlice = createSlice({
@@ -340,123 +305,15 @@ export const nodesSlice = createSlice({
}
node.data.notes = notes;
},
fieldValueReset: (
state,
action: FieldValueAction<
StatefulFieldValue,
{ generator?: IntegerRangeStartStepCountGenerator | FloatRangeStartStepCountGenerator }
>
) => {
const { nodeId, fieldName, value, generator } = action.payload;
const field = selectField(state, nodeId, fieldName);
const result = zStatefulFieldValue.safeParse(value);
if (!field || !result.success) {
return;
}
field.value = result.data;
if (isFloatFieldCollectionInputInstance(field) && generator?.type === 'float-range-generator-start-step-count') {
field.generator = generator;
} else if (
isIntegerFieldCollectionInputInstance(field) &&
generator?.type === 'integer-range-generator-start-step-count'
) {
field.generator = generator;
}
fieldValueReset: (state, action: FieldValueAction<StatefulFieldValue>) => {
fieldValueReducer(state, action, zStatefulFieldValue);
},
fieldStringValueChanged: (state, action: FieldValueAction<StringFieldValue>) => {
fieldValueReducer(state, action, zStringFieldValue);
},
fieldStringCollectionValueChanged: (state, action: FieldValueAction<StringFieldCollectionValue>) => {
fieldValueReducer(state, action, zStringFieldCollectionValue);
},
fieldNumberValueChanged: (state, action: FieldValueAction<IntegerFieldValue | FloatFieldValue>) => {
fieldValueReducer(state, action, zIntegerFieldValue.or(zFloatFieldValue));
},
fieldNumberCollectionValueChanged: (state, action: FieldValueAction<IntegerFieldCollectionValue>) => {
fieldValueReducer(state, action, zIntegerFieldCollectionValue.or(zFloatFieldCollectionValue));
},
fieldNumberCollectionGeneratorToggled: (state, action: PayloadAction<{ nodeId: string; fieldName: string }>) => {
const { nodeId, fieldName } = action.payload;
const field = selectField(state, nodeId, fieldName);
if (!field) {
return;
}
if (isFloatFieldCollectionInputInstance(field)) {
field.generator = field.generator ? undefined : getDefaultFloatRangeStartStepCountGenerator();
} else if (isIntegerFieldCollectionInputInstance(field)) {
field.generator = field.generator ? undefined : getDefaultIntegerRangeStartStepCountGenerator();
} else {
// This should never happen
}
},
fieldNumberCollectionGeneratorStateChanged: (
state,
action: PayloadAction<{
nodeId: string;
fieldName: string;
generatorState: FloatRangeStartStepCountGenerator | IntegerRangeStartStepCountGenerator;
}>
) => {
const { nodeId, fieldName, generatorState } = action.payload;
const field = selectField(state, nodeId, fieldName);
if (!field) {
return;
}
if (
isFloatFieldCollectionInputInstance(field) &&
generatorState.type === 'float-range-generator-start-step-count'
) {
field.generator = generatorState;
} else if (
isIntegerFieldCollectionInputInstance(field) &&
generatorState.type === 'integer-range-generator-start-step-count'
) {
field.generator = generatorState;
} else {
// This should never happen
}
},
fieldNumberCollectionGeneratorCommitted: (state, action: PayloadAction<{ nodeId: string; fieldName: string }>) => {
const { nodeId, fieldName } = action.payload;
const field = selectField(state, nodeId, fieldName);
if (!field) {
return;
}
if (
isFloatFieldCollectionInputInstance(field) &&
field.generator &&
field.generator.type === 'float-range-generator-start-step-count'
) {
field.value = floatRangeStartStepCountGenerator(field.generator);
field.generator = undefined;
} else if (
isIntegerFieldCollectionInputInstance(field) &&
field.generator &&
field.generator.type === 'integer-range-generator-start-step-count'
) {
field.value = integerRangeStartStepCountGenerator(field.generator);
field.generator = undefined;
} else {
// This should never happen
}
},
fieldNumberCollectionLockLinearViewToggled: (
state,
action: PayloadAction<{ nodeId: string; fieldName: string }>
) => {
const { nodeId, fieldName } = action.payload;
const field = selectField(state, nodeId, fieldName);
if (!field) {
return;
}
if (!isFloatFieldCollectionInputInstance(field) && !isIntegerFieldCollectionInputInstance(field)) {
return;
}
field.lockLinearView = !field.lockLinearView;
},
fieldBooleanValueChanged: (state, action: FieldValueAction<BooleanFieldValue>) => {
fieldValueReducer(state, action, zBooleanFieldValue);
},
@@ -578,15 +435,9 @@ export const {
fieldModelIdentifierValueChanged,
fieldMainModelValueChanged,
fieldNumberValueChanged,
fieldNumberCollectionValueChanged,
fieldNumberCollectionGeneratorToggled,
fieldNumberCollectionGeneratorStateChanged,
fieldNumberCollectionGeneratorCommitted,
fieldNumberCollectionLockLinearViewToggled,
fieldRefinerModelValueChanged,
fieldSchedulerValueChanged,
fieldStringValueChanged,
fieldStringCollectionValueChanged,
fieldVaeModelValueChanged,
fieldT5EncoderValueChanged,
fieldCLIPEmbedValueChanged,
@@ -695,11 +546,9 @@ export const isAnyNodeOrEdgeMutation = isAnyOf(
fieldLoRAModelValueChanged,
fieldMainModelValueChanged,
fieldNumberValueChanged,
fieldNumberCollectionValueChanged,
fieldRefinerModelValueChanged,
fieldSchedulerValueChanged,
fieldStringValueChanged,
fieldStringCollectionValueChanged,
fieldVaeModelValueChanged,
fieldT5EncoderValueChanged,
fieldCLIPEmbedValueChanged,

View File

@@ -1,8 +1,8 @@
import type {
FieldIdentifier,
FieldInputInstance,
FieldInputTemplate,
FieldOutputTemplate,
StatefulFieldValue,
} from 'features/nodes/types/field';
import type {
AnyNode,
@@ -31,15 +31,15 @@ export type NodesState = {
};
export type WorkflowMode = 'edit' | 'view';
export type FieldIdentifierWithInstance = FieldIdentifier & {
field: FieldInputInstance;
export type FieldIdentifierWithValue = FieldIdentifier & {
value: StatefulFieldValue;
};
export type WorkflowsState = Omit<WorkflowV3, 'nodes' | 'edges'> & {
_version: 2;
_version: 1;
isTouched: boolean;
mode: WorkflowMode;
originalExposedFieldValues: FieldIdentifierWithInstance[];
originalExposedFieldValues: FieldIdentifierWithValue[];
searchTerm: string;
orderBy?: WorkflowRecordOrderBy;
orderDirection: SQLiteDirection;

View File

@@ -5,7 +5,7 @@ import { deepClone } from 'common/util/deepClone';
import { workflowLoaded } from 'features/nodes/store/actions';
import { isAnyNodeOrEdgeMutation, nodeEditorReset, nodesChanged } from 'features/nodes/store/nodesSlice';
import type {
FieldIdentifierWithInstance,
FieldIdentifierWithValue,
WorkflowMode,
WorkflowsState as WorkflowState,
} from 'features/nodes/store/types';
@@ -31,7 +31,7 @@ const blankWorkflow: Omit<WorkflowV3, 'nodes' | 'edges'> = {
};
const initialWorkflowState: WorkflowState = {
_version: 2,
_version: 1,
isTouched: false,
mode: 'view',
originalExposedFieldValues: [],
@@ -62,7 +62,7 @@ export const workflowSlice = createSlice({
const { id, isOpen } = action.payload;
state.categorySections[id] = isOpen;
},
workflowExposedFieldAdded: (state, action: PayloadAction<FieldIdentifierWithInstance>) => {
workflowExposedFieldAdded: (state, action: PayloadAction<FieldIdentifierWithValue>) => {
state.exposedFields = uniqBy(
state.exposedFields.concat(omit(action.payload, 'value')),
(field) => `${field.nodeId}-${field.fieldName}`
@@ -128,25 +128,25 @@ export const workflowSlice = createSlice({
builder.addCase(workflowLoaded, (state, action) => {
const { nodes, edges: _edges, ...workflowExtra } = action.payload;
const originalExposedFieldValues: FieldIdentifierWithInstance[] = [];
const originalExposedFieldValues: FieldIdentifierWithValue[] = [];
workflowExtra.exposedFields.forEach(({ nodeId, fieldName }) => {
const node = nodes.find((n) => n.id === nodeId);
workflowExtra.exposedFields.forEach((field) => {
const node = nodes.find((n) => n.id === field.nodeId);
if (!isInvocationNode(node)) {
return;
}
const field = node.data.inputs[fieldName];
const input = node.data.inputs[field.fieldName];
if (!field) {
if (!input) {
return;
}
const originalExposedFieldValue = {
nodeId,
fieldName,
field,
nodeId: field.nodeId,
fieldName: field.fieldName,
value: input.value,
};
originalExposedFieldValues.push(originalExposedFieldValue);
});
@@ -243,9 +243,6 @@ const migrateWorkflowState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
if (state._version === 1) {
return deepClone(initialWorkflowState);
}
return state;
};

View File

@@ -1,8 +1,3 @@
import {
zFloatRangeStartStepCountGenerator,
zIntegerRangeStartStepCountGenerator,
} from 'features/nodes/types/generators';
import { buildTypeGuard } from 'features/parameters/types/parameterSchemas';
import { z } from 'zod';
import { zBoardField, zColorField, zImageField, zModelIdentifierField, zSchedulerField } from './common';
@@ -83,35 +78,14 @@ const zIntegerFieldType = zFieldTypeBase.extend({
name: z.literal('IntegerField'),
originalType: zStatelessFieldType.optional(),
});
const zIntegerCollectionFieldType = z.object({
name: z.literal('IntegerField'),
cardinality: z.literal(COLLECTION),
originalType: zStatelessFieldType.optional(),
});
export const isIntegerCollectionFieldType = buildTypeGuard(zIntegerCollectionFieldType);
const zFloatFieldType = zFieldTypeBase.extend({
name: z.literal('FloatField'),
originalType: zStatelessFieldType.optional(),
});
const zFloatCollectionFieldType = z.object({
name: z.literal('FloatField'),
cardinality: z.literal(COLLECTION),
originalType: zStatelessFieldType.optional(),
});
export const isFloatCollectionFieldType = buildTypeGuard(zFloatCollectionFieldType);
const zStringFieldType = zFieldTypeBase.extend({
name: z.literal('StringField'),
originalType: zStatelessFieldType.optional(),
});
const zStringCollectionFieldType = z.object({
name: z.literal('StringField'),
cardinality: z.literal(COLLECTION),
originalType: zStatelessFieldType.optional(),
});
export const isStringCollectionFieldType = buildTypeGuard(zStringCollectionFieldType);
const zBooleanFieldType = zFieldTypeBase.extend({
name: z.literal('BooleanField'),
originalType: zStatelessFieldType.optional(),
@@ -129,7 +103,9 @@ const zImageCollectionFieldType = z.object({
cardinality: z.literal(COLLECTION),
originalType: zStatelessFieldType.optional(),
});
export const isImageCollectionFieldType = buildTypeGuard(zImageCollectionFieldType);
export const isImageCollectionFieldType = (
fieldType: FieldType
): fieldType is z.infer<typeof zImageCollectionFieldType> => zImageCollectionFieldType.safeParse(fieldType).success;
const zBoardFieldType = zFieldTypeBase.extend({
name: z.literal('BoardField'),
originalType: zStatelessFieldType.optional(),
@@ -278,48 +254,10 @@ const zIntegerFieldOutputTemplate = zFieldOutputTemplateBase.extend({
export type IntegerFieldValue = z.infer<typeof zIntegerFieldValue>;
export type IntegerFieldInputInstance = z.infer<typeof zIntegerFieldInputInstance>;
export type IntegerFieldInputTemplate = z.infer<typeof zIntegerFieldInputTemplate>;
export const isIntegerFieldInputInstance = buildTypeGuard(zIntegerFieldInputInstance);
export const isIntegerFieldInputTemplate = buildTypeGuard(zIntegerFieldInputTemplate);
// #endregion
// #region IntegerField Collection
export const zIntegerFieldCollectionValue = z.array(zIntegerFieldValue).optional();
const zIntegerFieldCollectionInputInstance = zFieldInputInstanceBase.extend({
value: zIntegerFieldCollectionValue,
generator: zIntegerRangeStartStepCountGenerator.optional(),
lockLinearView: z.boolean().default(false),
});
const zIntegerFieldCollectionInputTemplate = zFieldInputTemplateBase
.extend({
type: zIntegerCollectionFieldType,
originalType: zFieldType.optional(),
default: zIntegerFieldCollectionValue,
maxItems: z.number().int().gte(0).optional(),
minItems: z.number().int().gte(0).optional(),
multipleOf: z.number().int().optional(),
maximum: z.number().int().optional(),
exclusiveMaximum: z.number().int().optional(),
minimum: z.number().int().optional(),
exclusiveMinimum: z.number().int().optional(),
})
.refine(
(val) => {
if (val.maxItems !== undefined && val.minItems !== undefined) {
return val.maxItems >= val.minItems;
}
return true;
},
{ message: 'maxItems must be greater than or equal to minItems' }
);
const zIntegerFieldCollectionOutputTemplate = zFieldOutputTemplateBase.extend({
type: zIntegerCollectionFieldType,
});
export type IntegerFieldCollectionValue = z.infer<typeof zIntegerFieldCollectionValue>;
export type IntegerFieldCollectionInputInstance = z.infer<typeof zIntegerFieldCollectionInputInstance>;
export type IntegerFieldCollectionInputTemplate = z.infer<typeof zIntegerFieldCollectionInputTemplate>;
export const isIntegerFieldCollectionInputInstance = buildTypeGuard(zIntegerFieldCollectionInputInstance);
export const isIntegerFieldCollectionInputTemplate = buildTypeGuard(zIntegerFieldCollectionInputTemplate);
export const isIntegerFieldInputInstance = (val: unknown): val is IntegerFieldInputInstance =>
zIntegerFieldInputInstance.safeParse(val).success;
export const isIntegerFieldInputTemplate = (val: unknown): val is IntegerFieldInputTemplate =>
zIntegerFieldInputTemplate.safeParse(val).success;
// #endregion
// #region FloatField
@@ -344,48 +282,10 @@ const zFloatFieldOutputTemplate = zFieldOutputTemplateBase.extend({
export type FloatFieldValue = z.infer<typeof zFloatFieldValue>;
export type FloatFieldInputInstance = z.infer<typeof zFloatFieldInputInstance>;
export type FloatFieldInputTemplate = z.infer<typeof zFloatFieldInputTemplate>;
export const isFloatFieldInputInstance = buildTypeGuard(zFloatFieldInputInstance);
export const isFloatFieldInputTemplate = buildTypeGuard(zFloatFieldInputTemplate);
// #endregion
// #region FloatField Collection
export const zFloatFieldCollectionValue = z.array(zFloatFieldValue).optional();
const zFloatFieldCollectionInputInstance = zFieldInputInstanceBase.extend({
value: zFloatFieldCollectionValue,
generator: zFloatRangeStartStepCountGenerator.optional(),
lockLinearView: z.boolean().default(false),
});
const zFloatFieldCollectionInputTemplate = zFieldInputTemplateBase
.extend({
type: zFloatCollectionFieldType,
originalType: zFieldType.optional(),
default: zFloatFieldCollectionValue,
maxItems: z.number().int().gte(0).optional(),
minItems: z.number().int().gte(0).optional(),
multipleOf: z.number().int().optional(),
maximum: z.number().optional(),
exclusiveMaximum: z.number().optional(),
minimum: z.number().optional(),
exclusiveMinimum: z.number().optional(),
})
.refine(
(val) => {
if (val.maxItems !== undefined && val.minItems !== undefined) {
return val.maxItems >= val.minItems;
}
return true;
},
{ message: 'maxItems must be greater than or equal to minItems' }
);
const zFloatFieldCollectionOutputTemplate = zFieldOutputTemplateBase.extend({
type: zFloatCollectionFieldType,
});
export type FloatFieldCollectionInputInstance = z.infer<typeof zFloatFieldCollectionInputInstance>;
export type FloatFieldCollectionInputTemplate = z.infer<typeof zFloatFieldCollectionInputTemplate>;
export const isFloatFieldCollectionInputInstance = buildTypeGuard(zFloatFieldCollectionInputInstance);
export const isFloatFieldCollectionInputTemplate = buildTypeGuard(zFloatFieldCollectionInputTemplate);
export const isFloatFieldInputInstance = (val: unknown): val is FloatFieldInputInstance =>
zFloatFieldInputInstance.safeParse(val).success;
export const isFloatFieldInputTemplate = (val: unknown): val is FloatFieldInputTemplate =>
zFloatFieldInputTemplate.safeParse(val).success;
// #endregion
// #region StringField
@@ -415,55 +315,13 @@ const zStringFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zStringFieldType,
});
// #region StringField Collection
export const zStringFieldCollectionValue = z.array(zStringFieldValue).optional();
const zStringFieldCollectionInputInstance = zFieldInputInstanceBase.extend({
value: zStringFieldCollectionValue,
});
const zStringFieldCollectionInputTemplate = zFieldInputTemplateBase
.extend({
type: zStringCollectionFieldType,
originalType: zFieldType.optional(),
default: zStringFieldCollectionValue,
maxLength: z.number().int().gte(0).optional(),
minLength: z.number().int().gte(0).optional(),
maxItems: z.number().int().gte(0).optional(),
minItems: z.number().int().gte(0).optional(),
})
.refine(
(val) => {
if (val.maxLength !== undefined && val.minLength !== undefined) {
return val.maxLength >= val.minLength;
}
return true;
},
{ message: 'maxLength must be greater than or equal to minLength' }
)
.refine(
(val) => {
if (val.maxItems !== undefined && val.minItems !== undefined) {
return val.maxItems >= val.minItems;
}
return true;
},
{ message: 'maxItems must be greater than or equal to minItems' }
);
const zStringFieldCollectionOutputTemplate = zFieldOutputTemplateBase.extend({
type: zStringCollectionFieldType,
});
export type StringFieldCollectionValue = z.infer<typeof zStringFieldCollectionValue>;
export type StringFieldCollectionInputInstance = z.infer<typeof zStringFieldCollectionInputInstance>;
export type StringFieldCollectionInputTemplate = z.infer<typeof zStringFieldCollectionInputTemplate>;
export const isStringFieldCollectionInputInstance = buildTypeGuard(zStringFieldCollectionInputInstance);
export const isStringFieldCollectionInputTemplate = buildTypeGuard(zStringFieldCollectionInputTemplate);
// #endregion
export type StringFieldValue = z.infer<typeof zStringFieldValue>;
export type StringFieldInputInstance = z.infer<typeof zStringFieldInputInstance>;
export type StringFieldInputTemplate = z.infer<typeof zStringFieldInputTemplate>;
export const isStringFieldInputInstance = buildTypeGuard(zStringFieldInputInstance);
export const isStringFieldInputTemplate = buildTypeGuard(zStringFieldInputTemplate);
export const isStringFieldInputInstance = (val: unknown): val is StringFieldInputInstance =>
zStringFieldInputInstance.safeParse(val).success;
export const isStringFieldInputTemplate = (val: unknown): val is StringFieldInputTemplate =>
zStringFieldInputTemplate.safeParse(val).success;
// #endregion
// #region BooleanField
@@ -483,8 +341,10 @@ const zBooleanFieldOutputTemplate = zFieldOutputTemplateBase.extend({
export type BooleanFieldValue = z.infer<typeof zBooleanFieldValue>;
export type BooleanFieldInputInstance = z.infer<typeof zBooleanFieldInputInstance>;
export type BooleanFieldInputTemplate = z.infer<typeof zBooleanFieldInputTemplate>;
export const isBooleanFieldInputInstance = buildTypeGuard(zBooleanFieldInputInstance);
export const isBooleanFieldInputTemplate = buildTypeGuard(zBooleanFieldInputTemplate);
export const isBooleanFieldInputInstance = (val: unknown): val is BooleanFieldInputInstance =>
zBooleanFieldInputInstance.safeParse(val).success;
export const isBooleanFieldInputTemplate = (val: unknown): val is BooleanFieldInputTemplate =>
zBooleanFieldInputTemplate.safeParse(val).success;
// #endregion
// #region EnumField
@@ -506,8 +366,10 @@ const zEnumFieldOutputTemplate = zFieldOutputTemplateBase.extend({
export type EnumFieldValue = z.infer<typeof zEnumFieldValue>;
export type EnumFieldInputInstance = z.infer<typeof zEnumFieldInputInstance>;
export type EnumFieldInputTemplate = z.infer<typeof zEnumFieldInputTemplate>;
export const isEnumFieldInputInstance = buildTypeGuard(zEnumFieldInputInstance);
export const isEnumFieldInputTemplate = buildTypeGuard(zEnumFieldInputTemplate);
export const isEnumFieldInputInstance = (val: unknown): val is EnumFieldInputInstance =>
zEnumFieldInputInstance.safeParse(val).success;
export const isEnumFieldInputTemplate = (val: unknown): val is EnumFieldInputTemplate =>
zEnumFieldInputTemplate.safeParse(val).success;
// #endregion
// #region ImageField
@@ -526,8 +388,10 @@ const zImageFieldOutputTemplate = zFieldOutputTemplateBase.extend({
export type ImageFieldValue = z.infer<typeof zImageFieldValue>;
export type ImageFieldInputInstance = z.infer<typeof zImageFieldInputInstance>;
export type ImageFieldInputTemplate = z.infer<typeof zImageFieldInputTemplate>;
export const isImageFieldInputInstance = buildTypeGuard(zImageFieldInputInstance);
export const isImageFieldInputTemplate = buildTypeGuard(zImageFieldInputTemplate);
export const isImageFieldInputInstance = (val: unknown): val is ImageFieldInputInstance =>
zImageFieldInputInstance.safeParse(val).success;
export const isImageFieldInputTemplate = (val: unknown): val is ImageFieldInputTemplate =>
zImageFieldInputTemplate.safeParse(val).success;
// #endregion
// #region ImageField Collection
@@ -550,7 +414,7 @@ const zImageFieldCollectionInputTemplate = zFieldInputTemplateBase
}
return true;
},
{ message: 'maxItems must be greater than or equal to minItems' }
{ message: 'maxLength must be greater than or equal to minLength' }
);
const zImageFieldCollectionOutputTemplate = zFieldOutputTemplateBase.extend({
@@ -559,8 +423,10 @@ const zImageFieldCollectionOutputTemplate = zFieldOutputTemplateBase.extend({
export type ImageFieldCollectionValue = z.infer<typeof zImageFieldCollectionValue>;
export type ImageFieldCollectionInputInstance = z.infer<typeof zImageFieldCollectionInputInstance>;
export type ImageFieldCollectionInputTemplate = z.infer<typeof zImageFieldCollectionInputTemplate>;
export const isImageFieldCollectionInputInstance = buildTypeGuard(zImageFieldCollectionInputInstance);
export const isImageFieldCollectionInputTemplate = buildTypeGuard(zImageFieldCollectionInputTemplate);
export const isImageFieldCollectionInputInstance = (val: unknown): val is ImageFieldCollectionInputInstance =>
zImageFieldCollectionInputInstance.safeParse(val).success;
export const isImageFieldCollectionInputTemplate = (val: unknown): val is ImageFieldCollectionInputTemplate =>
zImageFieldCollectionInputTemplate.safeParse(val).success;
// #endregion
// #region BoardField
@@ -580,8 +446,10 @@ const zBoardFieldOutputTemplate = zFieldOutputTemplateBase.extend({
export type BoardFieldValue = z.infer<typeof zBoardFieldValue>;
export type BoardFieldInputInstance = z.infer<typeof zBoardFieldInputInstance>;
export type BoardFieldInputTemplate = z.infer<typeof zBoardFieldInputTemplate>;
export const isBoardFieldInputInstance = buildTypeGuard(zBoardFieldInputInstance);
export const isBoardFieldInputTemplate = buildTypeGuard(zBoardFieldInputTemplate);
export const isBoardFieldInputInstance = (val: unknown): val is BoardFieldInputInstance =>
zBoardFieldInputInstance.safeParse(val).success;
export const isBoardFieldInputTemplate = (val: unknown): val is BoardFieldInputTemplate =>
zBoardFieldInputTemplate.safeParse(val).success;
// #endregion
// #region ColorField
@@ -601,8 +469,10 @@ const zColorFieldOutputTemplate = zFieldOutputTemplateBase.extend({
export type ColorFieldValue = z.infer<typeof zColorFieldValue>;
export type ColorFieldInputInstance = z.infer<typeof zColorFieldInputInstance>;
export type ColorFieldInputTemplate = z.infer<typeof zColorFieldInputTemplate>;
export const isColorFieldInputInstance = buildTypeGuard(zColorFieldInputInstance);
export const isColorFieldInputTemplate = buildTypeGuard(zColorFieldInputTemplate);
export const isColorFieldInputInstance = (val: unknown): val is ColorFieldInputInstance =>
zColorFieldInputInstance.safeParse(val).success;
export const isColorFieldInputTemplate = (val: unknown): val is ColorFieldInputTemplate =>
zColorFieldInputTemplate.safeParse(val).success;
// #endregion
// #region MainModelField
@@ -622,8 +492,10 @@ const zMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
export type MainModelFieldValue = z.infer<typeof zMainModelFieldValue>;
export type MainModelFieldInputInstance = z.infer<typeof zMainModelFieldInputInstance>;
export type MainModelFieldInputTemplate = z.infer<typeof zMainModelFieldInputTemplate>;
export const isMainModelFieldInputInstance = buildTypeGuard(zMainModelFieldInputInstance);
export const isMainModelFieldInputTemplate = buildTypeGuard(zMainModelFieldInputTemplate);
export const isMainModelFieldInputInstance = (val: unknown): val is MainModelFieldInputInstance =>
zMainModelFieldInputInstance.safeParse(val).success;
export const isMainModelFieldInputTemplate = (val: unknown): val is MainModelFieldInputTemplate =>
zMainModelFieldInputTemplate.safeParse(val).success;
// #endregion
// #region ModelIdentifierField
@@ -642,8 +514,10 @@ const zModelIdentifierFieldOutputTemplate = zFieldOutputTemplateBase.extend({
export type ModelIdentifierFieldValue = z.infer<typeof zModelIdentifierFieldValue>;
export type ModelIdentifierFieldInputInstance = z.infer<typeof zModelIdentifierFieldInputInstance>;
export type ModelIdentifierFieldInputTemplate = z.infer<typeof zModelIdentifierFieldInputTemplate>;
export const isModelIdentifierFieldInputInstance = buildTypeGuard(zModelIdentifierFieldInputInstance);
export const isModelIdentifierFieldInputTemplate = buildTypeGuard(zModelIdentifierFieldInputTemplate);
export const isModelIdentifierFieldInputInstance = (val: unknown): val is ModelIdentifierFieldInputInstance =>
zModelIdentifierFieldInputInstance.safeParse(val).success;
export const isModelIdentifierFieldInputTemplate = (val: unknown): val is ModelIdentifierFieldInputTemplate =>
zModelIdentifierFieldInputTemplate.safeParse(val).success;
// #endregion
// #region SDXLMainModelField
@@ -662,8 +536,10 @@ const zSDXLMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
});
export type SDXLMainModelFieldInputInstance = z.infer<typeof zSDXLMainModelFieldInputInstance>;
export type SDXLMainModelFieldInputTemplate = z.infer<typeof zSDXLMainModelFieldInputTemplate>;
export const isSDXLMainModelFieldInputInstance = buildTypeGuard(zSDXLMainModelFieldInputInstance);
export const isSDXLMainModelFieldInputTemplate = buildTypeGuard(zSDXLMainModelFieldInputTemplate);
export const isSDXLMainModelFieldInputInstance = (val: unknown): val is SDXLMainModelFieldInputInstance =>
zSDXLMainModelFieldInputInstance.safeParse(val).success;
export const isSDXLMainModelFieldInputTemplate = (val: unknown): val is SDXLMainModelFieldInputTemplate =>
zSDXLMainModelFieldInputTemplate.safeParse(val).success;
// #endregion
// #region SD3MainModelField
@@ -682,8 +558,10 @@ const zSD3MainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
});
export type SD3MainModelFieldInputInstance = z.infer<typeof zSD3MainModelFieldInputInstance>;
export type SD3MainModelFieldInputTemplate = z.infer<typeof zSD3MainModelFieldInputTemplate>;
export const isSD3MainModelFieldInputInstance = buildTypeGuard(zSD3MainModelFieldInputInstance);
export const isSD3MainModelFieldInputTemplate = buildTypeGuard(zSD3MainModelFieldInputTemplate);
export const isSD3MainModelFieldInputInstance = (val: unknown): val is SD3MainModelFieldInputInstance =>
zSD3MainModelFieldInputInstance.safeParse(val).success;
export const isSD3MainModelFieldInputTemplate = (val: unknown): val is SD3MainModelFieldInputTemplate =>
zSD3MainModelFieldInputTemplate.safeParse(val).success;
// #endregion
@@ -703,8 +581,10 @@ const zFluxMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
});
export type FluxMainModelFieldInputInstance = z.infer<typeof zFluxMainModelFieldInputInstance>;
export type FluxMainModelFieldInputTemplate = z.infer<typeof zFluxMainModelFieldInputTemplate>;
export const isFluxMainModelFieldInputInstance = buildTypeGuard(zFluxMainModelFieldInputInstance);
export const isFluxMainModelFieldInputTemplate = buildTypeGuard(zFluxMainModelFieldInputTemplate);
export const isFluxMainModelFieldInputInstance = (val: unknown): val is FluxMainModelFieldInputInstance =>
zFluxMainModelFieldInputInstance.safeParse(val).success;
export const isFluxMainModelFieldInputTemplate = (val: unknown): val is FluxMainModelFieldInputTemplate =>
zFluxMainModelFieldInputTemplate.safeParse(val).success;
// #endregion
@@ -726,8 +606,10 @@ const zSDXLRefinerModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
export type SDXLRefinerModelFieldValue = z.infer<typeof zSDXLRefinerModelFieldValue>;
export type SDXLRefinerModelFieldInputInstance = z.infer<typeof zSDXLRefinerModelFieldInputInstance>;
export type SDXLRefinerModelFieldInputTemplate = z.infer<typeof zSDXLRefinerModelFieldInputTemplate>;
export const isSDXLRefinerModelFieldInputInstance = buildTypeGuard(zSDXLRefinerModelFieldInputInstance);
export const isSDXLRefinerModelFieldInputTemplate = buildTypeGuard(zSDXLRefinerModelFieldInputTemplate);
export const isSDXLRefinerModelFieldInputInstance = (val: unknown): val is SDXLRefinerModelFieldInputInstance =>
zSDXLRefinerModelFieldInputInstance.safeParse(val).success;
export const isSDXLRefinerModelFieldInputTemplate = (val: unknown): val is SDXLRefinerModelFieldInputTemplate =>
zSDXLRefinerModelFieldInputTemplate.safeParse(val).success;
// #endregion
// #region VAEModelField
@@ -747,8 +629,10 @@ const zVAEModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
export type VAEModelFieldValue = z.infer<typeof zVAEModelFieldValue>;
export type VAEModelFieldInputInstance = z.infer<typeof zVAEModelFieldInputInstance>;
export type VAEModelFieldInputTemplate = z.infer<typeof zVAEModelFieldInputTemplate>;
export const isVAEModelFieldInputInstance = buildTypeGuard(zVAEModelFieldInputInstance);
export const isVAEModelFieldInputTemplate = buildTypeGuard(zVAEModelFieldInputTemplate);
export const isVAEModelFieldInputInstance = (val: unknown): val is VAEModelFieldInputInstance =>
zVAEModelFieldInputInstance.safeParse(val).success;
export const isVAEModelFieldInputTemplate = (val: unknown): val is VAEModelFieldInputTemplate =>
zVAEModelFieldInputTemplate.safeParse(val).success;
// #endregion
// #region LoRAModelField
@@ -768,8 +652,10 @@ const zLoRAModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
export type LoRAModelFieldValue = z.infer<typeof zLoRAModelFieldValue>;
export type LoRAModelFieldInputInstance = z.infer<typeof zLoRAModelFieldInputInstance>;
export type LoRAModelFieldInputTemplate = z.infer<typeof zLoRAModelFieldInputTemplate>;
export const isLoRAModelFieldInputInstance = buildTypeGuard(zLoRAModelFieldInputInstance);
export const isLoRAModelFieldInputTemplate = buildTypeGuard(zLoRAModelFieldInputTemplate);
export const isLoRAModelFieldInputInstance = (val: unknown): val is LoRAModelFieldInputInstance =>
zLoRAModelFieldInputInstance.safeParse(val).success;
export const isLoRAModelFieldInputTemplate = (val: unknown): val is LoRAModelFieldInputTemplate =>
zLoRAModelFieldInputTemplate.safeParse(val).success;
// #endregion
// #region ControlNetModelField
@@ -789,8 +675,10 @@ const zControlNetModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
export type ControlNetModelFieldValue = z.infer<typeof zControlNetModelFieldValue>;
export type ControlNetModelFieldInputInstance = z.infer<typeof zControlNetModelFieldInputInstance>;
export type ControlNetModelFieldInputTemplate = z.infer<typeof zControlNetModelFieldInputTemplate>;
export const isControlNetModelFieldInputInstance = buildTypeGuard(zControlNetModelFieldInputInstance);
export const isControlNetModelFieldInputTemplate = buildTypeGuard(zControlNetModelFieldInputTemplate);
export const isControlNetModelFieldInputInstance = (val: unknown): val is ControlNetModelFieldInputInstance =>
zControlNetModelFieldInputInstance.safeParse(val).success;
export const isControlNetModelFieldInputTemplate = (val: unknown): val is ControlNetModelFieldInputTemplate =>
zControlNetModelFieldInputTemplate.safeParse(val).success;
// #endregion
// #region IPAdapterModelField
@@ -810,8 +698,10 @@ const zIPAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
export type IPAdapterModelFieldValue = z.infer<typeof zIPAdapterModelFieldValue>;
export type IPAdapterModelFieldInputInstance = z.infer<typeof zIPAdapterModelFieldInputInstance>;
export type IPAdapterModelFieldInputTemplate = z.infer<typeof zIPAdapterModelFieldInputTemplate>;
export const isIPAdapterModelFieldInputInstance = buildTypeGuard(zIPAdapterModelFieldInputInstance);
export const isIPAdapterModelFieldInputTemplate = buildTypeGuard(zIPAdapterModelFieldInputTemplate);
export const isIPAdapterModelFieldInputInstance = (val: unknown): val is IPAdapterModelFieldInputInstance =>
zIPAdapterModelFieldInputInstance.safeParse(val).success;
export const isIPAdapterModelFieldInputTemplate = (val: unknown): val is IPAdapterModelFieldInputTemplate =>
zIPAdapterModelFieldInputTemplate.safeParse(val).success;
// #endregion
// #region T2IAdapterField
@@ -831,8 +721,10 @@ const zT2IAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
export type T2IAdapterModelFieldValue = z.infer<typeof zT2IAdapterModelFieldValue>;
export type T2IAdapterModelFieldInputInstance = z.infer<typeof zT2IAdapterModelFieldInputInstance>;
export type T2IAdapterModelFieldInputTemplate = z.infer<typeof zT2IAdapterModelFieldInputTemplate>;
export const isT2IAdapterModelFieldInputInstance = buildTypeGuard(zT2IAdapterModelFieldInputInstance);
export const isT2IAdapterModelFieldInputTemplate = buildTypeGuard(zT2IAdapterModelFieldInputTemplate);
export const isT2IAdapterModelFieldInputInstance = (val: unknown): val is T2IAdapterModelFieldInputInstance =>
zT2IAdapterModelFieldInputInstance.safeParse(val).success;
export const isT2IAdapterModelFieldInputTemplate = (val: unknown): val is T2IAdapterModelFieldInputTemplate =>
zT2IAdapterModelFieldInputTemplate.safeParse(val).success;
// #endregion
// #region SpandrelModelToModelField
@@ -852,12 +744,14 @@ const zSpandrelImageToImageModelFieldOutputTemplate = zFieldOutputTemplateBase.e
export type SpandrelImageToImageModelFieldValue = z.infer<typeof zSpandrelImageToImageModelFieldValue>;
export type SpandrelImageToImageModelFieldInputInstance = z.infer<typeof zSpandrelImageToImageModelFieldInputInstance>;
export type SpandrelImageToImageModelFieldInputTemplate = z.infer<typeof zSpandrelImageToImageModelFieldInputTemplate>;
export const isSpandrelImageToImageModelFieldInputInstance = buildTypeGuard(
zSpandrelImageToImageModelFieldInputInstance
);
export const isSpandrelImageToImageModelFieldInputTemplate = buildTypeGuard(
zSpandrelImageToImageModelFieldInputTemplate
);
export const isSpandrelImageToImageModelFieldInputInstance = (
val: unknown
): val is SpandrelImageToImageModelFieldInputInstance =>
zSpandrelImageToImageModelFieldInputInstance.safeParse(val).success;
export const isSpandrelImageToImageModelFieldInputTemplate = (
val: unknown
): val is SpandrelImageToImageModelFieldInputTemplate =>
zSpandrelImageToImageModelFieldInputTemplate.safeParse(val).success;
// #endregion
// #region T5EncoderModelField
@@ -876,8 +770,10 @@ export type T5EncoderModelFieldValue = z.infer<typeof zT5EncoderModelFieldValue>
export type T5EncoderModelFieldInputInstance = z.infer<typeof zT5EncoderModelFieldInputInstance>;
export type T5EncoderModelFieldInputTemplate = z.infer<typeof zT5EncoderModelFieldInputTemplate>;
export const isT5EncoderModelFieldInputInstance = buildTypeGuard(zT5EncoderModelFieldInputInstance);
export const isT5EncoderModelFieldInputTemplate = buildTypeGuard(zT5EncoderModelFieldInputTemplate);
export const isT5EncoderModelFieldInputInstance = (val: unknown): val is T5EncoderModelFieldInputInstance =>
zT5EncoderModelFieldInputInstance.safeParse(val).success;
export const isT5EncoderModelFieldInputTemplate = (val: unknown): val is T5EncoderModelFieldInputTemplate =>
zT5EncoderModelFieldInputTemplate.safeParse(val).success;
// #endregion
@@ -897,8 +793,10 @@ export type FluxVAEModelFieldValue = z.infer<typeof zFluxVAEModelFieldValue>;
export type FluxVAEModelFieldInputInstance = z.infer<typeof zFluxVAEModelFieldInputInstance>;
export type FluxVAEModelFieldInputTemplate = z.infer<typeof zFluxVAEModelFieldInputTemplate>;
export const isFluxVAEModelFieldInputInstance = buildTypeGuard(zFluxVAEModelFieldInputInstance);
export const isFluxVAEModelFieldInputTemplate = buildTypeGuard(zFluxVAEModelFieldInputTemplate);
export const isFluxVAEModelFieldInputInstance = (val: unknown): val is FluxVAEModelFieldInputInstance =>
zFluxVAEModelFieldInputInstance.safeParse(val).success;
export const isFluxVAEModelFieldInputTemplate = (val: unknown): val is FluxVAEModelFieldInputTemplate =>
zFluxVAEModelFieldInputTemplate.safeParse(val).success;
// #endregion
@@ -918,8 +816,10 @@ export type CLIPEmbedModelFieldValue = z.infer<typeof zCLIPEmbedModelFieldValue>
export type CLIPEmbedModelFieldInputInstance = z.infer<typeof zCLIPEmbedModelFieldInputInstance>;
export type CLIPEmbedModelFieldInputTemplate = z.infer<typeof zCLIPEmbedModelFieldInputTemplate>;
export const isCLIPEmbedModelFieldInputInstance = buildTypeGuard(zCLIPEmbedModelFieldInputInstance);
export const isCLIPEmbedModelFieldInputTemplate = buildTypeGuard(zCLIPEmbedModelFieldInputTemplate);
export const isCLIPEmbedModelFieldInputInstance = (val: unknown): val is CLIPEmbedModelFieldInputInstance =>
zCLIPEmbedModelFieldInputInstance.safeParse(val).success;
export const isCLIPEmbedModelFieldInputTemplate = (val: unknown): val is CLIPEmbedModelFieldInputTemplate =>
zCLIPEmbedModelFieldInputTemplate.safeParse(val).success;
// #endregion
@@ -939,8 +839,10 @@ export type CLIPLEmbedModelFieldValue = z.infer<typeof zCLIPLEmbedModelFieldValu
export type CLIPLEmbedModelFieldInputInstance = z.infer<typeof zCLIPLEmbedModelFieldInputInstance>;
export type CLIPLEmbedModelFieldInputTemplate = z.infer<typeof zCLIPLEmbedModelFieldInputTemplate>;
export const isCLIPLEmbedModelFieldInputInstance = buildTypeGuard(zCLIPLEmbedModelFieldInputInstance);
export const isCLIPLEmbedModelFieldInputTemplate = buildTypeGuard(zCLIPLEmbedModelFieldInputTemplate);
export const isCLIPLEmbedModelFieldInputInstance = (val: unknown): val is CLIPLEmbedModelFieldInputInstance =>
zCLIPLEmbedModelFieldInputInstance.safeParse(val).success;
export const isCLIPLEmbedModelFieldInputTemplate = (val: unknown): val is CLIPLEmbedModelFieldInputTemplate =>
zCLIPLEmbedModelFieldInputTemplate.safeParse(val).success;
// #endregion
@@ -960,8 +862,10 @@ export type CLIPGEmbedModelFieldValue = z.infer<typeof zCLIPLEmbedModelFieldValu
export type CLIPGEmbedModelFieldInputInstance = z.infer<typeof zCLIPGEmbedModelFieldInputInstance>;
export type CLIPGEmbedModelFieldInputTemplate = z.infer<typeof zCLIPGEmbedModelFieldInputTemplate>;
export const isCLIPGEmbedModelFieldInputInstance = buildTypeGuard(zCLIPGEmbedModelFieldInputInstance);
export const isCLIPGEmbedModelFieldInputTemplate = buildTypeGuard(zCLIPGEmbedModelFieldInputTemplate);
export const isCLIPGEmbedModelFieldInputInstance = (val: unknown): val is CLIPGEmbedModelFieldInputInstance =>
zCLIPGEmbedModelFieldInputInstance.safeParse(val).success;
export const isCLIPGEmbedModelFieldInputTemplate = (val: unknown): val is CLIPGEmbedModelFieldInputTemplate =>
zCLIPGEmbedModelFieldInputTemplate.safeParse(val).success;
// #endregion
@@ -981,8 +885,10 @@ export type ControlLoRAModelFieldValue = z.infer<typeof zCLIPLEmbedModelFieldVal
export type ControlLoRAModelFieldInputInstance = z.infer<typeof zControlLoRAModelFieldInputInstance>;
export type ControlLoRAModelFieldInputTemplate = z.infer<typeof zControlLoRAModelFieldInputTemplate>;
export const isControlLoRAModelFieldInputInstance = buildTypeGuard(zControlLoRAModelFieldInputInstance);
export const isControlLoRAModelFieldInputTemplate = buildTypeGuard(zControlLoRAModelFieldInputTemplate);
export const isControlLoRAModelFieldInputInstance = (val: unknown): val is ControlLoRAModelFieldInputInstance =>
zControlLoRAModelFieldInputInstance.safeParse(val).success;
export const isControlLoRAModelFieldInputTemplate = (val: unknown): val is ControlLoRAModelFieldInputTemplate =>
zControlLoRAModelFieldInputTemplate.safeParse(val).success;
// #endregion
@@ -1003,8 +909,10 @@ const zSchedulerFieldOutputTemplate = zFieldOutputTemplateBase.extend({
export type SchedulerFieldValue = z.infer<typeof zSchedulerFieldValue>;
export type SchedulerFieldInputInstance = z.infer<typeof zSchedulerFieldInputInstance>;
export type SchedulerFieldInputTemplate = z.infer<typeof zSchedulerFieldInputTemplate>;
export const isSchedulerFieldInputInstance = buildTypeGuard(zSchedulerFieldInputInstance);
export const isSchedulerFieldInputTemplate = buildTypeGuard(zSchedulerFieldInputTemplate);
export const isSchedulerFieldInputInstance = (val: unknown): val is SchedulerFieldInputInstance =>
zSchedulerFieldInputInstance.safeParse(val).success;
export const isSchedulerFieldInputTemplate = (val: unknown): val is SchedulerFieldInputTemplate =>
zSchedulerFieldInputTemplate.safeParse(val).success;
// #endregion
// #region StatelessField
@@ -1055,11 +963,8 @@ export type StatelessFieldInputTemplate = z.infer<typeof zStatelessFieldInputTem
// #region StatefulFieldValue & FieldValue
export const zStatefulFieldValue = z.union([
zIntegerFieldValue,
zIntegerFieldCollectionValue,
zFloatFieldValue,
zFloatFieldCollectionValue,
zStringFieldValue,
zStringFieldCollectionValue,
zBooleanFieldValue,
zEnumFieldValue,
zImageFieldValue,
@@ -1095,11 +1000,8 @@ export type FieldValue = z.infer<typeof zFieldValue>;
// #region StatefulFieldInputInstance & FieldInputInstance
const zStatefulFieldInputInstance = z.union([
zIntegerFieldInputInstance,
zIntegerFieldCollectionInputInstance,
zFloatFieldInputInstance,
zFloatFieldCollectionInputInstance,
zStringFieldInputInstance,
zStringFieldCollectionInputInstance,
zBooleanFieldInputInstance,
zEnumFieldInputInstance,
zImageFieldInputInstance,
@@ -1126,17 +1028,15 @@ const zStatefulFieldInputInstance = z.union([
export const zFieldInputInstance = z.union([zStatefulFieldInputInstance, zStatelessFieldInputInstance]);
export type FieldInputInstance = z.infer<typeof zFieldInputInstance>;
export const isFieldInputInstance = buildTypeGuard(zFieldInputInstance);
export const isFieldInputInstance = (val: unknown): val is FieldInputInstance =>
zFieldInputInstance.safeParse(val).success;
// #endregion
// #region StatefulFieldInputTemplate & FieldInputTemplate
const zStatefulFieldInputTemplate = z.union([
zIntegerFieldInputTemplate,
zIntegerFieldCollectionInputTemplate,
zFloatFieldInputTemplate,
zFloatFieldCollectionInputTemplate,
zStringFieldInputTemplate,
zStringFieldCollectionInputTemplate,
zBooleanFieldInputTemplate,
zEnumFieldInputTemplate,
zImageFieldInputTemplate,
@@ -1167,17 +1067,15 @@ const zStatefulFieldInputTemplate = z.union([
export const zFieldInputTemplate = z.union([zStatefulFieldInputTemplate, zStatelessFieldInputTemplate]);
export type FieldInputTemplate = z.infer<typeof zFieldInputTemplate>;
export const isFieldInputTemplate = buildTypeGuard(zFieldInputTemplate);
export const isFieldInputTemplate = (val: unknown): val is FieldInputTemplate =>
zFieldInputTemplate.safeParse(val).success;
// #endregion
// #region StatefulFieldOutputTemplate & FieldOutputTemplate
const zStatefulFieldOutputTemplate = z.union([
zIntegerFieldOutputTemplate,
zIntegerFieldCollectionOutputTemplate,
zFloatFieldOutputTemplate,
zFloatFieldCollectionOutputTemplate,
zStringFieldOutputTemplate,
zStringFieldCollectionOutputTemplate,
zBooleanFieldOutputTemplate,
zEnumFieldOutputTemplate,
zImageFieldOutputTemplate,

View File

@@ -1,133 +0,0 @@
import type {
FloatFieldCollectionInputInstance,
FloatFieldCollectionInputTemplate,
ImageFieldCollectionInputTemplate,
ImageFieldCollectionValue,
IntegerFieldCollectionInputInstance,
IntegerFieldCollectionInputTemplate,
StringFieldCollectionInputTemplate,
StringFieldCollectionValue,
} from 'features/nodes/types/field';
import {
floatRangeStartStepCountGenerator,
integerRangeStartStepCountGenerator,
} from 'features/nodes/types/generators';
import { t } from 'i18next';
export const validateImageFieldCollectionValue = (
value: NonNullable<ImageFieldCollectionValue>,
template: ImageFieldCollectionInputTemplate
): string[] => {
const reasons: string[] = [];
const { minItems, maxItems } = template;
const count = value.length;
// Image collections may have min or max items to validate
if (minItems !== undefined && minItems > 0 && count === 0) {
reasons.push(t('parameters.invoke.collectionEmpty'));
}
if (minItems !== undefined && count < minItems) {
reasons.push(t('parameters.invoke.collectionTooFewItems', { count, minItems }));
}
if (maxItems !== undefined && count > maxItems) {
reasons.push(t('parameters.invoke.collectionTooManyItems', { count, maxItems }));
}
return reasons;
};
export const validateStringFieldCollectionValue = (
value: NonNullable<StringFieldCollectionValue>,
template: StringFieldCollectionInputTemplate
): string[] => {
const reasons: string[] = [];
const { minItems, maxItems, minLength, maxLength } = template;
const count = value.length;
// Image collections may have min or max items to validate
if (minItems !== undefined && minItems > 0 && count === 0) {
reasons.push(t('parameters.invoke.collectionEmpty'));
}
if (minItems !== undefined && count < minItems) {
reasons.push(t('parameters.invoke.collectionTooFewItems', { count, minItems }));
}
if (maxItems !== undefined && count > maxItems) {
reasons.push(t('parameters.invoke.collectionTooManyItems', { count, maxItems }));
}
for (const str of value) {
if (maxLength !== undefined && str.length > maxLength) {
reasons.push(t('parameters.invoke.collectionStringTooLong', { value, maxLength }));
}
if (minLength !== undefined && str.length < minLength) {
reasons.push(t('parameters.invoke.collectionStringTooShort', { value, minLength }));
}
}
return reasons;
};
export const resolveNumberFieldCollectionValue = (
field: IntegerFieldCollectionInputInstance | FloatFieldCollectionInputInstance
): number[] | undefined => {
if (field.generator?.type === 'float-range-generator-start-step-count') {
return floatRangeStartStepCountGenerator(field.generator);
} else if (field.generator?.type === 'integer-range-generator-start-step-count') {
return integerRangeStartStepCountGenerator(field.generator);
} else {
return field.value;
}
};
export const validateNumberFieldCollectionValue = (
field: IntegerFieldCollectionInputInstance | FloatFieldCollectionInputInstance,
template: IntegerFieldCollectionInputTemplate | FloatFieldCollectionInputTemplate
): string[] => {
const reasons: string[] = [];
const { minItems, maxItems, minimum, maximum, exclusiveMinimum, exclusiveMaximum, multipleOf } = template;
const value = resolveNumberFieldCollectionValue(field);
if (value === undefined) {
reasons.push(t('parameters.invoke.collectionEmpty'));
return reasons;
}
const count = value.length;
// Image collections may have min or max items to validate
if (minItems !== undefined && minItems > 0 && count === 0) {
reasons.push(t('parameters.invoke.collectionEmpty'));
}
if (minItems !== undefined && count < minItems) {
reasons.push(t('parameters.invoke.collectionTooFewItems', { count, minItems }));
}
if (maxItems !== undefined && count > maxItems) {
reasons.push(t('parameters.invoke.collectionTooManyItems', { count, maxItems }));
}
for (const num of value) {
if (maximum !== undefined && num > maximum) {
reasons.push(t('parameters.invoke.collectionNumberGTMax', { value, maximum }));
}
if (minimum !== undefined && num < minimum) {
reasons.push(t('parameters.invoke.collectionNumberLTMin', { value, minimum }));
}
if (exclusiveMaximum !== undefined && num >= exclusiveMaximum) {
reasons.push(t('parameters.invoke.collectionNumberGTExclusiveMax', { value, exclusiveMaximum }));
}
if (exclusiveMinimum !== undefined && num <= exclusiveMinimum) {
reasons.push(t('parameters.invoke.collectionNumberLTExclusiveMin', { value, exclusiveMinimum }));
}
if (multipleOf !== undefined && num % multipleOf !== 0) {
reasons.push(t('parameters.invoke.collectionNumberNotMultipleOf', { value, multipleOf }));
}
}
return reasons;
};

View File

@@ -1,29 +0,0 @@
import { z } from 'zod';
export const zFloatRangeStartStepCountGenerator = z.object({
type: z.literal('float-range-generator-start-step-count').default('float-range-generator-start-step-count'),
start: z.number().default(0),
step: z.number().default(1),
count: z.number().int().default(10),
});
export type FloatRangeStartStepCountGenerator = z.infer<typeof zFloatRangeStartStepCountGenerator>;
export const floatRangeStartStepCountGenerator = (generator: FloatRangeStartStepCountGenerator): number[] => {
const { start, step, count } = generator;
return Array.from({ length: count }, (_, i) => start + i * step);
};
export const getDefaultFloatRangeStartStepCountGenerator = (): FloatRangeStartStepCountGenerator =>
zFloatRangeStartStepCountGenerator.parse({});
export const zIntegerRangeStartStepCountGenerator = z.object({
type: z.literal('integer-range-generator-start-step-count').default('integer-range-generator-start-step-count'),
start: z.number().int().default(0),
step: z.number().int().default(1),
count: z.number().int().default(10),
});
export type IntegerRangeStartStepCountGenerator = z.infer<typeof zIntegerRangeStartStepCountGenerator>;
export const integerRangeStartStepCountGenerator = (generator: IntegerRangeStartStepCountGenerator): number[] => {
const { start, step, count } = generator;
return Array.from({ length: count }, (_, i) => start + i * step);
};
export const getDefaultIntegerRangeStartStepCountGenerator = (): IntegerRangeStartStepCountGenerator =>
zIntegerRangeStartStepCountGenerator.parse({});

View File

@@ -91,15 +91,3 @@ const zInvocationNodeEdgeExtra = z.object({
type InvocationNodeEdgeExtra = z.infer<typeof zInvocationNodeEdgeExtra>;
export type InvocationNodeEdge = Edge<InvocationNodeEdgeExtra>;
// #endregion
export const isBatchNode = (node: InvocationNode) => {
switch (node.data.type) {
case 'image_batch':
case 'string_batch':
case 'integer_batch':
case 'float_batch':
return true;
default:
return false;
}
};

View File

@@ -1,9 +1,7 @@
import { logger } from 'app/logging/logger';
import type { NodesState } from 'features/nodes/store/types';
import { isFloatFieldCollectionInputInstance, isIntegerFieldCollectionInputInstance } from 'features/nodes/types/field';
import { resolveNumberFieldCollectionValue } from 'features/nodes/types/fieldValidators';
import { isBatchNode, isInvocationNode } from 'features/nodes/types/invocation';
import { negate, omit, reduce } from 'lodash-es';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { omit, reduce } from 'lodash-es';
import type { AnyInvocation, Graph } from 'services/api/types';
import { v4 as uuidv4 } from 'uuid';
@@ -16,7 +14,7 @@ export const buildNodesGraph = (nodesState: NodesState): Graph => {
const { nodes, edges } = nodesState;
// Exclude all batch nodes - we will handle these in the batch setup in a diff function
const filteredNodes = nodes.filter(isInvocationNode).filter(negate(isBatchNode));
const filteredNodes = nodes.filter(isInvocationNode).filter((node) => node.data.type !== 'image_batch');
// Reduce the node editor nodes into invocation graph nodes
const parsedNodes = filteredNodes.reduce<NonNullable<Graph['nodes']>>((nodesAccumulator, node) => {
@@ -27,11 +25,7 @@ export const buildNodesGraph = (nodesState: NodesState): Graph => {
const transformedInputs = reduce(
inputs,
(inputsAccumulator, input, name) => {
if (isFloatFieldCollectionInputInstance(input) || isIntegerFieldCollectionInputInstance(input)) {
inputsAccumulator[name] = resolveNumberFieldCollectionValue(input);
} else {
inputsAccumulator[name] = input.value;
}
inputsAccumulator[name] = input.value;
return inputsAccumulator;
},

Some files were not shown because too many files have changed in this diff Show More