mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-23 06:58:10 -05:00
Compare commits
1 Commits
psyche/ref
...
ryan/parti
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cd268ff5b6 |
@@ -39,7 +39,7 @@ It has two sections - one for internal use and one for user settings:
|
||||
|
||||
```yaml
|
||||
# Internal metadata - do not edit:
|
||||
schema_version: 4.0.2
|
||||
schema_version: 4
|
||||
|
||||
# Put user settings here - see https://invoke-ai.github.io/InvokeAI/features/CONFIGURATION/:
|
||||
host: 0.0.0.0 # serve the app on your local network
|
||||
@@ -83,10 +83,6 @@ A subset of settings may be specified using CLI args:
|
||||
- `--root`: specify the root directory
|
||||
- `--config`: override the default `invokeai.yaml` file location
|
||||
|
||||
### Low-VRAM Mode
|
||||
|
||||
See the [Low-VRAM mode docs][low-vram] for details on enabling this feature.
|
||||
|
||||
### All Settings
|
||||
|
||||
Following the table are additional explanations for certain settings.
|
||||
@@ -189,4 +185,3 @@ The `log_format` option provides several alternative formats:
|
||||
|
||||
[basic guide to yaml files]: https://circleci.com/blog/what-is-yaml-a-beginner-s-guide/
|
||||
[Model Marketplace API Keys]: #model-marketplace-api-keys
|
||||
[low-vram]: ./features/low-vram.md
|
||||
|
||||
@@ -22,7 +22,7 @@ If you just want to use Invoke, you should use the [launcher][launcher link].
|
||||
|
||||
4. Follow the [manual install][manual install link] guide, with some modifications to the install command:
|
||||
|
||||
- Use `.` instead of `invokeai` to install from the current directory. You don't need to specify the version.
|
||||
- Use `.` instead of `invokeai` to install from the current directory.
|
||||
|
||||
- Add `-e` after the `install` operation to make this an [editable install][editable install link]. That means your changes to the python code will be reflected when you restart the Invoke server.
|
||||
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 72 KiB |
@@ -1,129 +0,0 @@
|
||||
---
|
||||
title: Low-VRAM mode
|
||||
---
|
||||
|
||||
As of v5.6.0, Invoke has a low-VRAM mode. It works on systems with dedicated GPUs (Nvidia GPUs on Windows/Linux and AMD GPUs on Linux).
|
||||
|
||||
This allows you to generate even if your GPU doesn't have enough VRAM to hold full models. Most users should be able to run even the beefiest models - like the ~24GB unquantised FLUX dev model.
|
||||
|
||||
## Enabling Low-VRAM mode
|
||||
|
||||
To enable Low-VRAM mode, add this line to your `invokeai.yaml` configuration file, then restart Invoke:
|
||||
|
||||
```yaml
|
||||
enable_partial_loading: true
|
||||
```
|
||||
|
||||
**Windows users should also [disable the Nvidia sysmem fallback](#disabling-nvidia-sysmem-fallback-windows-only)**.
|
||||
|
||||
It is possible to fine-tune the settings for best performance or if you still get out-of-memory errors (OOMs).
|
||||
|
||||
!!! tip "How to find `invokeai.yaml`"
|
||||
|
||||
The `invokeai.yaml` configuration file lives in your install directory. To access it, run the **Invoke Community Edition** launcher and click the install location. This will open your install directory in a file explorer window.
|
||||
|
||||
You'll see `invokeai.yaml` there and can edit it with any text editor. After making changes, restart Invoke.
|
||||
|
||||
If you don't see `invokeai.yaml`, launch Invoke once. It will create the file on its first startup.
|
||||
|
||||
## Details and fine-tuning
|
||||
|
||||
Low-VRAM mode involves 3 features, each of which can be configured or fine-tuned:
|
||||
|
||||
- Partial model loading
|
||||
- Dynamic RAM and VRAM cache sizes
|
||||
- Working memory
|
||||
|
||||
Read on to learn about these features and understand how to fine-tune them for your system and use-cases.
|
||||
|
||||
### Partial model loading
|
||||
|
||||
Invoke's partial model loading works by streaming model "layers" between RAM and VRAM as they are needed.
|
||||
|
||||
When an operation needs layers that are not in VRAM, but there isn't enough room to load them, inactive layers are offloaded to RAM to make room.
|
||||
|
||||
#### Enabling partial model loading
|
||||
|
||||
As described above, you can enable partial model loading by adding this line to `invokeai.yaml`:
|
||||
|
||||
```yaml
|
||||
enable_partial_loading: true
|
||||
```
|
||||
|
||||
### Dynamic RAM and VRAM cache sizes
|
||||
|
||||
Loading models from disk is slow and can be a major bottleneck for performance. Invoke uses two model caches - RAM and VRAM - to reduce loading from disk to a minimum.
|
||||
|
||||
By default, Invoke manages these caches' sizes dynamically for best performance.
|
||||
|
||||
#### Fine-tuning cache sizes
|
||||
|
||||
Prior to v5.6.0, the cache sizes were static, and for best performance, many users needed to manually fine-tune the `ram` and `vram` settings in `invokeai.yaml`.
|
||||
|
||||
As of v5.6.0, the caches are dynamically sized. The `ram` and `vram` settings are no longer used, and new settings are added to configure the cache.
|
||||
|
||||
**Most users will not need to fine-tune the cache sizes.**
|
||||
|
||||
But, if your GPU has enough VRAM to hold models fully, you might get a perf boost by manually setting the cache sizes in `invokeai.yaml`:
|
||||
|
||||
```yaml
|
||||
# Set the RAM cache size to as large as possible, leaving a few GB free for the rest of your system and Invoke.
|
||||
# For example, if your system has 32GB RAM, 28GB is a good value.
|
||||
max_cache_ram_gb: 28
|
||||
# Set the VRAM cache size to be as large as possible while leaving enough room for the working memory of the tasks you will be doing.
|
||||
# For example, on a 24GB GPU that will be running unquantized FLUX without any auxiliary models,
|
||||
# 18GB is a good value.
|
||||
max_cache_vram_gb: 18
|
||||
```
|
||||
|
||||
!!! tip "Max safe value for `max_cache_vram_gb`"
|
||||
|
||||
To determine the max safe value for `max_cache_vram_gb`, subtract `device_working_mem_gb` from your GPU's VRAM. As described below, the default for `device_working_mem_gb` is 3GB.
|
||||
|
||||
For example, if you have a 12GB GPU, the max safe value for `max_cache_vram_gb` is `12GB - 3GB = 9GB`.
|
||||
|
||||
If you had increased `device_working_mem_gb` to 4GB, then the max safe value for `max_cache_vram_gb` is `12GB - 4GB = 8GB`.
|
||||
|
||||
### Working memory
|
||||
|
||||
Invoke cannot use _all_ of your VRAM for model caching and loading. It requires some VRAM to use as working memory for various operations.
|
||||
|
||||
Invoke reserves 3GB VRAM as working memory by default, which is enough for most use-cases. However, it is possible to fine-tune this setting if you still get OOMs.
|
||||
|
||||
#### Fine-tuning working memory
|
||||
|
||||
You can increase the working memory size in `invokeai.yaml` to prevent OOMs:
|
||||
|
||||
```yaml
|
||||
# The default is 3GB - bump it up to 4GB to prevent OOMs.
|
||||
device_working_mem_gb: 4
|
||||
```
|
||||
|
||||
!!! tip "Operations may request more working memory"
|
||||
|
||||
For some operations, we can determine VRAM requirements in advance and allocate additional working memory to prevent OOMs.
|
||||
|
||||
VAE decoding is one such operation. This operation converts the generation process's output into an image. For large image outputs, this might use more than the default working memory size of 3GB.
|
||||
|
||||
During this decoding step, Invoke calculates how much VRAM will be required to decode and requests that much VRAM from the model manager. If the amount exceeds the working memory size, the model manager will offload cached model layers from VRAM until there's enough VRAM to decode.
|
||||
|
||||
Once decoding completes, the model manager "reclaims" the extra VRAM allocated as working memory for future model loading operations.
|
||||
|
||||
### Disabling Nvidia sysmem fallback (Windows only)
|
||||
|
||||
On Windows, Nvidia GPUs are able to use system RAM when their VRAM fills up via **sysmem fallback**. While it sounds like a good idea on the surface, in practice it causes massive slowdowns during generation.
|
||||
|
||||
It is strongly suggested to disable this feature:
|
||||
|
||||
- Open the **NVIDIA Control Panel** app.
|
||||
- Expand **3D Settings** on the left panel.
|
||||
- Click **Manage 3D Settings** in the left panel.
|
||||
- Find **CUDA - Sysmem Fallback Policy** in the right panel and set it to **Prefer No Sysmem Fallback**.
|
||||
|
||||

|
||||
|
||||
!!! tip "Invoke does the same thing, but better"
|
||||
|
||||
If the sysmem fallback feature sounds familiar, that's because Invoke's partial model loading strategy is conceptually very similar - use VRAM when there's room, else fall back to RAM.
|
||||
|
||||
Unfortunately, the Nvidia implementation is not optimized for applications like Invoke and does more harm than good.
|
||||
@@ -75,14 +75,14 @@ The following commands vary depending on the version of Invoke being installed a
|
||||
|
||||
- If you are on Windows with an Nvidia GPU, use `https://download.pytorch.org/whl/cu124`.
|
||||
- If you are on Linux with no GPU, use `https://download.pytorch.org/whl/cpu`.
|
||||
- If you are on Linux with an AMD GPU, use `https://download.pytorch.org/whl/rocm6.1`.
|
||||
- If you are on Linux with an AMD GPU, use `https://download.pytorch.org/whl/rocm62`.
|
||||
- **In all other cases, do not use an index.**
|
||||
|
||||
=== "Invoke v4"
|
||||
|
||||
- If you are on Windows with an Nvidia GPU, use `https://download.pytorch.org/whl/cu124`.
|
||||
- If you are on Linux with no GPU, use `https://download.pytorch.org/whl/cpu`.
|
||||
- If you are on Linux with an AMD GPU, use `https://download.pytorch.org/whl/rocm5.2`.
|
||||
- If you are on Linux with an AMD GPU, use `https://download.pytorch.org/whl/rocm52`.
|
||||
- **In all other cases, do not use an index.**
|
||||
|
||||
8. Install the `invokeai` package. Substitute the package specifier and version.
|
||||
|
||||
@@ -54,7 +54,7 @@ If you have an existing Invoke installation, you can select it and let the launc
|
||||
- Open the **Invoke-Installer-mac-arm64.dmg** file.
|
||||
- Drag the launcher to **Applications**.
|
||||
- Open a terminal.
|
||||
- Run `xattr -d 'com.apple.quarantine' /Applications/Invoke\ Community\ Edition.app`.
|
||||
- Run `xattr -cr /Applications/Invoke-Installer.app`.
|
||||
|
||||
You should now be able to run the launcher.
|
||||
|
||||
|
||||
@@ -535,7 +535,7 @@ View:
|
||||
**Node Link:** https://github.com/simonfuhrmann/invokeai-stereo
|
||||
|
||||
**Example Workflow and Output**
|
||||
</br><img src="https://raw.githubusercontent.com/simonfuhrmann/invokeai-stereo/refs/heads/main/docs/example_promo_03.jpg" width="600" />
|
||||
</br><img src="https://github.com/simonfuhrmann/invokeai-stereo/blob/main/docs/example_promo_03.jpg" width="500" />
|
||||
|
||||
--------------------------------
|
||||
### Simple Skin Detection
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
import contextlib
|
||||
import io
|
||||
import pathlib
|
||||
import shutil
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
@@ -20,6 +21,7 @@ from starlette.exceptions import HTTPException
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.services.config import get_config
|
||||
from invokeai.app.services.model_images.model_images_common import ModelImageFileNotFoundException
|
||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
|
||||
from invokeai.app.services.model_records import (
|
||||
@@ -846,6 +848,74 @@ async def get_starter_models() -> StarterModelResponse:
|
||||
return StarterModelResponse(starter_models=starter_models, starter_bundles=starter_bundles)
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/model_cache",
|
||||
operation_id="get_cache_size",
|
||||
response_model=float,
|
||||
summary="Get maximum size of model manager RAM or VRAM cache.",
|
||||
)
|
||||
async def get_cache_size(cache_type: CacheType = Query(description="The cache type", default=CacheType.RAM)) -> float:
|
||||
"""Return the current RAM or VRAM cache size setting (in GB)."""
|
||||
cache = ApiDependencies.invoker.services.model_manager.load.ram_cache
|
||||
value = 0.0
|
||||
if cache_type == CacheType.RAM:
|
||||
value = cache.max_cache_size
|
||||
elif cache_type == CacheType.VRAM:
|
||||
value = cache.max_vram_cache_size
|
||||
return value
|
||||
|
||||
|
||||
@model_manager_router.put(
|
||||
"/model_cache",
|
||||
operation_id="set_cache_size",
|
||||
response_model=float,
|
||||
summary="Set maximum size of model manager RAM or VRAM cache, optionally writing new value out to invokeai.yaml config file.",
|
||||
)
|
||||
async def set_cache_size(
|
||||
value: float = Query(description="The new value for the maximum cache size"),
|
||||
cache_type: CacheType = Query(description="The cache type", default=CacheType.RAM),
|
||||
persist: bool = Query(description="Write new value out to invokeai.yaml", default=False),
|
||||
) -> float:
|
||||
"""Set the current RAM or VRAM cache size setting (in GB). ."""
|
||||
cache = ApiDependencies.invoker.services.model_manager.load.ram_cache
|
||||
app_config = get_config()
|
||||
# Record initial state.
|
||||
vram_old = app_config.vram
|
||||
ram_old = app_config.ram
|
||||
|
||||
# Prepare target state.
|
||||
vram_new = vram_old
|
||||
ram_new = ram_old
|
||||
if cache_type == CacheType.RAM:
|
||||
ram_new = value
|
||||
elif cache_type == CacheType.VRAM:
|
||||
vram_new = value
|
||||
else:
|
||||
raise ValueError(f"Unexpected {cache_type=}.")
|
||||
|
||||
config_path = app_config.config_file_path
|
||||
new_config_path = config_path.with_suffix(".yaml.new")
|
||||
|
||||
try:
|
||||
# Try to apply the target state.
|
||||
cache.max_vram_cache_size = vram_new
|
||||
cache.max_cache_size = ram_new
|
||||
app_config.ram = ram_new
|
||||
app_config.vram = vram_new
|
||||
if persist:
|
||||
app_config.write_file(new_config_path)
|
||||
shutil.move(new_config_path, config_path)
|
||||
except Exception as e:
|
||||
# If there was a failure, restore the initial state.
|
||||
cache.max_cache_size = ram_old
|
||||
cache.max_vram_cache_size = vram_old
|
||||
app_config.ram = ram_old
|
||||
app_config.vram = vram_old
|
||||
|
||||
raise RuntimeError("Failed to update cache size") from e
|
||||
return value
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/stats",
|
||||
operation_id="get_stats",
|
||||
|
||||
@@ -1,118 +0,0 @@
|
||||
from typing import Literal
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
Classification,
|
||||
invocation,
|
||||
)
|
||||
from invokeai.app.invocations.fields import (
|
||||
ImageField,
|
||||
Input,
|
||||
InputField,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import FloatOutput, ImageOutput, IntegerOutput, StringOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
BATCH_GROUP_IDS = Literal[
|
||||
"None",
|
||||
"Group 1",
|
||||
"Group 2",
|
||||
"Group 3",
|
||||
"Group 4",
|
||||
"Group 5",
|
||||
]
|
||||
|
||||
|
||||
class NotExecutableNodeError(Exception):
|
||||
def __init__(self, message: str = "This class should never be executed or instantiated directly."):
|
||||
super().__init__(message)
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class BaseBatchInvocation(BaseInvocation):
|
||||
batch_group_id: BATCH_GROUP_IDS = InputField(
|
||||
default="None",
|
||||
description="The ID of this batch node's group. If provided, all batch nodes in with the same ID will be 'zipped' before execution, and all nodes' collections must be of the same size.",
|
||||
input=Input.Direct,
|
||||
title="Batch Group",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
raise NotExecutableNodeError()
|
||||
|
||||
|
||||
@invocation(
|
||||
"image_batch",
|
||||
title="Image Batch",
|
||||
tags=["primitives", "image", "batch", "special"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
class ImageBatchInvocation(BaseBatchInvocation):
|
||||
"""Create a batched generation, where the workflow is executed once for each image in the batch."""
|
||||
|
||||
images: list[ImageField] = InputField(
|
||||
default=[], min_length=1, description="The images to batch over", input=Input.Direct
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
raise NotExecutableNodeError()
|
||||
|
||||
|
||||
@invocation(
|
||||
"string_batch",
|
||||
title="String Batch",
|
||||
tags=["primitives", "string", "batch", "special"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
class StringBatchInvocation(BaseBatchInvocation):
|
||||
"""Create a batched generation, where the workflow is executed once for each string in the batch."""
|
||||
|
||||
strings: list[str] = InputField(
|
||||
default=[], min_length=1, description="The strings to batch over", input=Input.Direct
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> StringOutput:
|
||||
raise NotExecutableNodeError()
|
||||
|
||||
|
||||
@invocation(
|
||||
"integer_batch",
|
||||
title="Integer Batch",
|
||||
tags=["primitives", "integer", "number", "batch", "special"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
class IntegerBatchInvocation(BaseBatchInvocation):
|
||||
"""Create a batched generation, where the workflow is executed once for each integer in the batch."""
|
||||
|
||||
integers: list[int] = InputField(
|
||||
default=[], min_length=1, description="The integers to batch over", input=Input.Direct
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||
raise NotExecutableNodeError()
|
||||
|
||||
|
||||
@invocation(
|
||||
"float_batch",
|
||||
title="Float Batch",
|
||||
tags=["primitives", "float", "number", "batch", "special"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
class FloatBatchInvocation(BaseBatchInvocation):
|
||||
"""Create a batched generation, where the workflow is executed once for each float in the batch."""
|
||||
|
||||
floats: list[float] = InputField(
|
||||
default=[], min_length=1, description="The floats to batch over", input=Input.Direct
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatOutput:
|
||||
raise NotExecutableNodeError()
|
||||
@@ -63,6 +63,9 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
tokenizer_info = context.models.load(self.clip.tokenizer)
|
||||
text_encoder_info = context.models.load(self.clip.text_encoder)
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
|
||||
for lora in self.clip.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
@@ -73,13 +76,12 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
|
||||
text_encoder_info = context.models.load(self.clip.text_encoder)
|
||||
ti_list = generate_ti_list(self.prompt, text_encoder_info.config.base, context)
|
||||
|
||||
with (
|
||||
# apply all patches while the model is on the target device
|
||||
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
|
||||
context.models.load(self.clip.tokenizer) as tokenizer,
|
||||
tokenizer_info as tokenizer,
|
||||
LayerPatcher.apply_smart_model_patches(
|
||||
model=text_encoder,
|
||||
patches=_lora_loader(),
|
||||
@@ -103,7 +105,6 @@ class CompelInvocation(BaseInvocation):
|
||||
textual_inversion_manager=ti_manager,
|
||||
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
|
||||
truncate_long_prompts=False,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
)
|
||||
|
||||
conjunction = Compel.parse_prompt_string(self.prompt)
|
||||
@@ -138,7 +139,9 @@ class SDXLPromptInvocationBase:
|
||||
lora_prefix: str,
|
||||
zero_on_empty: bool,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
tokenizer_info = context.models.load(clip_field.tokenizer)
|
||||
text_encoder_info = context.models.load(clip_field.text_encoder)
|
||||
|
||||
# return zero on empty
|
||||
if prompt == "" and zero_on_empty:
|
||||
cpu_text_encoder = text_encoder_info.model
|
||||
@@ -176,7 +179,7 @@ class SDXLPromptInvocationBase:
|
||||
with (
|
||||
# apply all patches while the model is on the target device
|
||||
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
|
||||
context.models.load(clip_field.tokenizer) as tokenizer,
|
||||
tokenizer_info as tokenizer,
|
||||
LayerPatcher.apply_smart_model_patches(
|
||||
model=text_encoder,
|
||||
patches=_lora_loader(),
|
||||
@@ -204,7 +207,6 @@ class SDXLPromptInvocationBase:
|
||||
truncate_long_prompts=False, # TODO:
|
||||
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
|
||||
requires_pooled=get_pooled,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
)
|
||||
|
||||
conjunction = Compel.parse_prompt_string(prompt)
|
||||
@@ -222,6 +224,7 @@ class SDXLPromptInvocationBase:
|
||||
|
||||
del tokenizer
|
||||
del text_encoder
|
||||
del tokenizer_info
|
||||
del text_encoder_info
|
||||
|
||||
c = c.detach().to("cpu")
|
||||
|
||||
@@ -10,9 +10,7 @@ import torchvision.transforms as T
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
from diffusers.models.adapter import T2IAdapter
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
|
||||
from diffusers.schedulers.scheduling_dpmsolver_sde import DPMSolverSDEScheduler
|
||||
from diffusers.schedulers.scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler
|
||||
from diffusers.schedulers.scheduling_tcd import TCDScheduler
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin as Scheduler
|
||||
from PIL import Image
|
||||
@@ -40,7 +38,6 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
@@ -86,14 +83,12 @@ def get_scheduler(
|
||||
scheduler_info: ModelIdentifierField,
|
||||
scheduler_name: str,
|
||||
seed: int,
|
||||
unet_config: AnyModelConfig,
|
||||
) -> Scheduler:
|
||||
"""Load a scheduler and apply some scheduler-specific overrides."""
|
||||
# TODO(ryand): Silently falling back to ddim seems like a bad idea. Look into why this was added and remove if
|
||||
# possible.
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
||||
orig_scheduler_info = context.models.load(scheduler_info)
|
||||
|
||||
with orig_scheduler_info as orig_scheduler:
|
||||
scheduler_config = orig_scheduler.config
|
||||
|
||||
@@ -105,17 +100,10 @@ def get_scheduler(
|
||||
"_backup": scheduler_config,
|
||||
}
|
||||
|
||||
if hasattr(unet_config, "prediction_type"):
|
||||
scheduler_config["prediction_type"] = unet_config.prediction_type
|
||||
|
||||
# make dpmpp_sde reproducable(seed can be passed only in initializer)
|
||||
if scheduler_class is DPMSolverSDEScheduler:
|
||||
scheduler_config["noise_sampler_seed"] = seed
|
||||
|
||||
if scheduler_class is DPMSolverMultistepScheduler or scheduler_class is DPMSolverSinglestepScheduler:
|
||||
if scheduler_config["_class_name"] == "DEISMultistepScheduler" and scheduler_config["algorithm_type"] == "deis":
|
||||
scheduler_config["algorithm_type"] = "dpmsolver++"
|
||||
|
||||
scheduler = scheduler_class.from_config(scheduler_config)
|
||||
|
||||
# hack copied over from generate.py
|
||||
@@ -423,7 +411,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
context: InvocationContext,
|
||||
control_input: ControlField | list[ControlField] | None,
|
||||
latents_shape: List[int],
|
||||
device: torch.device,
|
||||
exit_stack: ExitStack,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
) -> list[ControlNetData] | None:
|
||||
@@ -465,7 +452,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
height=control_height_resize,
|
||||
# batch_size=batch_size * num_images_per_prompt,
|
||||
# num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
device=control_model.device,
|
||||
dtype=control_model.dtype,
|
||||
control_mode=control_info.control_mode,
|
||||
resize_mode=control_info.resize_mode,
|
||||
@@ -560,6 +547,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
for single_ip_adapter in ip_adapters:
|
||||
with context.models.load(single_ip_adapter.ip_adapter_model) as ip_adapter_model:
|
||||
assert isinstance(ip_adapter_model, IPAdapter)
|
||||
image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model)
|
||||
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
|
||||
single_ipa_image_fields = single_ip_adapter.image
|
||||
if not isinstance(single_ipa_image_fields, list):
|
||||
@@ -568,7 +556,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
single_ipa_images = [
|
||||
context.images.get_pil(image.image_name, mode="RGB") for image in single_ipa_image_fields
|
||||
]
|
||||
with context.models.load(single_ip_adapter.image_encoder_model) as image_encoder_model:
|
||||
with image_encoder_model_info as image_encoder_model:
|
||||
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
|
||||
# Get image embeddings from CLIP and ImageProjModel.
|
||||
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
|
||||
@@ -618,7 +606,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
context: InvocationContext,
|
||||
t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]],
|
||||
latents_shape: list[int],
|
||||
device: torch.device,
|
||||
do_classifier_free_guidance: bool,
|
||||
) -> Optional[list[T2IAdapterData]]:
|
||||
if t2i_adapter is None:
|
||||
@@ -634,6 +621,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
t2i_adapter_data = []
|
||||
for t2i_adapter_field in t2i_adapter:
|
||||
t2i_adapter_model_config = context.models.get_config(t2i_adapter_field.t2i_adapter_model.key)
|
||||
t2i_adapter_loaded_model = context.models.load(t2i_adapter_field.t2i_adapter_model)
|
||||
image = context.images.get_pil(t2i_adapter_field.image.image_name, mode="RGB")
|
||||
|
||||
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
|
||||
@@ -649,7 +637,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
raise ValueError(f"Unexpected T2I-Adapter base model type: '{t2i_adapter_model_config.base}'.")
|
||||
|
||||
t2i_adapter_model: T2IAdapter
|
||||
with context.models.load(t2i_adapter_field.t2i_adapter_model) as t2i_adapter_model:
|
||||
with t2i_adapter_loaded_model as t2i_adapter_model:
|
||||
total_downscale_factor = t2i_adapter_model.total_downscale_factor
|
||||
|
||||
# Note: We have hard-coded `do_classifier_free_guidance=False`. This is because we only want to prepare
|
||||
@@ -669,7 +657,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
width=control_width_resize,
|
||||
height=control_height_resize,
|
||||
num_channels=t2i_adapter_model.config["in_channels"], # mypy treats this as a FrozenDict
|
||||
device=device,
|
||||
device=t2i_adapter_model.device,
|
||||
dtype=t2i_adapter_model.dtype,
|
||||
resize_mode=t2i_adapter_field.resize_mode,
|
||||
)
|
||||
@@ -834,9 +822,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
||||
_, _, latent_height, latent_width = latents.shape
|
||||
|
||||
# get the unet's config so that we can pass the base to sd_step_callback()
|
||||
unet_config = context.models.get_config(self.unet.unet.key)
|
||||
|
||||
conditioning_data = self.get_conditioning_data(
|
||||
context=context,
|
||||
positive_conditioning_field=self.positive_conditioning,
|
||||
@@ -856,7 +841,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
seed=seed,
|
||||
unet_config=unet_config,
|
||||
)
|
||||
|
||||
timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
|
||||
@@ -868,6 +852,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
denoising_end=self.denoising_end,
|
||||
)
|
||||
|
||||
# get the unet's config so that we can pass the base to sd_step_callback()
|
||||
unet_config = context.models.get_config(self.unet.unet.key)
|
||||
|
||||
### preview
|
||||
def step_callback(state: PipelineIntermediateState) -> None:
|
||||
context.util.sd_step_callback(state, unet_config.base)
|
||||
@@ -939,8 +926,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
# ext: t2i/ip adapter
|
||||
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
|
||||
|
||||
unet_info = context.models.load(self.unet.unet)
|
||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||
with (
|
||||
context.models.load(self.unet.unet).model_on_device() as (cached_weights, unet),
|
||||
unet_info.model_on_device() as (cached_weights, unet),
|
||||
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
|
||||
# ext: controlnet
|
||||
ext_manager.patch_extensions(denoise_ctx),
|
||||
@@ -961,7 +950,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
@torch.no_grad()
|
||||
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
|
||||
def _old_invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
device = TorchDevice.choose_torch_device()
|
||||
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
||||
|
||||
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
|
||||
@@ -976,7 +964,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
context,
|
||||
self.t2i_adapter,
|
||||
latents.shape,
|
||||
device=device,
|
||||
do_classifier_free_guidance=True,
|
||||
)
|
||||
|
||||
@@ -1008,9 +995,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
del lora_info
|
||||
return
|
||||
|
||||
unet_info = context.models.load(self.unet.unet)
|
||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||
with (
|
||||
ExitStack() as exit_stack,
|
||||
context.models.load(self.unet.unet).model_on_device() as (cached_weights, unet),
|
||||
unet_info.model_on_device() as (cached_weights, unet),
|
||||
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
||||
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
|
||||
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||
@@ -1023,20 +1012,19 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
),
|
||||
):
|
||||
assert isinstance(unet, UNet2DConditionModel)
|
||||
latents = latents.to(device=device, dtype=unet.dtype)
|
||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||
if noise is not None:
|
||||
noise = noise.to(device=device, dtype=unet.dtype)
|
||||
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||
if mask is not None:
|
||||
mask = mask.to(device=device, dtype=unet.dtype)
|
||||
mask = mask.to(device=unet.device, dtype=unet.dtype)
|
||||
if masked_latents is not None:
|
||||
masked_latents = masked_latents.to(device=device, dtype=unet.dtype)
|
||||
masked_latents = masked_latents.to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
seed=seed,
|
||||
unet_config=unet_config,
|
||||
)
|
||||
|
||||
pipeline = self.create_pipeline(unet, scheduler)
|
||||
@@ -1046,7 +1034,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
context=context,
|
||||
positive_conditioning_field=self.positive_conditioning,
|
||||
negative_conditioning_field=self.negative_conditioning,
|
||||
device=device,
|
||||
device=unet.device,
|
||||
dtype=unet.dtype,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
@@ -1059,7 +1047,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
context=context,
|
||||
control_input=self.control,
|
||||
latents_shape=latents.shape,
|
||||
device=device,
|
||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||
do_classifier_free_guidance=True,
|
||||
exit_stack=exit_stack,
|
||||
@@ -1077,7 +1064,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
|
||||
scheduler,
|
||||
device=device,
|
||||
device=unet.device,
|
||||
steps=self.steps,
|
||||
denoising_start=self.denoising_start,
|
||||
denoising_end=self.denoising_end,
|
||||
|
||||
@@ -199,8 +199,8 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
else None
|
||||
)
|
||||
|
||||
transformer_config = context.models.get_config(self.transformer.transformer)
|
||||
is_schnell = "schnell" in getattr(transformer_config, "config_path", "")
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
is_schnell = "schnell" in getattr(transformer_info.config, "config_path", "")
|
||||
|
||||
# Calculate the timestep schedule.
|
||||
timesteps = get_schedule(
|
||||
@@ -276,7 +276,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
# TODO(ryand): We should really do this in a separate invocation to benefit from caching.
|
||||
ip_adapter_fields = self._normalize_ip_adapter_fields()
|
||||
pos_image_prompt_clip_embeds, neg_image_prompt_clip_embeds = self._prep_ip_adapter_image_prompt_clip_embeds(
|
||||
ip_adapter_fields, context, device=x.device
|
||||
ip_adapter_fields, context
|
||||
)
|
||||
|
||||
cfg_scale = self.prep_cfg_scale(
|
||||
@@ -299,11 +299,9 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
)
|
||||
|
||||
# Load the transformer model.
|
||||
(cached_weights, transformer) = exit_stack.enter_context(
|
||||
context.models.load(self.transformer.transformer).model_on_device()
|
||||
)
|
||||
(cached_weights, transformer) = exit_stack.enter_context(transformer_info.model_on_device())
|
||||
assert isinstance(transformer, Flux)
|
||||
config = transformer_config
|
||||
config = transformer_info.config
|
||||
assert config is not None
|
||||
|
||||
# Determine if the model is quantized.
|
||||
@@ -514,18 +512,15 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
# before loading the models. Then make sure that all VAE encoding is done before loading the ControlNets to
|
||||
# minimize peak memory.
|
||||
|
||||
# First, load the ControlNet models so that we can determine the ControlNet types.
|
||||
controlnet_models = [context.models.load(controlnet.control_model) for controlnet in controlnets]
|
||||
|
||||
# Calculate the controlnet conditioning tensors.
|
||||
# We do this before loading the ControlNet models because it may require running the VAE, and we are trying to
|
||||
# keep peak memory down.
|
||||
controlnet_conds: list[torch.Tensor] = []
|
||||
for controlnet in controlnets:
|
||||
for controlnet, controlnet_model in zip(controlnets, controlnet_models, strict=True):
|
||||
image = context.images.get_pil(controlnet.image.image_name)
|
||||
|
||||
# HACK(ryand): We have to load the ControlNet model to determine whether the VAE needs to be run. We really
|
||||
# shouldn't have to load the model here. There's a risk that the model will be dropped from the model cache
|
||||
# before we load it into VRAM and thus we'll have to load it again (context:
|
||||
# https://github.com/invoke-ai/InvokeAI/issues/7513).
|
||||
controlnet_model = context.models.load(controlnet.control_model)
|
||||
if isinstance(controlnet_model.model, InstantXControlNetFlux):
|
||||
if self.controlnet_vae is None:
|
||||
raise ValueError("A ControlNet VAE is required when using an InstantX FLUX ControlNet.")
|
||||
@@ -555,8 +550,10 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
# Finally, load the ControlNet models and initialize the ControlNet extensions.
|
||||
controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension] = []
|
||||
for controlnet, controlnet_cond in zip(controlnets, controlnet_conds, strict=True):
|
||||
model = exit_stack.enter_context(context.models.load(controlnet.control_model))
|
||||
for controlnet, controlnet_cond, controlnet_model in zip(
|
||||
controlnets, controlnet_conds, controlnet_models, strict=True
|
||||
):
|
||||
model = exit_stack.enter_context(controlnet_model)
|
||||
|
||||
if isinstance(model, XLabsControlNetFlux):
|
||||
controlnet_extensions.append(
|
||||
@@ -626,7 +623,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
self,
|
||||
ip_adapter_fields: list[IPAdapterField],
|
||||
context: InvocationContext,
|
||||
device: torch.device,
|
||||
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
|
||||
"""Run the IPAdapter CLIPVisionModel, returning image prompt embeddings."""
|
||||
clip_image_processor = CLIPImageProcessor()
|
||||
@@ -666,11 +662,11 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
|
||||
|
||||
clip_image: torch.Tensor = clip_image_processor(images=pos_images, return_tensors="pt").pixel_values
|
||||
clip_image = clip_image.to(device=device, dtype=image_encoder_model.dtype)
|
||||
clip_image = clip_image.to(device=image_encoder_model.device, dtype=image_encoder_model.dtype)
|
||||
pos_clip_image_embeds = image_encoder_model(clip_image).image_embeds
|
||||
|
||||
clip_image = clip_image_processor(images=neg_images, return_tensors="pt").pixel_values
|
||||
clip_image = clip_image.to(device=device, dtype=image_encoder_model.dtype)
|
||||
clip_image = clip_image.to(device=image_encoder_model.device, dtype=image_encoder_model.dtype)
|
||||
neg_clip_image_embeds = image_encoder_model(clip_image).image_embeds
|
||||
|
||||
pos_image_prompt_clip_embeds.append(pos_clip_image_embeds)
|
||||
|
||||
@@ -10,10 +10,6 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.t5_model_identifier import (
|
||||
preprocess_t5_encoder_model_identifier,
|
||||
preprocess_t5_tokenizer_model_identifier,
|
||||
)
|
||||
from invokeai.backend.flux.util import max_seq_lengths
|
||||
from invokeai.backend.model_manager.config import (
|
||||
CheckpointConfigBase,
|
||||
@@ -78,8 +74,8 @@ class FluxModelLoaderInvocation(BaseInvocation):
|
||||
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
|
||||
tokenizer2 = preprocess_t5_tokenizer_model_identifier(self.t5_encoder_model)
|
||||
t5_encoder = preprocess_t5_encoder_model_identifier(self.t5_encoder_model)
|
||||
tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
|
||||
transformer_config = context.models.get_config(transformer)
|
||||
assert isinstance(transformer_config, CheckpointConfigBase)
|
||||
|
||||
@@ -2,7 +2,7 @@ from contextlib import ExitStack
|
||||
from typing import Iterator, Literal, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer, T5TokenizerFast
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
@@ -69,14 +69,17 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
|
||||
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
|
||||
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
|
||||
|
||||
prompt = [self.prompt]
|
||||
|
||||
with (
|
||||
context.models.load(self.t5_encoder.text_encoder) as t5_text_encoder,
|
||||
context.models.load(self.t5_encoder.tokenizer) as t5_tokenizer,
|
||||
t5_text_encoder_info as t5_text_encoder,
|
||||
t5_tokenizer_info as t5_tokenizer,
|
||||
):
|
||||
assert isinstance(t5_text_encoder, T5EncoderModel)
|
||||
assert isinstance(t5_tokenizer, (T5Tokenizer, T5TokenizerFast))
|
||||
assert isinstance(t5_tokenizer, T5Tokenizer)
|
||||
|
||||
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)
|
||||
|
||||
@@ -87,20 +90,22 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
return prompt_embeds
|
||||
|
||||
def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
|
||||
prompt = [self.prompt]
|
||||
|
||||
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
|
||||
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
|
||||
clip_text_encoder_config = clip_text_encoder_info.config
|
||||
assert clip_text_encoder_config is not None
|
||||
|
||||
prompt = [self.prompt]
|
||||
|
||||
with (
|
||||
clip_text_encoder_info.model_on_device() as (cached_weights, clip_text_encoder),
|
||||
context.models.load(self.clip.tokenizer) as clip_tokenizer,
|
||||
clip_tokenizer_info as clip_tokenizer,
|
||||
ExitStack() as exit_stack,
|
||||
):
|
||||
assert isinstance(clip_text_encoder, CLIPTextModel)
|
||||
assert isinstance(clip_tokenizer, CLIPTokenizer)
|
||||
|
||||
clip_text_encoder_config = clip_text_encoder_info.config
|
||||
assert clip_text_encoder_config is not None
|
||||
|
||||
# Apply LoRA models to the CLIP encoder.
|
||||
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
|
||||
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
|
||||
|
||||
@@ -3,7 +3,6 @@ from einops import rearrange
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
@@ -25,7 +24,7 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
title="FLUX Latents to Image",
|
||||
tags=["latents", "image", "vae", "l2i", "flux"],
|
||||
category="latents",
|
||||
version="1.0.1",
|
||||
version="1.0.0",
|
||||
)
|
||||
class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates an image from latents."""
|
||||
@@ -39,23 +38,8 @@ class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
def _estimate_working_memory(self, latents: torch.Tensor, vae: AutoEncoder) -> int:
|
||||
"""Estimate the working memory required by the invocation in bytes."""
|
||||
# It was found experimentally that the peak working memory scales linearly with the number of pixels and the
|
||||
# element size (precision).
|
||||
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
|
||||
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
|
||||
element_size = next(vae.parameters()).element_size()
|
||||
scaling_constant = 1090 # Determined experimentally.
|
||||
working_memory = out_h * out_w * element_size * scaling_constant
|
||||
|
||||
# We add a 20% buffer to the working memory estimate to be safe.
|
||||
working_memory = working_memory * 1.2
|
||||
return int(working_memory)
|
||||
|
||||
def _vae_decode(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image:
|
||||
estimated_working_memory = self._estimate_working_memory(latents, vae_info.model)
|
||||
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
|
||||
with vae_info as vae:
|
||||
assert isinstance(vae, AutoEncoder)
|
||||
vae_dtype = next(iter(vae.parameters())).dtype
|
||||
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
|
||||
|
||||
@@ -23,7 +23,6 @@ from invokeai.app.invocations.fields import (
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.misc import SEED_MAX
|
||||
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
||||
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
||||
|
||||
@@ -162,12 +161,12 @@ class ImagePasteInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
crop: bool = InputField(default=False, description="Crop to base image dimensions")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
base_image = context.images.get_pil(self.base_image.image_name, mode="RGBA")
|
||||
image = context.images.get_pil(self.image.image_name, mode="RGBA")
|
||||
base_image = context.images.get_pil(self.base_image.image_name)
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
mask = None
|
||||
if self.mask is not None:
|
||||
mask = context.images.get_pil(self.mask.image_name, mode="L")
|
||||
mask = ImageOps.invert(mask)
|
||||
mask = context.images.get_pil(self.mask.image_name)
|
||||
mask = ImageOps.invert(mask.convert("L"))
|
||||
# TODO: probably shouldn't invert mask here... should user be required to do it?
|
||||
|
||||
min_x = min(0, self.x)
|
||||
@@ -177,11 +176,7 @@ class ImagePasteInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
new_image = Image.new(mode="RGBA", size=(max_x - min_x, max_y - min_y), color=(0, 0, 0, 0))
|
||||
new_image.paste(base_image, (abs(min_x), abs(min_y)))
|
||||
|
||||
# Create a temporary image to paste the image with transparency
|
||||
temp_image = Image.new("RGBA", new_image.size)
|
||||
temp_image.paste(image, (max(0, self.x), max(0, self.y)), mask=mask)
|
||||
new_image = Image.alpha_composite(new_image, temp_image)
|
||||
new_image.paste(image, (max(0, self.x), max(0, self.y)), mask=mask)
|
||||
|
||||
if self.crop:
|
||||
base_w, base_h = base_image.size
|
||||
@@ -306,44 +301,14 @@ class ImageBlurInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
blur_type: Literal["gaussian", "box"] = InputField(default="gaussian", description="The type of blur")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, mode="RGBA")
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
# Split the image into RGBA channels
|
||||
r, g, b, a = image.split()
|
||||
|
||||
# Premultiply RGB channels by alpha
|
||||
premultiplied_image = ImageChops.multiply(image, a.convert("RGBA"))
|
||||
premultiplied_image.putalpha(a)
|
||||
|
||||
# Apply the blur
|
||||
blur = (
|
||||
ImageFilter.GaussianBlur(self.radius) if self.blur_type == "gaussian" else ImageFilter.BoxBlur(self.radius)
|
||||
)
|
||||
blurred_image = premultiplied_image.filter(blur)
|
||||
blur_image = image.filter(blur)
|
||||
|
||||
# Split the blurred image into RGBA channels
|
||||
r, g, b, a_orig = blurred_image.split()
|
||||
|
||||
# Convert to float using NumPy. float 32/64 division are much faster than float 16
|
||||
r = numpy.array(r, dtype=numpy.float32)
|
||||
g = numpy.array(g, dtype=numpy.float32)
|
||||
b = numpy.array(b, dtype=numpy.float32)
|
||||
a = numpy.array(a_orig, dtype=numpy.float32) / 255.0 # Normalize alpha to [0, 1]
|
||||
|
||||
# Unpremultiply RGB channels by alpha
|
||||
r /= a + 1e-6 # Add a small epsilon to avoid division by zero
|
||||
g /= a + 1e-6
|
||||
b /= a + 1e-6
|
||||
|
||||
# Convert back to PIL images
|
||||
r = Image.fromarray(numpy.uint8(numpy.clip(r, 0, 255)))
|
||||
g = Image.fromarray(numpy.uint8(numpy.clip(g, 0, 255)))
|
||||
b = Image.fromarray(numpy.uint8(numpy.clip(b, 0, 255)))
|
||||
|
||||
# Merge back into a single image
|
||||
result_image = Image.merge("RGBA", (r, g, b, a_orig))
|
||||
|
||||
image_dto = context.images.save(image=result_image)
|
||||
image_dto = context.images.save(image=blur_image)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
@@ -1090,67 +1055,3 @@ class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
image_dto = context.images.save(image=generated_image)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
|
||||
@invocation(
|
||||
"img_noise",
|
||||
title="Add Image Noise",
|
||||
tags=["image", "noise"],
|
||||
category="image",
|
||||
version="1.0.1",
|
||||
)
|
||||
class ImageNoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Add noise to an image"""
|
||||
|
||||
image: ImageField = InputField(description="The image to add noise to")
|
||||
seed: int = InputField(
|
||||
default=0,
|
||||
ge=0,
|
||||
le=SEED_MAX,
|
||||
description=FieldDescriptions.seed,
|
||||
)
|
||||
noise_type: Literal["gaussian", "salt_and_pepper"] = InputField(
|
||||
default="gaussian",
|
||||
description="The type of noise to add",
|
||||
)
|
||||
amount: float = InputField(default=0.1, ge=0, le=1, description="The amount of noise to add")
|
||||
noise_color: bool = InputField(default=True, description="Whether to add colored noise")
|
||||
size: int = InputField(default=1, ge=1, description="The size of the noise points")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, mode="RGBA")
|
||||
|
||||
# Save out the alpha channel
|
||||
alpha = image.getchannel("A")
|
||||
|
||||
# Set the seed for numpy random
|
||||
rs = numpy.random.RandomState(numpy.random.MT19937(numpy.random.SeedSequence(self.seed)))
|
||||
|
||||
if self.noise_type == "gaussian":
|
||||
if self.noise_color:
|
||||
noise = rs.normal(0, 1, (image.height // self.size, image.width // self.size, 3)) * 255
|
||||
else:
|
||||
noise = rs.normal(0, 1, (image.height // self.size, image.width // self.size)) * 255
|
||||
noise = numpy.stack([noise] * 3, axis=-1)
|
||||
elif self.noise_type == "salt_and_pepper":
|
||||
if self.noise_color:
|
||||
noise = rs.choice(
|
||||
[0, 255], (image.height // self.size, image.width // self.size, 3), p=[1 - self.amount, self.amount]
|
||||
)
|
||||
else:
|
||||
noise = rs.choice(
|
||||
[0, 255], (image.height // self.size, image.width // self.size), p=[1 - self.amount, self.amount]
|
||||
)
|
||||
noise = numpy.stack([noise] * 3, axis=-1)
|
||||
|
||||
noise = Image.fromarray(noise.astype(numpy.uint8), mode="RGB").resize(
|
||||
(image.width, image.height), Image.Resampling.NEAREST
|
||||
)
|
||||
noisy_image = Image.blend(image.convert("RGB"), noise, self.amount).convert("RGBA")
|
||||
|
||||
# Paste back the alpha channel
|
||||
noisy_image.putalpha(alpha)
|
||||
|
||||
image_dto = context.images.save(image=noisy_image)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
@@ -26,7 +26,6 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager import LoadedModel
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -99,7 +98,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
# non_noised_latents_from_image
|
||||
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae.dtype)
|
||||
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
|
||||
with torch.inference_mode(), tiling_context:
|
||||
latents = ImageToLatentsInvocation._encode_to_tensor(vae, image_tensor)
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
title="Latents to Image",
|
||||
tags=["latents", "image", "vae", "l2i"],
|
||||
category="latents",
|
||||
version="1.3.1",
|
||||
version="1.3.0",
|
||||
)
|
||||
class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates an image from latents."""
|
||||
@@ -53,58 +53,16 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
|
||||
fp32: bool = InputField(default=False, description=FieldDescriptions.fp32)
|
||||
|
||||
def _estimate_working_memory(
|
||||
self, latents: torch.Tensor, use_tiling: bool, vae: AutoencoderKL | AutoencoderTiny
|
||||
) -> int:
|
||||
"""Estimate the working memory required by the invocation in bytes."""
|
||||
# It was found experimentally that the peak working memory scales linearly with the number of pixels and the
|
||||
# element size (precision). This estimate is accurate for both SD1 and SDXL.
|
||||
element_size = 4 if self.fp32 else 2
|
||||
scaling_constant = 960 # Determined experimentally.
|
||||
|
||||
if use_tiling:
|
||||
tile_size = self.tile_size
|
||||
if tile_size == 0:
|
||||
tile_size = vae.tile_sample_min_size
|
||||
assert isinstance(tile_size, int)
|
||||
out_h = tile_size
|
||||
out_w = tile_size
|
||||
working_memory = out_h * out_w * element_size * scaling_constant
|
||||
|
||||
# We add 25% to the working memory estimate when tiling is enabled to account for factors like tile overlap
|
||||
# and number of tiles. We could make this more precise in the future, but this should be good enough for
|
||||
# most use cases.
|
||||
working_memory = working_memory * 1.25
|
||||
else:
|
||||
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
|
||||
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
|
||||
working_memory = out_h * out_w * element_size * scaling_constant
|
||||
|
||||
if self.fp32:
|
||||
# If we are running in FP32, then we should account for the likely increase in model size (~250MB).
|
||||
working_memory += 250 * 2**20
|
||||
|
||||
# We add 20% to the working memory estimate to be safe.
|
||||
working_memory = int(working_memory * 1.2)
|
||||
return working_memory
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
use_tiling = self.tiled or context.config.get().force_tiled_decode
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
|
||||
|
||||
estimated_working_memory = self._estimate_working_memory(latents, use_tiling, vae_info.model)
|
||||
with (
|
||||
SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes),
|
||||
vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae),
|
||||
):
|
||||
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||
context.util.signal_progress("Running VAE decoder")
|
||||
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
|
||||
latents = latents.to(TorchDevice.choose_torch_device())
|
||||
latents = latents.to(vae.device)
|
||||
if self.fp32:
|
||||
vae.to(dtype=torch.float32)
|
||||
|
||||
@@ -130,7 +88,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
vae.to(dtype=torch.float16)
|
||||
latents = latents.half()
|
||||
|
||||
if use_tiling:
|
||||
if self.tiled or context.config.get().force_tiled_decode:
|
||||
vae.enable_tiling()
|
||||
else:
|
||||
vae.disable_tiling()
|
||||
|
||||
@@ -7,6 +7,7 @@ import torch
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@@ -538,3 +539,23 @@ class BoundingBoxInvocation(BaseInvocation):
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
@invocation(
|
||||
"image_batch",
|
||||
title="Image Batch",
|
||||
tags=["primitives", "image", "batch", "internal"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
class ImageBatchInvocation(BaseInvocation):
|
||||
"""Create a batched generation, where the workflow is executed once for each image in the batch."""
|
||||
|
||||
images: list[ImageField] = InputField(min_length=1, description="The images to batch over", input=Input.Direct)
|
||||
|
||||
def __init__(self):
|
||||
raise NotImplementedError("This class should never be executed or instantiated directly.")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
raise NotImplementedError("This class should never be executed or instantiated directly.")
|
||||
|
||||
@@ -16,7 +16,6 @@ from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -40,7 +39,7 @@ class SD3ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
vae.disable_tiling()
|
||||
|
||||
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae.dtype)
|
||||
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
|
||||
with torch.inference_mode():
|
||||
image_tensor_dist = vae.encode(image_tensor).latent_dist
|
||||
# TODO: Use seed to make sampling reproducible.
|
||||
|
||||
@@ -6,7 +6,6 @@ from einops import rearrange
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
@@ -27,7 +26,7 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
title="SD3 Latents to Image",
|
||||
tags=["latents", "image", "vae", "l2i", "sd3"],
|
||||
category="latents",
|
||||
version="1.3.1",
|
||||
version="1.3.0",
|
||||
)
|
||||
class SD3LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates an image from latents."""
|
||||
@@ -41,34 +40,16 @@ class SD3LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
def _estimate_working_memory(self, latents: torch.Tensor, vae: AutoencoderKL) -> int:
|
||||
"""Estimate the working memory required by the invocation in bytes."""
|
||||
# It was found experimentally that the peak working memory scales linearly with the number of pixels and the
|
||||
# element size (precision).
|
||||
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
|
||||
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
|
||||
element_size = next(vae.parameters()).element_size()
|
||||
scaling_constant = 1230 # Determined experimentally.
|
||||
working_memory = out_h * out_w * element_size * scaling_constant
|
||||
|
||||
# We add a 20% buffer to the working memory estimate to be safe.
|
||||
working_memory = working_memory * 1.2
|
||||
return int(working_memory)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
assert isinstance(vae_info.model, (AutoencoderKL))
|
||||
estimated_working_memory = self._estimate_working_memory(latents, vae_info.model)
|
||||
with (
|
||||
SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes),
|
||||
vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae),
|
||||
):
|
||||
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||
context.util.signal_progress("Running VAE")
|
||||
assert isinstance(vae, (AutoencoderKL))
|
||||
latents = latents.to(TorchDevice.choose_torch_device())
|
||||
latents = latents.to(vae.device)
|
||||
|
||||
vae.disable_tiling()
|
||||
|
||||
|
||||
@@ -10,10 +10,6 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.t5_model_identifier import (
|
||||
preprocess_t5_encoder_model_identifier,
|
||||
preprocess_t5_tokenizer_model_identifier,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import SubModelType
|
||||
|
||||
|
||||
@@ -92,8 +88,16 @@ class Sd3ModelLoaderInvocation(BaseInvocation):
|
||||
if self.clip_g_model
|
||||
else self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
)
|
||||
tokenizer_t5 = preprocess_t5_tokenizer_model_identifier(self.t5_encoder_model or self.model)
|
||||
t5_encoder = preprocess_t5_encoder_model_identifier(self.t5_encoder_model or self.model)
|
||||
tokenizer_t5 = (
|
||||
self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
|
||||
if self.t5_encoder_model
|
||||
else self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
|
||||
)
|
||||
t5_encoder = (
|
||||
self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
|
||||
if self.t5_encoder_model
|
||||
else self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
|
||||
)
|
||||
|
||||
return Sd3ModelLoaderOutput(
|
||||
transformer=TransformerField(transformer=transformer, loras=[]),
|
||||
|
||||
@@ -21,7 +21,6 @@ from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, SD3ConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
# The SD3 T5 Max Sequence Length set based on the default in diffusers.
|
||||
SD3_T5_MAX_SEQ_LEN = 256
|
||||
@@ -87,11 +86,14 @@ class Sd3TextEncoderInvocation(BaseInvocation):
|
||||
|
||||
def _t5_encode(self, context: InvocationContext, max_seq_len: int) -> torch.Tensor:
|
||||
assert self.t5_encoder is not None
|
||||
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
|
||||
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
|
||||
|
||||
prompt = [self.prompt]
|
||||
|
||||
with (
|
||||
context.models.load(self.t5_encoder.text_encoder) as t5_text_encoder,
|
||||
context.models.load(self.t5_encoder.tokenizer) as t5_tokenizer,
|
||||
t5_text_encoder_info as t5_text_encoder,
|
||||
t5_tokenizer_info as t5_tokenizer,
|
||||
):
|
||||
context.util.signal_progress("Running T5 encoder")
|
||||
assert isinstance(t5_text_encoder, T5EncoderModel)
|
||||
@@ -118,7 +120,7 @@ class Sd3TextEncoderInvocation(BaseInvocation):
|
||||
f" {max_seq_len} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = t5_text_encoder(text_input_ids.to(TorchDevice.choose_torch_device()))[0]
|
||||
prompt_embeds = t5_text_encoder(text_input_ids.to(t5_text_encoder.device))[0]
|
||||
|
||||
assert isinstance(prompt_embeds, torch.Tensor)
|
||||
return prompt_embeds
|
||||
@@ -126,12 +128,14 @@ class Sd3TextEncoderInvocation(BaseInvocation):
|
||||
def _clip_encode(
|
||||
self, context: InvocationContext, clip_model: CLIPField, tokenizer_max_length: int = 77
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
clip_tokenizer_info = context.models.load(clip_model.tokenizer)
|
||||
clip_text_encoder_info = context.models.load(clip_model.text_encoder)
|
||||
|
||||
prompt = [self.prompt]
|
||||
|
||||
clip_text_encoder_info = context.models.load(clip_model.text_encoder)
|
||||
with (
|
||||
clip_text_encoder_info.model_on_device() as (cached_weights, clip_text_encoder),
|
||||
context.models.load(clip_model.tokenizer) as clip_tokenizer,
|
||||
clip_tokenizer_info as clip_tokenizer,
|
||||
ExitStack() as exit_stack,
|
||||
):
|
||||
context.util.signal_progress("Running CLIP encoder")
|
||||
@@ -181,7 +185,7 @@ class Sd3TextEncoderInvocation(BaseInvocation):
|
||||
f" {tokenizer_max_length} tokens: {removed_text}"
|
||||
)
|
||||
prompt_embeds = clip_text_encoder(
|
||||
input_ids=text_input_ids.to(TorchDevice.choose_torch_device()), output_hidden_states=True
|
||||
input_ids=text_input_ids.to(clip_text_encoder.device), output_hidden_states=True
|
||||
)
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||
|
||||
@@ -22,7 +22,6 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
from invokeai.backend.tiles.tiles import calc_tiles_min_overlap
|
||||
from invokeai.backend.tiles.utils import TBLR, Tile
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.3.0")
|
||||
@@ -103,7 +102,7 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
(height * scale, width * scale, channels), dtype=torch.uint8, device=torch.device("cpu")
|
||||
)
|
||||
|
||||
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=spandrel_model.dtype)
|
||||
image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)
|
||||
|
||||
# Run the model on each tile.
|
||||
pbar = tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles")
|
||||
@@ -117,7 +116,9 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
raise CanceledException
|
||||
|
||||
# Extract the current tile from the input tensor.
|
||||
input_tile = image_tensor[:, :, tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right]
|
||||
input_tile = image_tensor[
|
||||
:, :, tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right
|
||||
].to(device=spandrel_model.device, dtype=spandrel_model.dtype)
|
||||
|
||||
# Run the model on the tile.
|
||||
output_tile = spandrel_model.run(input_tile)
|
||||
@@ -150,12 +151,15 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
return pil_image
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
|
||||
# revisit this.
|
||||
image = context.images.get_pil(self.image.image_name, mode="RGB")
|
||||
|
||||
# Load the model.
|
||||
spandrel_model_info = context.models.load(self.image_to_image_model)
|
||||
|
||||
def step_callback(step: int, total_steps: int) -> None:
|
||||
context.util.signal_progress(
|
||||
message=f"Processing tile {step}/{total_steps}",
|
||||
@@ -163,7 +167,7 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
)
|
||||
|
||||
# Do the upscaling.
|
||||
with context.models.load(self.image_to_image_model) as spandrel_model:
|
||||
with spandrel_model_info as spandrel_model:
|
||||
assert isinstance(spandrel_model, SpandrelImageToImageModel)
|
||||
|
||||
# Upscale the image
|
||||
@@ -196,12 +200,15 @@ class SpandrelImageToImageAutoscaleInvocation(SpandrelImageToImageInvocation):
|
||||
description="If true, the output image will be resized to the nearest multiple of 8 in both dimensions.",
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
|
||||
# revisit this.
|
||||
image = context.images.get_pil(self.image.image_name, mode="RGB")
|
||||
|
||||
# Load the model.
|
||||
spandrel_model_info = context.models.load(self.image_to_image_model)
|
||||
|
||||
# The target size of the image, determined by the provided scale. We'll run the upscaler until we hit this size.
|
||||
# Later, we may mutate this value if the model doesn't upscale the image or if the user requested a multiple of 8.
|
||||
target_width = int(image.width * self.scale)
|
||||
@@ -214,7 +221,7 @@ class SpandrelImageToImageAutoscaleInvocation(SpandrelImageToImageInvocation):
|
||||
)
|
||||
|
||||
# Do the upscaling.
|
||||
with context.models.load(self.image_to_image_model) as spandrel_model:
|
||||
with spandrel_model_info as spandrel_model:
|
||||
assert isinstance(spandrel_model, SpandrelImageToImageModel)
|
||||
|
||||
iteration = 1
|
||||
|
||||
@@ -201,24 +201,25 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
|
||||
device = TorchDevice.choose_torch_device()
|
||||
# Load the UNet model.
|
||||
unet_info = context.models.load(self.unet.unet)
|
||||
|
||||
with (
|
||||
ExitStack() as exit_stack,
|
||||
context.models.load(self.unet.unet) as unet,
|
||||
unet_info as unet,
|
||||
LayerPatcher.apply_smart_model_patches(
|
||||
model=unet, patches=_lora_loader(), prefix="lora_unet_", dtype=unet.dtype
|
||||
),
|
||||
):
|
||||
assert isinstance(unet, UNet2DConditionModel)
|
||||
latents = latents.to(device=device, dtype=unet.dtype)
|
||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||
if noise is not None:
|
||||
noise = noise.to(device=device, dtype=unet.dtype)
|
||||
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
seed=seed,
|
||||
unet_config=unet_config,
|
||||
)
|
||||
pipeline = self.create_pipeline(unet=unet, scheduler=scheduler)
|
||||
|
||||
@@ -227,7 +228,7 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
|
||||
context=context,
|
||||
positive_conditioning_field=self.positive_conditioning,
|
||||
negative_conditioning_field=self.negative_conditioning,
|
||||
device=device,
|
||||
device=unet.device,
|
||||
dtype=unet.dtype,
|
||||
latent_height=latent_tile_height,
|
||||
latent_width=latent_tile_width,
|
||||
@@ -240,7 +241,6 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
|
||||
context=context,
|
||||
control_input=self.control,
|
||||
latents_shape=list(latents.shape),
|
||||
device=device,
|
||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||
do_classifier_free_guidance=True,
|
||||
exit_stack=exit_stack,
|
||||
@@ -266,7 +266,7 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
|
||||
|
||||
timesteps, init_timestep, scheduler_step_kwargs = DenoiseLatentsInvocation.init_scheduler(
|
||||
scheduler,
|
||||
device=device,
|
||||
device=unet.device,
|
||||
steps=self.steps,
|
||||
denoising_start=self.denoising_start,
|
||||
denoising_end=self.denoising_end,
|
||||
|
||||
@@ -13,6 +13,7 @@ from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
import psutil
|
||||
import yaml
|
||||
from pydantic import BaseModel, Field, PrivateAttr, field_validator
|
||||
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict
|
||||
@@ -24,6 +25,8 @@ from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
||||
INIT_FILE = Path("invokeai.yaml")
|
||||
DB_FILE = Path("invokeai.db")
|
||||
LEGACY_INIT_FILE = Path("invokeai.init")
|
||||
DEFAULT_RAM_CACHE = 10.0
|
||||
DEFAULT_VRAM_CACHE = 0.25
|
||||
DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
|
||||
PRECISION = Literal["auto", "float16", "bfloat16", "float32"]
|
||||
ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
|
||||
@@ -33,6 +36,24 @@ LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
|
||||
CONFIG_SCHEMA_VERSION = "4.0.2"
|
||||
|
||||
|
||||
def get_default_ram_cache_size() -> float:
|
||||
"""Run a heuristic for the default RAM cache based on installed RAM."""
|
||||
|
||||
# On some machines, psutil.virtual_memory().total gives a value that is slightly less than the actual RAM, so the
|
||||
# limits are set slightly lower than than what we expect the actual RAM to be.
|
||||
|
||||
GB = 1024**3
|
||||
max_ram = psutil.virtual_memory().total / GB
|
||||
|
||||
if max_ram >= 60:
|
||||
return 15.0
|
||||
if max_ram >= 30:
|
||||
return 7.5
|
||||
if max_ram >= 14:
|
||||
return 4.0
|
||||
return 2.1 # 2.1 is just large enough for sd 1.5 ;-)
|
||||
|
||||
|
||||
class URLRegexTokenPair(BaseModel):
|
||||
url_regex: str = Field(description="Regular expression to match against the URL")
|
||||
token: str = Field(description="Token to use when the URL matches the regex")
|
||||
@@ -82,14 +103,10 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
profile_graphs: Enable graph profiling using `cProfile`.
|
||||
profile_prefix: An optional prefix for profile output files.
|
||||
profiles_dir: Path to profiles output directory.
|
||||
max_cache_ram_gb: The maximum amount of CPU RAM to use for model caching in GB. If unset, the limit will be configured based on the available RAM. In most cases, it is recommended to leave this unset.
|
||||
max_cache_vram_gb: The amount of VRAM to use for model caching in GB. If unset, the limit will be configured based on the available VRAM and the device_working_mem_gb. In most cases, it is recommended to leave this unset.
|
||||
ram: Maximum memory amount used by memory model cache for rapid switching (GB).
|
||||
vram: Amount of VRAM reserved for model storage (GB).
|
||||
lazy_offload: Keep models in VRAM until their space is needed.
|
||||
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
|
||||
device_working_mem_gb: The amount of working memory to keep available on the compute device (in GB). Has no effect if running on CPU. If you are experiencing OOM errors, try increasing this value.
|
||||
enable_partial_loading: Enable partial loading of models. This enables models to run with reduced VRAM requirements (at the cost of slower speed) by streaming the model from RAM to VRAM as its used. In some edge cases, partial loading can cause models to run more slowly if they were previously being fully loaded into VRAM.
|
||||
ram: DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_ram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.
|
||||
vram: DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_vram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.
|
||||
lazy_offload: DEPRECATED: This setting is no longer used. Lazy-offloading is enabled by default. This config setting will be removed once the new model cache behavior is stable.
|
||||
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
|
||||
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
|
||||
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
|
||||
@@ -157,15 +174,10 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
profiles_dir: Path = Field(default=Path("profiles"), description="Path to profiles output directory.")
|
||||
|
||||
# CACHE
|
||||
max_cache_ram_gb: Optional[float] = Field(default=None, gt=0, description="The maximum amount of CPU RAM to use for model caching in GB. If unset, the limit will be configured based on the available RAM. In most cases, it is recommended to leave this unset.")
|
||||
max_cache_vram_gb: Optional[float] = Field(default=None, ge=0, description="The amount of VRAM to use for model caching in GB. If unset, the limit will be configured based on the available VRAM and the device_working_mem_gb. In most cases, it is recommended to leave this unset.")
|
||||
ram: float = Field(default_factory=get_default_ram_cache_size, gt=0, description="Maximum memory amount used by memory model cache for rapid switching (GB).")
|
||||
vram: float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (GB).")
|
||||
lazy_offload: bool = Field(default=True, description="Keep models in VRAM until their space is needed.")
|
||||
log_memory_usage: bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.")
|
||||
device_working_mem_gb: float = Field(default=3, description="The amount of working memory to keep available on the compute device (in GB). Has no effect if running on CPU. If you are experiencing OOM errors, try increasing this value.")
|
||||
enable_partial_loading: bool = Field(default=False, description="Enable partial loading of models. This enables models to run with reduced VRAM requirements (at the cost of slower speed) by streaming the model from RAM to VRAM as its used. In some edge cases, partial loading can cause models to run more slowly if they were previously being fully loaded into VRAM.")
|
||||
# Deprecated CACHE configs
|
||||
ram: Optional[float] = Field(default=None, gt=0, description="DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_ram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.")
|
||||
vram: Optional[float] = Field(default=None, ge=0, description="DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_vram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.")
|
||||
lazy_offload: bool = Field(default=True, description="DEPRECATED: This setting is no longer used. Lazy-offloading is enabled by default. This config setting will be removed once the new model cache behavior is stable.")
|
||||
|
||||
# DEVICE
|
||||
device: DEVICE = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.")
|
||||
|
||||
@@ -82,12 +82,11 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
logger.setLevel(app_config.log_level.upper())
|
||||
|
||||
ram_cache = ModelCache(
|
||||
execution_device_working_mem_gb=app_config.device_working_mem_gb,
|
||||
enable_partial_loading=app_config.enable_partial_loading,
|
||||
max_ram_cache_size_gb=app_config.max_cache_ram_gb,
|
||||
max_vram_cache_size_gb=app_config.max_cache_vram_gb,
|
||||
execution_device=execution_device or TorchDevice.choose_torch_device(),
|
||||
max_cache_size=app_config.ram,
|
||||
max_vram_cache_size=app_config.vram,
|
||||
lazy_offloading=app_config.lazy_offload,
|
||||
logger=logger,
|
||||
execution_device=execution_device or TorchDevice.choose_torch_device(),
|
||||
)
|
||||
loader = ModelLoadService(
|
||||
app_config=app_config,
|
||||
|
||||
@@ -108,16 +108,8 @@ class Batch(BaseModel):
|
||||
return v
|
||||
for batch_data_list in v:
|
||||
for datum in batch_data_list:
|
||||
if not datum.items:
|
||||
continue
|
||||
|
||||
# Special handling for numbers - they can be mixed
|
||||
# TODO(psyche): Update BatchDatum to have a `type` field to specify the type of the items, then we can have strict float and int fields
|
||||
if all(isinstance(item, (int, float)) for item in datum.items):
|
||||
continue
|
||||
|
||||
# Get the type of the first item in the list
|
||||
first_item_type = type(datum.items[0])
|
||||
first_item_type = type(datum.items[0]) if datum.items else None
|
||||
for item in datum.items:
|
||||
if type(item) is not first_item_type:
|
||||
raise BatchItemsTypeError("All items in a batch must have the same type")
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.backend.model_manager.config import BaseModelType, SubModelType
|
||||
|
||||
|
||||
def preprocess_t5_encoder_model_identifier(model_identifier: ModelIdentifierField) -> ModelIdentifierField:
|
||||
"""A helper function to normalize a T5 encoder model identifier so that T5 models associated with FLUX
|
||||
or SD3 models can be used interchangeably.
|
||||
"""
|
||||
if model_identifier.base == BaseModelType.Any:
|
||||
return model_identifier.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
elif model_identifier.base == BaseModelType.StableDiffusion3:
|
||||
return model_identifier.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
|
||||
else:
|
||||
raise ValueError(f"Unsupported model base: {model_identifier.base}")
|
||||
|
||||
|
||||
def preprocess_t5_tokenizer_model_identifier(model_identifier: ModelIdentifierField) -> ModelIdentifierField:
|
||||
"""A helper function to normalize a T5 tokenizer model identifier so that T5 models associated with FLUX
|
||||
or SD3 models can be used interchangeably.
|
||||
"""
|
||||
if model_identifier.base == BaseModelType.Any:
|
||||
return model_identifier.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||
elif model_identifier.base == BaseModelType.StableDiffusion3:
|
||||
return model_identifier.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
|
||||
else:
|
||||
raise ValueError(f"Unsupported model base: {model_identifier.base}")
|
||||
@@ -8,7 +8,6 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import XlabsIpAdapterFlux
|
||||
from invokeai.backend.flux.modules.layers import DoubleStreamBlock
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
class XLabsIPAdapterExtension:
|
||||
@@ -46,7 +45,7 @@ class XLabsIPAdapterExtension:
|
||||
) -> torch.Tensor:
|
||||
clip_image_processor = CLIPImageProcessor()
|
||||
clip_image: torch.Tensor = clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
||||
clip_image = clip_image.to(device=TorchDevice.choose_torch_device(), dtype=image_encoder.dtype)
|
||||
clip_image = clip_image.to(device=image_encoder.device, dtype=image_encoder.dtype)
|
||||
clip_image_embeds = image_encoder(clip_image).image_embeds
|
||||
return clip_image_embeds
|
||||
|
||||
|
||||
@@ -1,19 +1,11 @@
|
||||
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||
|
||||
from torch import Tensor, nn
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
|
||||
class HFEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder: PreTrainedModel,
|
||||
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
|
||||
is_clip: bool,
|
||||
max_length: int,
|
||||
):
|
||||
def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool, max_length: int):
|
||||
super().__init__()
|
||||
self.max_length = max_length
|
||||
self.is_clip = is_clip
|
||||
@@ -34,7 +26,7 @@ class HFEncoder(nn.Module):
|
||||
)
|
||||
|
||||
outputs = self.hf_module(
|
||||
input_ids=batch_encoding["input_ids"].to(TorchDevice.choose_torch_device()),
|
||||
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
|
||||
attention_mask=None,
|
||||
output_hidden_states=False,
|
||||
)
|
||||
|
||||
@@ -18,7 +18,6 @@ from invokeai.backend.image_util.util import (
|
||||
resize_image_to_resolution,
|
||||
safe_step,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
|
||||
|
||||
|
||||
class DoubleConvBlock(torch.nn.Module):
|
||||
@@ -110,7 +109,7 @@ class HEDProcessor:
|
||||
Returns:
|
||||
The detected edges.
|
||||
"""
|
||||
device = get_effective_device(self.network)
|
||||
device = next(iter(self.network.parameters())).device
|
||||
np_image = pil_to_np(input_image)
|
||||
np_image = normalize_image_channel_count(np_image)
|
||||
np_image = resize_image_to_resolution(np_image, detect_resolution)
|
||||
@@ -184,7 +183,7 @@ class HEDEdgeDetector:
|
||||
The detected edges.
|
||||
"""
|
||||
|
||||
device = get_effective_device(self.model)
|
||||
device = next(iter(self.model.parameters())).device
|
||||
|
||||
np_image = pil_to_np(image)
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ from PIL import Image
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.model_manager.config import AnyModel
|
||||
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
|
||||
|
||||
|
||||
def norm_img(np_img):
|
||||
@@ -32,7 +31,7 @@ class LaMA:
|
||||
mask = norm_img(mask)
|
||||
mask = (mask > 0) * 1
|
||||
|
||||
device = get_effective_device(self._model)
|
||||
device = next(self._model.buffers()).device
|
||||
image = torch.from_numpy(image).unsqueeze(0).to(device)
|
||||
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
|
||||
|
||||
|
||||
@@ -17,7 +17,6 @@ from invokeai.backend.image_util.util import (
|
||||
pil_to_np,
|
||||
resize_image_to_resolution,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
@@ -131,7 +130,7 @@ class LineartProcessor:
|
||||
Returns:
|
||||
The detected lineart.
|
||||
"""
|
||||
device = get_effective_device(self.model)
|
||||
device = next(iter(self.model.parameters())).device
|
||||
|
||||
np_image = pil_to_np(input_image)
|
||||
np_image = normalize_image_channel_count(np_image)
|
||||
@@ -202,7 +201,7 @@ class LineartEdgeDetector:
|
||||
Returns:
|
||||
The detected edges.
|
||||
"""
|
||||
device = get_effective_device(self.model)
|
||||
device = next(iter(self.model.parameters())).device
|
||||
|
||||
np_image = pil_to_np(image)
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ from invokeai.backend.image_util.util import (
|
||||
pil_to_np,
|
||||
resize_image_to_resolution,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
|
||||
|
||||
|
||||
class UnetGenerator(nn.Module):
|
||||
@@ -172,7 +171,7 @@ class LineartAnimeProcessor:
|
||||
Returns:
|
||||
The detected lineart.
|
||||
"""
|
||||
device = get_effective_device(self.model)
|
||||
device = next(iter(self.model.parameters())).device
|
||||
np_image = pil_to_np(input_image)
|
||||
|
||||
np_image = normalize_image_channel_count(np_image)
|
||||
@@ -240,7 +239,7 @@ class LineartAnimeEdgeDetector:
|
||||
|
||||
def run(self, image: Image.Image) -> Image.Image:
|
||||
"""Processes an image and returns the detected edges."""
|
||||
device = get_effective_device(self.model)
|
||||
device = next(iter(self.model.parameters())).device
|
||||
|
||||
np_image = pil_to_np(image)
|
||||
|
||||
|
||||
@@ -14,8 +14,6 @@ import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
|
||||
|
||||
|
||||
def deccode_output_score_and_ptss(tpMap, topk_n = 200, ksize = 5):
|
||||
'''
|
||||
@@ -51,7 +49,7 @@ def pred_lines(image, model,
|
||||
dist_thr=20.0):
|
||||
h, w, _ = image.shape
|
||||
|
||||
device = get_effective_device(model)
|
||||
device = next(iter(model.parameters())).device
|
||||
h_ratio, w_ratio = [h / input_shape[0], w / input_shape[1]]
|
||||
|
||||
resized_image = np.concatenate([cv2.resize(image, (input_shape[1], input_shape[0]), interpolation=cv2.INTER_AREA),
|
||||
@@ -110,7 +108,7 @@ def pred_squares(image,
|
||||
'''
|
||||
h, w, _ = image.shape
|
||||
original_shape = [h, w]
|
||||
device = get_effective_device(model)
|
||||
device = next(iter(model.parameters())).device
|
||||
|
||||
resized_image = np.concatenate([cv2.resize(image, (input_shape[0], input_shape[1]), interpolation=cv2.INTER_AREA),
|
||||
np.ones([input_shape[0], input_shape[1], 1])], axis=-1)
|
||||
|
||||
@@ -13,7 +13,6 @@ from PIL import Image
|
||||
|
||||
from invokeai.backend.image_util.normal_bae.nets.NNET import NNET
|
||||
from invokeai.backend.image_util.util import np_to_pil, pil_to_np, resize_to_multiple
|
||||
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
|
||||
|
||||
|
||||
class NormalMapDetector:
|
||||
@@ -65,7 +64,7 @@ class NormalMapDetector:
|
||||
def run(self, image: Image.Image):
|
||||
"""Processes an image and returns the detected normal map."""
|
||||
|
||||
device = get_effective_device(self.model)
|
||||
device = next(iter(self.model.parameters())).device
|
||||
np_image = pil_to_np(image)
|
||||
|
||||
height, width, _channels = np_image.shape
|
||||
|
||||
@@ -11,7 +11,6 @@ from PIL import Image
|
||||
|
||||
from invokeai.backend.image_util.pidi.model import PiDiNet, pidinet
|
||||
from invokeai.backend.image_util.util import nms, normalize_image_channel_count, np_to_pil, pil_to_np, safe_step
|
||||
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
|
||||
|
||||
|
||||
class PIDINetDetector:
|
||||
@@ -46,7 +45,7 @@ class PIDINetDetector:
|
||||
) -> Image.Image:
|
||||
"""Processes an image and returns the detected edges."""
|
||||
|
||||
device = get_effective_device(self.model)
|
||||
device = next(iter(self.model.parameters())).device
|
||||
|
||||
np_img = pil_to_np(image)
|
||||
np_img = normalize_image_channel_count(np_img)
|
||||
|
||||
@@ -57,31 +57,25 @@ class LoadedModelWithoutConfig:
|
||||
self._cache = cache
|
||||
|
||||
def __enter__(self) -> AnyModel:
|
||||
self._cache.lock(self._cache_record, None)
|
||||
self._cache.lock(self._cache_record)
|
||||
return self.model
|
||||
|
||||
def __exit__(self, *args: Any, **kwargs: Any) -> None:
|
||||
self._cache.unlock(self._cache_record)
|
||||
|
||||
@contextmanager
|
||||
def model_on_device(
|
||||
self, working_mem_bytes: Optional[int] = None
|
||||
) -> Generator[Tuple[Optional[Dict[str, torch.Tensor]], AnyModel], None, None]:
|
||||
"""Return a tuple consisting of the model's state dict (if it exists) and the locked model on execution device.
|
||||
|
||||
:param working_mem_bytes: The amount of working memory to keep available on the compute device when loading the
|
||||
model.
|
||||
"""
|
||||
self._cache.lock(self._cache_record, working_mem_bytes)
|
||||
def model_on_device(self) -> Generator[Tuple[Optional[Dict[str, torch.Tensor]], AnyModel], None, None]:
|
||||
"""Return a tuple consisting of the model's state dict (if it exists) and the locked model on execution device."""
|
||||
self._cache.lock(self._cache_record)
|
||||
try:
|
||||
yield (self._cache_record.cached_model.get_cpu_state_dict(), self._cache_record.cached_model.model)
|
||||
yield (self._cache_record.state_dict, self._cache_record.model)
|
||||
finally:
|
||||
self._cache.unlock(self._cache_record)
|
||||
|
||||
@property
|
||||
def model(self) -> AnyModel:
|
||||
"""Return the model without locking it."""
|
||||
return self._cache_record.cached_model.model
|
||||
return self._cache_record.model
|
||||
|
||||
|
||||
class LoadedModel(LoadedModelWithoutConfig):
|
||||
|
||||
@@ -1,21 +1,38 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import (
|
||||
CachedModelOnlyFullLoad,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import (
|
||||
CachedModelWithPartialLoad,
|
||||
)
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheRecord:
|
||||
"""A class that represents a model in the model cache."""
|
||||
"""
|
||||
Elements of the cache:
|
||||
|
||||
key: Unique key for each model, same as used in the models database.
|
||||
model: Model in memory.
|
||||
state_dict: A read-only copy of the model's state dict in RAM. It will be
|
||||
used as a template for creating a copy in the VRAM.
|
||||
size: Size of the model
|
||||
loaded: True if the model's state dict is currently in VRAM
|
||||
|
||||
Before a model is executed, the state_dict template is copied into VRAM,
|
||||
and then injected into the model. When the model is finished, the VRAM
|
||||
copy of the state dict is deleted, and the RAM version is reinjected
|
||||
into the model.
|
||||
|
||||
The state_dict should be treated as a read-only attribute. Do not attempt
|
||||
to patch or otherwise modify it. Instead, patch the copy of the state_dict
|
||||
after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel`
|
||||
context manager call `model_on_device()`.
|
||||
"""
|
||||
|
||||
# Cache key.
|
||||
key: str
|
||||
# Model in memory.
|
||||
cached_model: CachedModelWithPartialLoad | CachedModelOnlyFullLoad
|
||||
model: Any
|
||||
device: torch.device
|
||||
state_dict: Optional[Dict[str, torch.Tensor]]
|
||||
size: int
|
||||
loaded: bool = False
|
||||
_locks: int = 0
|
||||
|
||||
def lock(self) -> None:
|
||||
@@ -28,6 +45,6 @@ class CacheRecord:
|
||||
assert self._locks >= 0
|
||||
|
||||
@property
|
||||
def is_locked(self) -> bool:
|
||||
def locked(self) -> bool:
|
||||
"""Return true if record is locked."""
|
||||
return self._locks > 0
|
||||
|
||||
@@ -7,6 +7,18 @@ from invokeai.backend.util.calc_tensor_size import calc_tensor_size
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
|
||||
def set_nested_attr(obj: object, attr: str, value: object):
|
||||
"""A helper function that extends setattr() to support nested attributes.
|
||||
|
||||
Example:
|
||||
set_nested_attr(model, "module.encoder.conv1.weight", new_conv1_weight)
|
||||
"""
|
||||
attrs = attr.split(".")
|
||||
for attr in attrs[:-1]:
|
||||
obj = getattr(obj, attr)
|
||||
setattr(obj, attrs[-1], value)
|
||||
|
||||
|
||||
class CachedModelWithPartialLoad:
|
||||
"""A wrapper around a PyTorch model to handle partial loads and unloads between the CPU and the compute device.
|
||||
|
||||
@@ -21,14 +33,9 @@ class CachedModelWithPartialLoad:
|
||||
# A CPU read-only copy of the model's state dict.
|
||||
self._cpu_state_dict: dict[str, torch.Tensor] = model.state_dict()
|
||||
|
||||
# A dictionary of the size of each tensor in the state dict.
|
||||
# HACK(ryand): We use this dictionary any time we are doing byte tracking calculations. We do this for
|
||||
# consistency in case the application code has modified the model's size (e.g. by casting to a different
|
||||
# precision). Of course, this means that we are making model cache load/unload decisions based on model size
|
||||
# data that may not be fully accurate.
|
||||
self._state_dict_bytes = {k: calc_tensor_size(v) for k, v in self._cpu_state_dict.items()}
|
||||
|
||||
self._total_bytes = sum(self._state_dict_bytes.values())
|
||||
# TODO(ryand): Handle the case where the model sizes changes after initial load (e.g. due to dtype casting).
|
||||
# Consider how we should handle this for both self._total_bytes and self._cur_vram_bytes.
|
||||
self._total_bytes = sum(calc_tensor_size(p) for p in self._cpu_state_dict.values())
|
||||
self._cur_vram_bytes: int | None = None
|
||||
|
||||
self._modules_that_support_autocast = self._find_modules_that_support_autocast()
|
||||
@@ -84,9 +91,7 @@ class CachedModelWithPartialLoad:
|
||||
if self._cur_vram_bytes is None:
|
||||
cur_state_dict = self._model.state_dict()
|
||||
self._cur_vram_bytes = sum(
|
||||
self._state_dict_bytes[k]
|
||||
for k, v in cur_state_dict.items()
|
||||
if v.device.type == self._compute_device.type
|
||||
calc_tensor_size(p) for p in cur_state_dict.values() if p.device.type == self._compute_device.type
|
||||
)
|
||||
return self._cur_vram_bytes
|
||||
|
||||
@@ -118,7 +123,7 @@ class CachedModelWithPartialLoad:
|
||||
if param.device.type == self._compute_device.type:
|
||||
continue
|
||||
|
||||
param_size = self._state_dict_bytes[key]
|
||||
param_size = calc_tensor_size(param)
|
||||
cur_state_dict[key] = param.to(self._compute_device, copy=True)
|
||||
vram_bytes_loaded += param_size
|
||||
|
||||
@@ -135,7 +140,7 @@ class CachedModelWithPartialLoad:
|
||||
if param.device.type == self._compute_device.type:
|
||||
continue
|
||||
|
||||
param_size = self._state_dict_bytes[key]
|
||||
param_size = calc_tensor_size(param)
|
||||
if vram_bytes_loaded + param_size > vram_bytes_to_load:
|
||||
# TODO(ryand): Should we just break here? If we couldn't fit this parameter into VRAM, is it really
|
||||
# worth continuing to search for a smaller parameter that would fit?
|
||||
@@ -156,6 +161,7 @@ class CachedModelWithPartialLoad:
|
||||
|
||||
if fully_loaded:
|
||||
self._set_autocast_enabled_in_all_modules(False)
|
||||
# TODO(ryand): Warn if the self.cur_vram_bytes() and self.total_bytes() are out of sync.
|
||||
else:
|
||||
self._set_autocast_enabled_in_all_modules(True)
|
||||
|
||||
@@ -166,17 +172,13 @@ class CachedModelWithPartialLoad:
|
||||
return vram_bytes_loaded
|
||||
|
||||
@torch.no_grad()
|
||||
def partial_unload_from_vram(self, vram_bytes_to_free: int, keep_required_weights_in_vram: bool = False) -> int:
|
||||
def partial_unload_from_vram(self, vram_bytes_to_free: int) -> int:
|
||||
"""Unload weights from VRAM until vram_bytes_to_free bytes are freed. Or the entire model is unloaded.
|
||||
|
||||
:param keep_required_weights_in_vram: If True, any weights that must be kept in VRAM to run the model will be
|
||||
kept in VRAM.
|
||||
|
||||
Returns:
|
||||
The number of bytes unloaded from VRAM.
|
||||
"""
|
||||
vram_bytes_freed = 0
|
||||
required_weights_in_vram = 0
|
||||
|
||||
offload_device = "cpu"
|
||||
cur_state_dict = self._model.state_dict()
|
||||
@@ -187,12 +189,8 @@ class CachedModelWithPartialLoad:
|
||||
if param.device.type == offload_device:
|
||||
continue
|
||||
|
||||
if keep_required_weights_in_vram and key in self._keys_in_modules_that_do_not_support_autocast:
|
||||
required_weights_in_vram += self._state_dict_bytes[key]
|
||||
continue
|
||||
|
||||
cur_state_dict[key] = self._cpu_state_dict[key]
|
||||
vram_bytes_freed += self._state_dict_bytes[key]
|
||||
vram_bytes_freed += calc_tensor_size(param)
|
||||
|
||||
if vram_bytes_freed > 0:
|
||||
self._model.load_state_dict(cur_state_dict, assign=True)
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
|
||||
@contextmanager
|
||||
def log_operation_vram_usage(operation_name: str):
|
||||
"""A helper function for tuning working memory requirements for memory-intensive ops.
|
||||
|
||||
Sample usage:
|
||||
|
||||
```python
|
||||
with log_operation_vram_usage("some_operation"):
|
||||
some_operation()
|
||||
```
|
||||
"""
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
max_allocated_before = torch.cuda.max_memory_allocated()
|
||||
max_reserved_before = torch.cuda.max_memory_reserved()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch.cuda.synchronize()
|
||||
max_allocated_after = torch.cuda.max_memory_allocated()
|
||||
max_reserved_after = torch.cuda.max_memory_reserved()
|
||||
logger = InvokeAILogger.get_logger()
|
||||
logger.info(
|
||||
f">>>{operation_name} Peak VRAM allocated: {(max_allocated_after - max_allocated_before) / 2**20} MB, "
|
||||
f"Peak VRAM reserved: {(max_reserved_after - max_reserved_before) / 2**20} MB"
|
||||
)
|
||||
@@ -1,29 +1,24 @@
|
||||
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
|
||||
# TODO: Add Stalker's proper name to copyright
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import math
|
||||
import time
|
||||
from logging import Logger
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel, SubModelType
|
||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot
|
||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
|
||||
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import (
|
||||
CachedModelOnlyFullLoad,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import (
|
||||
CachedModelWithPartialLoad,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import (
|
||||
apply_custom_layers_to_model,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.backend.util.prefix_logger_adapter import PrefixedLoggerAdapter
|
||||
|
||||
# Size of a GB in bytes.
|
||||
GB = 2**30
|
||||
@@ -34,7 +29,6 @@ MB = 2**20
|
||||
|
||||
# TODO(ryand): Where should this go? The ModelCache shouldn't be concerned with submodels.
|
||||
def get_model_cache_key(model_key: str, submodel_type: Optional[SubModelType] = None) -> str:
|
||||
"""Get the cache key for a model based on the optional submodel type."""
|
||||
if submodel_type:
|
||||
return f"{model_key}:{submodel_type.value}"
|
||||
else:
|
||||
@@ -76,51 +70,61 @@ class ModelCache:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
execution_device_working_mem_gb: float,
|
||||
enable_partial_loading: bool,
|
||||
max_ram_cache_size_gb: float | None = None,
|
||||
max_vram_cache_size_gb: float | None = None,
|
||||
execution_device: torch.device | str = "cuda",
|
||||
storage_device: torch.device | str = "cpu",
|
||||
max_cache_size: float,
|
||||
max_vram_cache_size: float,
|
||||
execution_device: torch.device = torch.device("cuda"),
|
||||
storage_device: torch.device = torch.device("cpu"),
|
||||
lazy_offloading: bool = True,
|
||||
log_memory_usage: bool = False,
|
||||
logger: Optional[Logger] = None,
|
||||
):
|
||||
"""Initialize the model RAM cache.
|
||||
"""
|
||||
Initialize the model RAM cache.
|
||||
|
||||
:param execution_device_working_mem_gb: The amount of working memory to keep on the GPU (in GB) i.e. non-model
|
||||
VRAM.
|
||||
:param enable_partial_loading: Whether to enable partial loading of models.
|
||||
:param max_ram_cache_size_gb: The maximum amount of CPU RAM to use for model caching in GB. This parameter is
|
||||
kept to maintain compatibility with previous versions of the model cache, but should be deprecated in the
|
||||
future. If set, this parameter overrides the default cache size logic.
|
||||
:param max_vram_cache_size_gb: The amount of VRAM to use for model caching in GB. This parameter is kept to
|
||||
maintain compatibility with previous versions of the model cache, but should be deprecated in the future.
|
||||
If set, this parameter overrides the default cache size logic.
|
||||
:param max_cache_size: Maximum size of the storage_device cache in GBs.
|
||||
:param max_vram_cache_size: Maximum size of the execution_device cache in GBs.
|
||||
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
||||
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
||||
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
||||
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
|
||||
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
|
||||
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
||||
behaviour.
|
||||
:param logger: InvokeAILogger to use (otherwise creates one)
|
||||
"""
|
||||
self._enable_partial_loading = enable_partial_loading
|
||||
self._execution_device_working_mem_gb = execution_device_working_mem_gb
|
||||
self._execution_device: torch.device = torch.device(execution_device)
|
||||
self._storage_device: torch.device = torch.device(storage_device)
|
||||
|
||||
self._max_ram_cache_size_gb = max_ram_cache_size_gb
|
||||
self._max_vram_cache_size_gb = max_vram_cache_size_gb
|
||||
|
||||
self._logger = PrefixedLoggerAdapter(
|
||||
logger or InvokeAILogger.get_logger(self.__class__.__name__), "MODEL CACHE"
|
||||
)
|
||||
# allow lazy offloading only when vram cache enabled
|
||||
self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
||||
self._max_cache_size: float = max_cache_size
|
||||
self._max_vram_cache_size: float = max_vram_cache_size
|
||||
self._execution_device: torch.device = execution_device
|
||||
self._storage_device: torch.device = storage_device
|
||||
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
|
||||
self._log_memory_usage = log_memory_usage
|
||||
self._stats: Optional[CacheStats] = None
|
||||
|
||||
self._cached_models: Dict[str, CacheRecord] = {}
|
||||
self._cache_stack: List[str] = []
|
||||
|
||||
@property
|
||||
def max_cache_size(self) -> float:
|
||||
"""Return the cap on cache size."""
|
||||
return self._max_cache_size
|
||||
|
||||
@max_cache_size.setter
|
||||
def max_cache_size(self, value: float) -> None:
|
||||
"""Set the cap on cache size."""
|
||||
self._max_cache_size = value
|
||||
|
||||
@property
|
||||
def max_vram_cache_size(self) -> float:
|
||||
"""Return the cap on vram cache size."""
|
||||
return self._max_vram_cache_size
|
||||
|
||||
@max_vram_cache_size.setter
|
||||
def max_vram_cache_size(self, value: float) -> None:
|
||||
"""Set the cap on vram cache size."""
|
||||
self._max_vram_cache_size = value
|
||||
|
||||
@property
|
||||
def stats(self) -> Optional[CacheStats]:
|
||||
"""Return collected CacheStats object."""
|
||||
@@ -128,17 +132,17 @@ class ModelCache:
|
||||
|
||||
@stats.setter
|
||||
def stats(self, stats: CacheStats) -> None:
|
||||
"""Set the CacheStats object for collecting cache statistics."""
|
||||
"""Set the CacheStats object for collectin cache statistics."""
|
||||
self._stats = stats
|
||||
|
||||
def put(self, key: str, model: AnyModel) -> None:
|
||||
"""Add a model to the cache."""
|
||||
def put(
|
||||
self,
|
||||
key: str,
|
||||
model: AnyModel,
|
||||
) -> None:
|
||||
"""Insert model into the cache."""
|
||||
if key in self._cached_models:
|
||||
self._logger.debug(
|
||||
f"Attempted to add model {key} ({model.__class__.__name__}), but it already exists in the cache. No action necessary."
|
||||
)
|
||||
return
|
||||
|
||||
size = calc_model_size_by_data(self._logger, model)
|
||||
self.make_room(size)
|
||||
|
||||
@@ -146,26 +150,17 @@ class ModelCache:
|
||||
if isinstance(model, torch.nn.Module):
|
||||
apply_custom_layers_to_model(model)
|
||||
|
||||
# Partial loading only makes sense on CUDA.
|
||||
# - When running on CPU, there is no 'loading' to do.
|
||||
# - When running on MPS, memory is shared with the CPU, so the default OS memory management already handles this
|
||||
# well.
|
||||
running_with_cuda = self._execution_device.type == "cuda"
|
||||
|
||||
# Wrap model.
|
||||
if isinstance(model, torch.nn.Module) and running_with_cuda and self._enable_partial_loading:
|
||||
wrapped_model = CachedModelWithPartialLoad(model, self._execution_device)
|
||||
else:
|
||||
wrapped_model = CachedModelOnlyFullLoad(model, self._execution_device, size)
|
||||
|
||||
cache_record = CacheRecord(key=key, cached_model=wrapped_model)
|
||||
running_on_cpu = self._execution_device == torch.device("cpu")
|
||||
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) and not running_on_cpu else None
|
||||
cache_record = CacheRecord(key=key, model=model, device=self._storage_device, state_dict=state_dict, size=size)
|
||||
self._cached_models[key] = cache_record
|
||||
self._cache_stack.append(key)
|
||||
self._logger.debug(
|
||||
f"Added model {key} (Type: {model.__class__.__name__}, Wrap mode: {wrapped_model.__class__.__name__}, Model size: {size/MB:.2f}MB)"
|
||||
)
|
||||
|
||||
def get(self, key: str, stats_name: Optional[str] = None) -> CacheRecord:
|
||||
def get(
|
||||
self,
|
||||
key: str,
|
||||
stats_name: Optional[str] = None,
|
||||
) -> CacheRecord:
|
||||
"""Retrieve a model from the cache.
|
||||
|
||||
:param key: Model key
|
||||
@@ -179,7 +174,6 @@ class ModelCache:
|
||||
else:
|
||||
if self.stats:
|
||||
self.stats.misses += 1
|
||||
self._logger.debug(f"Cache miss: {key}")
|
||||
raise IndexError(f"The model with key {key} is not in the cache.")
|
||||
|
||||
cache_entry = self._cached_models[key]
|
||||
@@ -187,44 +181,37 @@ class ModelCache:
|
||||
# more stats
|
||||
if self.stats:
|
||||
stats_name = stats_name or key
|
||||
self.stats.high_watermark = max(self.stats.high_watermark, self._get_ram_in_use())
|
||||
self.stats.cache_size = int(self._max_cache_size * GB)
|
||||
self.stats.high_watermark = max(self.stats.high_watermark, self._get_cache_size())
|
||||
self.stats.in_cache = len(self._cached_models)
|
||||
self.stats.loaded_model_sizes[stats_name] = max(
|
||||
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.cached_model.total_bytes()
|
||||
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size
|
||||
)
|
||||
|
||||
# This moves the entry to the top (right end) of the stack.
|
||||
# this moves the entry to the top (right end) of the stack
|
||||
self._cache_stack = [k for k in self._cache_stack if k != key]
|
||||
self._cache_stack.append(key)
|
||||
|
||||
self._logger.debug(f"Cache hit: {key} (Type: {cache_entry.cached_model.model.__class__.__name__})")
|
||||
return cache_entry
|
||||
|
||||
def lock(self, cache_entry: CacheRecord, working_mem_bytes: Optional[int]) -> None:
|
||||
def lock(self, cache_entry: CacheRecord) -> None:
|
||||
"""Lock a model for use and move it into VRAM."""
|
||||
if cache_entry.key not in self._cached_models:
|
||||
self._logger.info(
|
||||
f"Locking model cache entry {cache_entry.key} "
|
||||
f"(Type: {cache_entry.cached_model.model.__class__.__name__}), but it has already been dropped from "
|
||||
"the RAM cache. This is a sign that the model loading order is non-optimal in the invocation code "
|
||||
"(See https://github.com/invoke-ai/InvokeAI/issues/7513)."
|
||||
f"Locking model cache entry {cache_entry.key} ({cache_entry.model.__class__.__name__}), but it has "
|
||||
"already been dropped from the RAM cache. This is a sign that the model loading order is non-optimal "
|
||||
"in the invocation code."
|
||||
)
|
||||
# cache_entry = self._cached_models[key]
|
||||
cache_entry.lock()
|
||||
|
||||
self._logger.debug(
|
||||
f"Locking model {cache_entry.key} (Type: {cache_entry.cached_model.model.__class__.__name__})"
|
||||
)
|
||||
|
||||
if self._execution_device.type == "cpu":
|
||||
# Models don't need to be loaded into VRAM if we're running on CPU.
|
||||
return
|
||||
|
||||
try:
|
||||
self._load_locked_model(cache_entry, working_mem_bytes)
|
||||
self._logger.debug(
|
||||
f"Finished locking model {cache_entry.key} (Type: {cache_entry.cached_model.model.__class__.__name__})"
|
||||
)
|
||||
if self._lazy_offloading:
|
||||
self._offload_unlocked_models(cache_entry.size)
|
||||
self._move_model_to_device(cache_entry, self._execution_device)
|
||||
cache_entry.loaded = True
|
||||
self._logger.debug(f"Locking {cache_entry.key} in {self._execution_device}")
|
||||
self._print_cuda_stats()
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
self._logger.warning("Insufficient GPU memory to load model. Aborting")
|
||||
cache_entry.unlock()
|
||||
@@ -233,333 +220,201 @@ class ModelCache:
|
||||
cache_entry.unlock()
|
||||
raise
|
||||
|
||||
self._log_cache_state()
|
||||
|
||||
def unlock(self, cache_entry: CacheRecord) -> None:
|
||||
"""Unlock a model."""
|
||||
if cache_entry.key not in self._cached_models:
|
||||
self._logger.info(
|
||||
f"Unlocking model cache entry {cache_entry.key} "
|
||||
f"(Type: {cache_entry.cached_model.model.__class__.__name__}), but it has already been dropped from "
|
||||
"the RAM cache. This is a sign that the model loading order is non-optimal in the invocation code "
|
||||
"(See https://github.com/invoke-ai/InvokeAI/issues/7513)."
|
||||
f"Unlocking model cache entry {cache_entry.key} ({cache_entry.model.__class__.__name__}), but it has "
|
||||
"already been dropped from the RAM cache. This is a sign that the model loading order is non-optimal "
|
||||
"in the invocation code."
|
||||
)
|
||||
# cache_entry = self._cached_models[key]
|
||||
cache_entry.unlock()
|
||||
self._logger.debug(
|
||||
f"Unlocked model {cache_entry.key} (Type: {cache_entry.cached_model.model.__class__.__name__})"
|
||||
)
|
||||
if not self._lazy_offloading:
|
||||
self._offload_unlocked_models(0)
|
||||
self._print_cuda_stats()
|
||||
|
||||
def _load_locked_model(self, cache_entry: CacheRecord, working_mem_bytes: Optional[int] = None) -> None:
|
||||
"""Helper function for self.lock(). Loads a locked model into VRAM."""
|
||||
start_time = time.time()
|
||||
|
||||
# Calculate model_vram_needed, the amount of additional VRAM that will be used if we fully load the model into
|
||||
# VRAM.
|
||||
model_cur_vram_bytes = cache_entry.cached_model.cur_vram_bytes()
|
||||
model_total_bytes = cache_entry.cached_model.total_bytes()
|
||||
model_vram_needed = model_total_bytes - model_cur_vram_bytes
|
||||
|
||||
vram_available = self._get_vram_available(working_mem_bytes)
|
||||
self._logger.debug(
|
||||
f"Before unloading: {self._get_vram_state_str(model_cur_vram_bytes, model_total_bytes, vram_available)}"
|
||||
)
|
||||
|
||||
# Make room for the model in VRAM.
|
||||
# 1. If the model can fit entirely in VRAM, then make enough room for it to be loaded fully.
|
||||
# 2. If the model can't fit fully into VRAM, then unload all other models and load as much of the model as
|
||||
# possible.
|
||||
vram_bytes_freed = self._offload_unlocked_models(model_vram_needed, working_mem_bytes)
|
||||
self._logger.debug(f"Unloaded models (if necessary): vram_bytes_freed={(vram_bytes_freed/MB):.2f}MB")
|
||||
|
||||
# Check the updated vram_available after offloading.
|
||||
vram_available = self._get_vram_available(working_mem_bytes)
|
||||
self._logger.debug(
|
||||
f"After unloading: {self._get_vram_state_str(model_cur_vram_bytes, model_total_bytes, vram_available)}"
|
||||
)
|
||||
|
||||
if vram_available < 0:
|
||||
# There is insufficient VRAM available. As a last resort, try to unload the model being locked from VRAM,
|
||||
# as it may still be loaded from a previous use.
|
||||
vram_bytes_freed_from_own_model = self._move_model_to_ram(cache_entry, -vram_available)
|
||||
vram_available = self._get_vram_available(working_mem_bytes)
|
||||
self._logger.debug(
|
||||
f"Unloaded {vram_bytes_freed_from_own_model/MB:.2f}MB from the model being locked ({cache_entry.key})."
|
||||
)
|
||||
|
||||
# Move as much of the model as possible into VRAM.
|
||||
# For testing, only allow 10% of the model to be loaded into VRAM.
|
||||
# vram_available = int(model_vram_needed * 0.1)
|
||||
# We add 1 MB to the available VRAM to account for small errors in memory tracking (e.g. off-by-one). A fully
|
||||
# loaded model is much faster than a 95% loaded model.
|
||||
model_bytes_loaded = self._move_model_to_vram(cache_entry, vram_available + MB)
|
||||
|
||||
model_cur_vram_bytes = cache_entry.cached_model.cur_vram_bytes()
|
||||
vram_available = self._get_vram_available(working_mem_bytes)
|
||||
loaded_percent = model_cur_vram_bytes / model_total_bytes if model_total_bytes > 0 else 0
|
||||
self._logger.info(
|
||||
f"Loaded model '{cache_entry.key}' ({cache_entry.cached_model.model.__class__.__name__}) onto "
|
||||
f"{self._execution_device.type} device in {(time.time() - start_time):.2f}s. "
|
||||
f"Total model size: {model_total_bytes/MB:.2f}MB, "
|
||||
f"VRAM: {model_cur_vram_bytes/MB:.2f}MB ({loaded_percent:.1%})"
|
||||
)
|
||||
self._logger.debug(f"Loaded model onto execution device: model_bytes_loaded={(model_bytes_loaded/MB):.2f}MB, ")
|
||||
self._logger.debug(
|
||||
f"After loading: {self._get_vram_state_str(model_cur_vram_bytes, model_total_bytes, vram_available)}"
|
||||
)
|
||||
|
||||
def _move_model_to_vram(self, cache_entry: CacheRecord, vram_available: int) -> int:
|
||||
try:
|
||||
if isinstance(cache_entry.cached_model, CachedModelWithPartialLoad):
|
||||
return cache_entry.cached_model.partial_load_to_vram(vram_available)
|
||||
elif isinstance(cache_entry.cached_model, CachedModelOnlyFullLoad): # type: ignore
|
||||
# Partial load is not supported, so we have not choice but to try and fit it all into VRAM.
|
||||
return cache_entry.cached_model.full_load_to_vram()
|
||||
else:
|
||||
raise ValueError(f"Unsupported cached model type: {type(cache_entry.cached_model)}")
|
||||
except Exception as e:
|
||||
if isinstance(e, torch.cuda.OutOfMemoryError):
|
||||
self._logger.warning("Insufficient GPU memory to load model. Aborting")
|
||||
# If an exception occurs, the model could be left in a bad state, so we delete it from the cache entirely.
|
||||
self._delete_cache_entry(cache_entry)
|
||||
raise
|
||||
|
||||
def _move_model_to_ram(self, cache_entry: CacheRecord, vram_bytes_to_free: int) -> int:
|
||||
try:
|
||||
if isinstance(cache_entry.cached_model, CachedModelWithPartialLoad):
|
||||
return cache_entry.cached_model.partial_unload_from_vram(
|
||||
vram_bytes_to_free, keep_required_weights_in_vram=cache_entry.is_locked
|
||||
)
|
||||
elif isinstance(cache_entry.cached_model, CachedModelOnlyFullLoad): # type: ignore
|
||||
return cache_entry.cached_model.full_unload_from_vram()
|
||||
else:
|
||||
raise ValueError(f"Unsupported cached model type: {type(cache_entry.cached_model)}")
|
||||
except Exception:
|
||||
# If an exception occurs, the model could be left in a bad state, so we delete it from the cache entirely.
|
||||
self._delete_cache_entry(cache_entry)
|
||||
raise
|
||||
|
||||
def _get_vram_available(self, working_mem_bytes: Optional[int]) -> int:
|
||||
"""Calculate the amount of additional VRAM available for the cache to use (takes into account the working
|
||||
memory).
|
||||
"""
|
||||
# If self._max_vram_cache_size_gb is set, then it overrides the default logic.
|
||||
if self._max_vram_cache_size_gb is not None:
|
||||
vram_total_available_to_cache = int(self._max_vram_cache_size_gb * GB)
|
||||
return vram_total_available_to_cache - self._get_vram_in_use()
|
||||
|
||||
working_mem_bytes_default = int(self._execution_device_working_mem_gb * GB)
|
||||
working_mem_bytes = max(working_mem_bytes or working_mem_bytes_default, working_mem_bytes_default)
|
||||
|
||||
if self._execution_device.type == "cuda":
|
||||
# TODO(ryand): It is debatable whether we should use memory_reserved() or memory_allocated() here.
|
||||
# memory_reserved() includes memory reserved by the torch CUDA memory allocator that may or may not be
|
||||
# re-used for future allocations. For now, we use memory_allocated() to be conservative.
|
||||
# vram_reserved = torch.cuda.memory_reserved(self._execution_device)
|
||||
vram_allocated = torch.cuda.memory_allocated(self._execution_device)
|
||||
vram_free, _vram_total = torch.cuda.mem_get_info(self._execution_device)
|
||||
vram_available_to_process = vram_free + vram_allocated
|
||||
elif self._execution_device.type == "mps":
|
||||
vram_reserved = torch.mps.driver_allocated_memory()
|
||||
# TODO(ryand): Is it accurate that MPS shares memory with the CPU?
|
||||
vram_free = psutil.virtual_memory().available
|
||||
vram_available_to_process = vram_free + vram_reserved
|
||||
else:
|
||||
raise ValueError(f"Unsupported execution device: {self._execution_device.type}")
|
||||
|
||||
vram_total_available_to_cache = vram_available_to_process - working_mem_bytes
|
||||
vram_cur_available_to_cache = vram_total_available_to_cache - self._get_vram_in_use()
|
||||
return vram_cur_available_to_cache
|
||||
|
||||
def _get_vram_in_use(self) -> int:
|
||||
"""Get the amount of VRAM currently in use by the cache."""
|
||||
if self._execution_device.type == "cuda":
|
||||
return torch.cuda.memory_allocated()
|
||||
elif self._execution_device.type == "mps":
|
||||
return torch.mps.current_allocated_memory()
|
||||
else:
|
||||
raise ValueError(f"Unsupported execution device type: {self._execution_device.type}")
|
||||
# Alternative definition of VRAM in use:
|
||||
# return sum(ce.cached_model.cur_vram_bytes() for ce in self._cached_models.values())
|
||||
|
||||
def _get_ram_available(self) -> int:
|
||||
"""Get the amount of RAM available for the cache to use, while keeping memory pressure under control."""
|
||||
# If self._max_ram_cache_size_gb is set, then it overrides the default logic.
|
||||
if self._max_ram_cache_size_gb is not None:
|
||||
ram_total_available_to_cache = int(self._max_ram_cache_size_gb * GB)
|
||||
return ram_total_available_to_cache - self._get_ram_in_use()
|
||||
|
||||
virtual_memory = psutil.virtual_memory()
|
||||
ram_total = virtual_memory.total
|
||||
ram_available = virtual_memory.available
|
||||
ram_used = ram_total - ram_available
|
||||
|
||||
# The total size of all the models in the cache will often be larger than the amount of RAM reported by psutil
|
||||
# (due to lazy-loading and OS RAM caching behaviour). We could just rely on the psutil values, but it feels
|
||||
# like a bad idea to over-fill the model cache. So, for now, we'll try to keep the total size of models in the
|
||||
# cache under the total amount of system RAM.
|
||||
cache_ram_used = self._get_ram_in_use()
|
||||
ram_used = max(cache_ram_used, ram_used)
|
||||
|
||||
# Aim to keep 10% of RAM free.
|
||||
ram_available_based_on_memory_usage = int(ram_total * 0.9) - ram_used
|
||||
|
||||
# If we are running out of RAM, then there's an increased likelihood that we will run into this issue:
|
||||
# https://github.com/invoke-ai/InvokeAI/issues/7513
|
||||
# To keep things running smoothly, there's a minimum RAM cache size that we always allow (even if this means
|
||||
# using swap).
|
||||
min_ram_cache_size_bytes = 4 * GB
|
||||
ram_available_based_on_min_cache_size = min_ram_cache_size_bytes - cache_ram_used
|
||||
|
||||
return max(ram_available_based_on_memory_usage, ram_available_based_on_min_cache_size)
|
||||
|
||||
def _get_ram_in_use(self) -> int:
|
||||
"""Get the amount of RAM currently in use."""
|
||||
return sum(ce.cached_model.total_bytes() for ce in self._cached_models.values())
|
||||
def _get_cache_size(self) -> int:
|
||||
"""Get the total size of the models currently cached."""
|
||||
total = 0
|
||||
for cache_record in self._cached_models.values():
|
||||
total += cache_record.size
|
||||
return total
|
||||
|
||||
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
|
||||
if self._log_memory_usage:
|
||||
return MemorySnapshot.capture()
|
||||
return None
|
||||
|
||||
def _get_vram_state_str(self, model_cur_vram_bytes: int, model_total_bytes: int, vram_available: int) -> str:
|
||||
"""Helper function for preparing a VRAM state log string."""
|
||||
model_cur_vram_bytes_percent = model_cur_vram_bytes / model_total_bytes if model_total_bytes > 0 else 0
|
||||
return (
|
||||
f"model_total={model_total_bytes/MB:.0f} MB, "
|
||||
+ f"model_vram={model_cur_vram_bytes/MB:.0f} MB ({model_cur_vram_bytes_percent:.1%} %), "
|
||||
# + f"vram_total={int(self._max_vram_cache_size * GB)/MB:.0f} MB, "
|
||||
+ f"vram_available={(vram_available/MB):.0f} MB, "
|
||||
)
|
||||
def _make_cache_key(self, model_key: str, submodel_type: Optional[SubModelType] = None) -> str:
|
||||
if submodel_type:
|
||||
return f"{model_key}:{submodel_type.value}"
|
||||
else:
|
||||
return model_key
|
||||
|
||||
def _offload_unlocked_models(self, vram_bytes_required: int, working_mem_bytes: Optional[int] = None) -> int:
|
||||
"""Offload models from the execution_device until vram_bytes_required bytes are available, or all models are
|
||||
offloaded. Of course, locked models are not offloaded.
|
||||
def _offload_unlocked_models(self, size_required: int) -> None:
|
||||
"""Offload models from the execution_device to make room for size_required.
|
||||
|
||||
Returns:
|
||||
int: The number of bytes freed based on believed model sizes. The actual change in VRAM may be different.
|
||||
:param size_required: The amount of space to clear in the execution_device cache, in bytes.
|
||||
"""
|
||||
self._logger.debug(
|
||||
f"Offloading unlocked models with goal of making room for {vram_bytes_required/MB:.2f}MB of VRAM."
|
||||
)
|
||||
vram_bytes_freed = 0
|
||||
# TODO(ryand): Give more thought to the offloading policy used here.
|
||||
cache_entries_increasing_size = sorted(self._cached_models.values(), key=lambda x: x.cached_model.total_bytes())
|
||||
for cache_entry in cache_entries_increasing_size:
|
||||
# We do not fully trust the count of bytes freed, so we check again on each iteration.
|
||||
vram_available = self._get_vram_available(working_mem_bytes)
|
||||
vram_bytes_to_free = vram_bytes_required - vram_available
|
||||
if vram_bytes_to_free <= 0:
|
||||
reserved = self._max_vram_cache_size * GB
|
||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||
self._logger.debug(f"{(vram_in_use/GB):.2f}GB VRAM needed for models; max allowed={(reserved/GB):.2f}GB")
|
||||
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
||||
if vram_in_use <= reserved:
|
||||
break
|
||||
if cache_entry.is_locked:
|
||||
# TODO(ryand): In the future, we may want to partially unload locked models, but this requires careful
|
||||
# handling of model patches (e.g. LoRA).
|
||||
if not cache_entry.loaded:
|
||||
continue
|
||||
cache_entry_bytes_freed = self._move_model_to_ram(cache_entry, vram_bytes_to_free)
|
||||
if cache_entry_bytes_freed > 0:
|
||||
if not cache_entry.locked:
|
||||
self._move_model_to_device(cache_entry, self._storage_device)
|
||||
cache_entry.loaded = False
|
||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||
self._logger.debug(
|
||||
f"Unloaded {cache_entry.key} from VRAM to free {(cache_entry_bytes_freed/MB):.0f} MB."
|
||||
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GB):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GB):.2f}GB"
|
||||
)
|
||||
vram_bytes_freed += cache_entry_bytes_freed
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
return vram_bytes_freed
|
||||
|
||||
def _log_cache_state(self, title: str = "Model cache state:", include_entry_details: bool = True):
|
||||
if self._logger.getEffectiveLevel() > logging.DEBUG:
|
||||
# Short circuit if the logger is not set to debug. Some of the data lookups could take a non-negligible
|
||||
# amount of time.
|
||||
def _move_model_to_device(self, cache_entry: CacheRecord, target_device: torch.device) -> None:
|
||||
"""Move model into the indicated device.
|
||||
|
||||
:param cache_entry: The CacheRecord for the model
|
||||
:param target_device: The torch.device to move the model into
|
||||
|
||||
May raise a torch.cuda.OutOfMemoryError
|
||||
"""
|
||||
self._logger.debug(f"Called to move {cache_entry.key} to {target_device}")
|
||||
source_device = cache_entry.device
|
||||
|
||||
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
|
||||
# This would need to be revised to support multi-GPU.
|
||||
if torch.device(source_device).type == torch.device(target_device).type:
|
||||
return
|
||||
|
||||
log = f"{title}\n"
|
||||
# Some models don't have a `to` method, in which case they run in RAM/CPU.
|
||||
if not hasattr(cache_entry.model, "to"):
|
||||
return
|
||||
|
||||
log_format = " {:<30} Limit: {:>7.1f} MB, Used: {:>7.1f} MB ({:>5.1%}), Available: {:>7.1f} MB ({:>5.1%})\n"
|
||||
# This roundabout method for moving the model around is done to avoid
|
||||
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
|
||||
# When moving to VRAM, we copy (not move) each element of the state dict from
|
||||
# RAM to a new state dict in VRAM, and then inject it into the model.
|
||||
# This operation is slightly faster than running `to()` on the whole model.
|
||||
#
|
||||
# When the model needs to be removed from VRAM we simply delete the copy
|
||||
# of the state dict in VRAM, and reinject the state dict that is cached
|
||||
# in RAM into the model. So this operation is very fast.
|
||||
start_model_to_time = time.time()
|
||||
snapshot_before = self._capture_memory_snapshot()
|
||||
|
||||
ram_in_use_bytes = self._get_ram_in_use()
|
||||
ram_available_bytes = self._get_ram_available()
|
||||
ram_size_bytes = ram_in_use_bytes + ram_available_bytes
|
||||
ram_in_use_bytes_percent = ram_in_use_bytes / ram_size_bytes if ram_size_bytes > 0 else 0
|
||||
ram_available_bytes_percent = ram_available_bytes / ram_size_bytes if ram_size_bytes > 0 else 0
|
||||
log += log_format.format(
|
||||
f"Storage Device ({self._storage_device.type})",
|
||||
ram_size_bytes / MB,
|
||||
ram_in_use_bytes / MB,
|
||||
ram_in_use_bytes_percent,
|
||||
ram_available_bytes / MB,
|
||||
ram_available_bytes_percent,
|
||||
try:
|
||||
if cache_entry.state_dict is not None:
|
||||
assert hasattr(cache_entry.model, "load_state_dict")
|
||||
if target_device == self._storage_device:
|
||||
cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True)
|
||||
else:
|
||||
new_dict: Dict[str, torch.Tensor] = {}
|
||||
for k, v in cache_entry.state_dict.items():
|
||||
new_dict[k] = v.to(target_device, copy=True)
|
||||
cache_entry.model.load_state_dict(new_dict, assign=True)
|
||||
cache_entry.model.to(target_device)
|
||||
cache_entry.device = target_device
|
||||
except Exception as e: # blow away cache entry
|
||||
self._delete_cache_entry(cache_entry)
|
||||
raise e
|
||||
|
||||
snapshot_after = self._capture_memory_snapshot()
|
||||
end_model_to_time = time.time()
|
||||
self._logger.debug(
|
||||
f"Moved model '{cache_entry.key}' from {source_device} to"
|
||||
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
|
||||
f"Estimated model size: {(cache_entry.size/GB):.3f} GB."
|
||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
)
|
||||
|
||||
if self._execution_device.type != "cpu":
|
||||
vram_in_use_bytes = self._get_vram_in_use()
|
||||
vram_available_bytes = self._get_vram_available(None)
|
||||
vram_size_bytes = vram_in_use_bytes + vram_available_bytes
|
||||
vram_in_use_bytes_percent = vram_in_use_bytes / vram_size_bytes if vram_size_bytes > 0 else 0
|
||||
vram_available_bytes_percent = vram_available_bytes / vram_size_bytes if vram_size_bytes > 0 else 0
|
||||
log += log_format.format(
|
||||
f"Compute Device ({self._execution_device.type})",
|
||||
vram_size_bytes / MB,
|
||||
vram_in_use_bytes / MB,
|
||||
vram_in_use_bytes_percent,
|
||||
vram_available_bytes / MB,
|
||||
vram_available_bytes_percent,
|
||||
)
|
||||
if (
|
||||
snapshot_before is not None
|
||||
and snapshot_after is not None
|
||||
and snapshot_before.vram is not None
|
||||
and snapshot_after.vram is not None
|
||||
):
|
||||
vram_change = abs(snapshot_before.vram - snapshot_after.vram)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
log += " {:<30} {:.1f} MB\n".format("CUDA Memory Allocated:", torch.cuda.memory_allocated() / MB)
|
||||
log += " {:<30} {}\n".format("Total models:", len(self._cached_models))
|
||||
|
||||
if include_entry_details and len(self._cached_models) > 0:
|
||||
log += " Models:\n"
|
||||
log_format = (
|
||||
" {:<80} total={:>7.1f} MB, vram={:>7.1f} MB ({:>5.1%}), ram={:>7.1f} MB ({:>5.1%}), locked={}\n"
|
||||
)
|
||||
for cache_record in self._cached_models.values():
|
||||
total_bytes = cache_record.cached_model.total_bytes()
|
||||
cur_vram_bytes = cache_record.cached_model.cur_vram_bytes()
|
||||
cur_vram_bytes_percent = cur_vram_bytes / total_bytes if total_bytes > 0 else 0
|
||||
cur_ram_bytes = total_bytes - cur_vram_bytes
|
||||
cur_ram_bytes_percent = cur_ram_bytes / total_bytes if total_bytes > 0 else 0
|
||||
|
||||
log += log_format.format(
|
||||
f"{cache_record.key} ({cache_record.cached_model.model.__class__.__name__}):",
|
||||
total_bytes / MB,
|
||||
cur_vram_bytes / MB,
|
||||
cur_vram_bytes_percent,
|
||||
cur_ram_bytes / MB,
|
||||
cur_ram_bytes_percent,
|
||||
cache_record.is_locked,
|
||||
# If the estimated model size does not match the change in VRAM, log a warning.
|
||||
if not math.isclose(
|
||||
vram_change,
|
||||
cache_entry.size,
|
||||
rel_tol=0.1,
|
||||
abs_tol=10 * MB,
|
||||
):
|
||||
self._logger.debug(
|
||||
f"Moving model '{cache_entry.key}' from {source_device} to"
|
||||
f" {target_device} caused an unexpected change in VRAM usage. The model's"
|
||||
" estimated size may be incorrect. Estimated model size:"
|
||||
f" {(cache_entry.size/GB):.3f} GB.\n"
|
||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
)
|
||||
|
||||
self._logger.debug(log)
|
||||
def _print_cuda_stats(self) -> None:
|
||||
"""Log CUDA diagnostics."""
|
||||
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GB)
|
||||
ram = "%4.2fG" % (self._get_cache_size() / GB)
|
||||
|
||||
def make_room(self, bytes_needed: int) -> None:
|
||||
in_ram_models = 0
|
||||
in_vram_models = 0
|
||||
locked_in_vram_models = 0
|
||||
for cache_record in self._cached_models.values():
|
||||
if hasattr(cache_record.model, "device"):
|
||||
if cache_record.model.device == self._storage_device:
|
||||
in_ram_models += 1
|
||||
else:
|
||||
in_vram_models += 1
|
||||
if cache_record.locked:
|
||||
locked_in_vram_models += 1
|
||||
|
||||
self._logger.debug(
|
||||
f"Current VRAM/RAM usage: {vram}/{ram}; models_in_ram/models_in_vram(locked) ="
|
||||
f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})"
|
||||
)
|
||||
|
||||
def make_room(self, size: int) -> None:
|
||||
"""Make enough room in the cache to accommodate a new model of indicated size.
|
||||
|
||||
Note: This function deletes all of the cache's internal references to a model in order to free it. If there are
|
||||
external references to the model, there's nothing that the cache can do about it, and those models will not be
|
||||
garbage-collected.
|
||||
"""
|
||||
self._logger.debug(f"Making room for {bytes_needed/MB:.2f}MB of RAM.")
|
||||
self._log_cache_state(title="Before dropping models:")
|
||||
bytes_needed = size
|
||||
maximum_size = self._max_cache_size * GB # stored in GB, convert to bytes
|
||||
current_size = self._get_cache_size()
|
||||
|
||||
ram_bytes_available = self._get_ram_available()
|
||||
ram_bytes_to_free = max(0, bytes_needed - ram_bytes_available)
|
||||
if current_size + bytes_needed > maximum_size:
|
||||
self._logger.debug(
|
||||
f"Max cache size exceeded: {(current_size/GB):.2f}/{self.max_cache_size:.2f} GB, need an additional"
|
||||
f" {(bytes_needed/GB):.2f} GB"
|
||||
)
|
||||
|
||||
self._logger.debug(f"Before making_room: cached_models={len(self._cached_models)}")
|
||||
|
||||
ram_bytes_freed = 0
|
||||
pos = 0
|
||||
models_cleared = 0
|
||||
while ram_bytes_freed < ram_bytes_to_free and pos < len(self._cache_stack):
|
||||
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
|
||||
model_key = self._cache_stack[pos]
|
||||
cache_entry = self._cached_models[model_key]
|
||||
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
|
||||
self._logger.debug(
|
||||
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}"
|
||||
)
|
||||
|
||||
if not cache_entry.is_locked:
|
||||
ram_bytes_freed += cache_entry.cached_model.total_bytes()
|
||||
if not cache_entry.locked:
|
||||
self._logger.debug(
|
||||
f"Dropping {model_key} from RAM cache to free {(cache_entry.cached_model.total_bytes()/MB):.2f}MB."
|
||||
f"Removing {model_key} from RAM cache to free at least {(size/GB):.2f} GB (-{(cache_entry.size/GB):.2f} GB)"
|
||||
)
|
||||
current_size -= cache_entry.size
|
||||
models_cleared += 1
|
||||
self._delete_cache_entry(cache_entry)
|
||||
del cache_entry
|
||||
models_cleared += 1
|
||||
|
||||
else:
|
||||
pos += 1
|
||||
|
||||
@@ -580,10 +435,8 @@ class ModelCache:
|
||||
gc.collect()
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
self._logger.debug(f"Dropped {models_cleared} models to free {ram_bytes_freed/MB:.2f}MB of RAM.")
|
||||
self._log_cache_state(title="After dropping models:")
|
||||
self._logger.debug(f"After making room: cached_models={len(self._cached_models)}")
|
||||
|
||||
def _delete_cache_entry(self, cache_entry: CacheRecord) -> None:
|
||||
"""Delete cache_entry from the cache if it exists. No exception is thrown if it doesn't exist."""
|
||||
self._cache_stack = [key for key in self._cache_stack if key != cache_entry.key]
|
||||
self._cached_models.pop(cache_entry.key, None)
|
||||
self._cache_stack.remove(cache_entry.key)
|
||||
del self._cached_models[cache_entry.key]
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def get_effective_device(model: torch.nn.Module) -> torch.device:
|
||||
"""A utility to infer the 'effective' device of a model.
|
||||
|
||||
This utility handles the case where a model is partially loaded onto the GPU, so is safer than just calling:
|
||||
`next(iter(model.parameters())).device`.
|
||||
|
||||
In the worst case, this utility has to check all model parameters, so if you already know the intended model device,
|
||||
then it is better to avoid calling this function.
|
||||
"""
|
||||
# If all parameters are on the CPU, return the CPU device. Otherwise, return the first non-CPU device.
|
||||
for p in itertools.chain(model.parameters(), model.buffers()):
|
||||
if p.device.type != "cpu":
|
||||
return p.device
|
||||
|
||||
return torch.device("cpu")
|
||||
@@ -14,7 +14,6 @@ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokeniz
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||
from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
class ModelPatcher:
|
||||
@@ -123,7 +122,7 @@ class ModelPatcher:
|
||||
)
|
||||
|
||||
model_embeddings.weight.data[token_id] = embedding.to(
|
||||
device=TorchDevice.choose_torch_device(), dtype=text_encoder.dtype
|
||||
device=text_encoder.device, dtype=text_encoder.dtype
|
||||
)
|
||||
ti_tokens.append(token_id)
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
|
||||
from invokeai.backend.patches.layers.lora_layer_base import LoRALayerBase
|
||||
|
||||
|
||||
@@ -51,7 +50,7 @@ class IA3Layer(LoRALayerBase):
|
||||
weight = self.weight
|
||||
if not self.on_input:
|
||||
weight = weight.reshape(-1, 1)
|
||||
return cast_to_device(orig_weight, weight.device) * weight
|
||||
return orig_weight * weight
|
||||
|
||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||
super().to(device, dtype)
|
||||
|
||||
@@ -12,7 +12,6 @@ from invokeai.backend.model_manager import BaseModelType
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode
|
||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
@@ -90,7 +89,7 @@ class T2IAdapterExt(ExtensionBase):
|
||||
width=input_width,
|
||||
height=input_height,
|
||||
num_channels=model.config["in_channels"],
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
device=model.device,
|
||||
dtype=model.dtype,
|
||||
resize_mode=self._resize_mode,
|
||||
)
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
import logging
|
||||
from typing import Any, MutableMapping
|
||||
|
||||
|
||||
# Issue with type hints related to LoggerAdapter: https://github.com/python/typeshed/issues/7855
|
||||
class PrefixedLoggerAdapter(logging.LoggerAdapter): # type: ignore
|
||||
def __init__(self, logger: logging.Logger, prefix: str):
|
||||
super().__init__(logger, {})
|
||||
self.prefix = prefix
|
||||
|
||||
def process(self, msg: str, kwargs: MutableMapping[str, Any]) -> tuple[str, MutableMapping[str, Any]]:
|
||||
return f"[{self.prefix}] {msg}", kwargs
|
||||
@@ -604,12 +604,7 @@
|
||||
"hfForbidden": "Sie haben keinen Zugriff auf dieses HF-Modell",
|
||||
"hfTokenInvalid": "Ungültiges oder fehlendes HF-Token",
|
||||
"restoreDefaultSettings": "Klicken, um die Standardeinstellungen des Modells zu verwenden.",
|
||||
"usingDefaultSettings": "Die Standardeinstellungen des Modells werden verwendet",
|
||||
"hfTokenInvalidErrorMessage": "Ungültiges oder fehlendes HuggingFace-Token.",
|
||||
"hfTokenUnableToVerify": "HF-Token kann nicht überprüft werden",
|
||||
"hfTokenUnableToVerifyErrorMessage": "HuggingFace-Token kann nicht überprüft werden. Dies ist wahrscheinlich auf einen Netzwerkfehler zurückzuführen. Bitte versuchen Sie es später erneut.",
|
||||
"hfTokenSaved": "HF-Token gespeichert",
|
||||
"hfTokenRequired": "Sie versuchen, ein Modell herunterzuladen, für das ein gültiges HuggingFace-Token erforderlich ist."
|
||||
"usingDefaultSettings": "Die Standardeinstellungen des Modells werden verwendet"
|
||||
},
|
||||
"parameters": {
|
||||
"images": "Bilder",
|
||||
@@ -660,23 +655,12 @@
|
||||
"canvasIsCompositing": "Leinwand ist beschäftigt (wird zusammengesetzt)",
|
||||
"canvasIsFiltering": "Leinwand ist beschäftigt (wird gefiltert)",
|
||||
"canvasIsSelectingObject": "Leinwand ist beschäftigt (wird Objekt ausgewählt)",
|
||||
"noPrompts": "Keine Eingabeaufforderungen generiert",
|
||||
"noModelSelected": "Kein Modell ausgewählt"
|
||||
"noPrompts": "Keine Eingabeaufforderungen generiert"
|
||||
},
|
||||
"seed": "Seed",
|
||||
"patchmatchDownScaleSize": "Herunterskalieren",
|
||||
"seamlessXAxis": "Nahtlose X Achse",
|
||||
"seamlessYAxis": "Nahtlose Y Achse",
|
||||
"coherenceEdgeSize": "Kantengröße",
|
||||
"infillColorValue": "Füllfarbe",
|
||||
"controlNetControlMode": "Kontrollmodus",
|
||||
"cancel": {
|
||||
"cancel": "Abbrechen"
|
||||
},
|
||||
"iterations": "Iterationen",
|
||||
"guidance": "Führung",
|
||||
"coherenceMode": "Modus",
|
||||
"recallMetadata": "Metadaten abrufen"
|
||||
"seamlessYAxis": "Nahtlose Y Achse"
|
||||
},
|
||||
"settings": {
|
||||
"displayInProgress": "Zwischenbilder anzeigen",
|
||||
@@ -1143,12 +1127,6 @@
|
||||
"Nahtloses Kacheln eines Bildes entlang der horizontalen Achse."
|
||||
],
|
||||
"heading": "Nahtlose Kachelung X Achse"
|
||||
},
|
||||
"compositingCoherenceEdgeSize": {
|
||||
"paragraphs": [
|
||||
"Die Kantengröße des Kohärenzdurchlaufs."
|
||||
],
|
||||
"heading": "Kantengröße"
|
||||
}
|
||||
},
|
||||
"invocationCache": {
|
||||
|
||||
@@ -177,11 +177,7 @@
|
||||
"none": "None",
|
||||
"new": "New",
|
||||
"generating": "Generating",
|
||||
"warnings": "Warnings",
|
||||
"start": "Start",
|
||||
"count": "Count",
|
||||
"step": "Step",
|
||||
"values": "Values"
|
||||
"warnings": "Warnings"
|
||||
},
|
||||
"hrf": {
|
||||
"hrf": "High Resolution Fix",
|
||||
@@ -854,14 +850,7 @@
|
||||
"defaultVAE": "Default VAE"
|
||||
},
|
||||
"nodes": {
|
||||
"noBatchGroup": "no group",
|
||||
"generator": "Generator",
|
||||
"generatedValues": "Generated Values",
|
||||
"commitValues": "Commit Values",
|
||||
"addValue": "Add Value",
|
||||
"addNode": "Add Node",
|
||||
"lockLinearView": "Lock Linear View",
|
||||
"unlockLinearView": "Unlock Linear View",
|
||||
"addNodeToolTip": "Add Node (Shift+A, Space)",
|
||||
"addLinearView": "Add to Linear View",
|
||||
"animatedEdges": "Animated Edges",
|
||||
@@ -1035,21 +1024,11 @@
|
||||
"addingImagesTo": "Adding images to",
|
||||
"invoke": "Invoke",
|
||||
"missingFieldTemplate": "Missing field template",
|
||||
"missingInputForField": "missing input",
|
||||
"missingInputForField": "{{nodeLabel}} -> {{fieldLabel}}: missing input",
|
||||
"missingNodeTemplate": "Missing node template",
|
||||
"collectionEmpty": "empty collection",
|
||||
"invalidBatchConfiguration": "Invalid batch configuration",
|
||||
"batchNodeNotConnected": "Batch node not connected: {{label}}",
|
||||
"collectionTooFewItems": "too few items, minimum {{minItems}}",
|
||||
"collectionTooManyItems": "too many items, maximum {{maxItems}}",
|
||||
"collectionStringTooLong": "too long, max {{maxLength}}",
|
||||
"collectionStringTooShort": "too short, min {{minLength}}",
|
||||
"collectionNumberGTMax": "{{value}} > {{maximum}} (inc max)",
|
||||
"collectionNumberLTMin": "{{value}} < {{minimum}} (inc min)",
|
||||
"collectionNumberGTExclusiveMax": "{{value}} >= {{exclusiveMaximum}} (exc max)",
|
||||
"collectionNumberLTExclusiveMin": "{{value}} <= {{exclusiveMinimum}} (exc min)",
|
||||
"collectionNumberNotMultipleOf": "{{value}} not multiple of {{multipleOf}}",
|
||||
"batchNodeCollectionSizeMismatch": "Collection size mismatch on Batch {{batchGroupId}}",
|
||||
"collectionEmpty": "{{nodeLabel}} -> {{fieldLabel}} empty collection",
|
||||
"collectionTooFewItems": "{{nodeLabel}} -> {{fieldLabel}}: too few items, minimum {{minItems}}",
|
||||
"collectionTooManyItems": "{{nodeLabel}} -> {{fieldLabel}}: too many items, maximum {{maxItems}}",
|
||||
"noModelSelected": "No model selected",
|
||||
"noT5EncoderModelSelected": "No T5 Encoder model selected for FLUX generation",
|
||||
"noFLUXVAEModelSelected": "No VAE model selected for FLUX generation",
|
||||
@@ -1206,7 +1185,6 @@
|
||||
"modelAddedSimple": "Model Added to Queue",
|
||||
"modelImportCanceled": "Model Import Canceled",
|
||||
"outOfMemoryError": "Out of Memory Error",
|
||||
"outOfMemoryErrorDescLocal": "Follow our <LinkComponent>Low VRAM guide</LinkComponent> to reduce OOMs.",
|
||||
"outOfMemoryErrorDesc": "Your current generation settings exceed system capacity. Please adjust your settings and try again.",
|
||||
"parameters": "Parameters",
|
||||
"parameterSet": "Parameter Recalled",
|
||||
@@ -1953,24 +1931,6 @@
|
||||
"description": "Generates an edge map from the selected layer using the PiDiNet edge detection model.",
|
||||
"scribble": "Scribble",
|
||||
"quantize_edges": "Quantize Edges"
|
||||
},
|
||||
"img_blur": {
|
||||
"label": "Blur Image",
|
||||
"description": "Blurs the selected layer.",
|
||||
"blur_type": "Blur Type",
|
||||
"blur_radius": "Radius",
|
||||
"gaussian_type": "Gaussian",
|
||||
"box_type": "Box"
|
||||
},
|
||||
"img_noise": {
|
||||
"label": "Noise Image",
|
||||
"description": "Adds noise to the selected layer.",
|
||||
"noise_type": "Noise Type",
|
||||
"noise_amount": "Amount",
|
||||
"gaussian_type": "Gaussian",
|
||||
"salt_and_pepper_type": "Salt and Pepper",
|
||||
"noise_color": "Colored Noise",
|
||||
"size": "Noise Size"
|
||||
}
|
||||
},
|
||||
"transform": {
|
||||
@@ -2173,12 +2133,15 @@
|
||||
"toGetStartedLocal": "To get started, make sure to download or import models needed to run Invoke. Then, enter a prompt in the box and click <StrongComponent>Invoke</StrongComponent> to generate your first image. Select a prompt template to improve results. You can choose to save your images directly to the <StrongComponent>Gallery</StrongComponent> or edit them to the <StrongComponent>Canvas</StrongComponent>.",
|
||||
"toGetStarted": "To get started, enter a prompt in the box and click <StrongComponent>Invoke</StrongComponent> to generate your first image. Select a prompt template to improve results. You can choose to save your images directly to the <StrongComponent>Gallery</StrongComponent> or edit them to the <StrongComponent>Canvas</StrongComponent>.",
|
||||
"gettingStartedSeries": "Want more guidance? Check out our <LinkComponent>Getting Started Series</LinkComponent> for tips on unlocking the full potential of the Invoke Studio.",
|
||||
"lowVRAMMode": "For best performance, follow our <LinkComponent>Low VRAM guide</LinkComponent>.",
|
||||
"noModelsInstalled": "It looks like you don't have any models installed! You can <DownloadStarterModelsButton>download a starter model bundle</DownloadStarterModelsButton> or <ImportModelsButton>import models</ImportModelsButton>."
|
||||
"downloadStarterModels": "Download Starter Models",
|
||||
"importModels": "Import Models",
|
||||
"noModelsInstalled": "It looks like you don't have any models installed"
|
||||
},
|
||||
"whatsNew": {
|
||||
"whatsNewInInvoke": "What's New in Invoke",
|
||||
"items": ["Low-VRAM mode", "Dynamic memory management", "Faster model loading times", "Fewer memory errors"],
|
||||
"items": [
|
||||
"<StrongComponent>Flux Control Layers</StrongComponent>: New control models for edge detection and depth mapping are now supported for Flux dev models."
|
||||
],
|
||||
"readReleaseNotes": "Read Release Notes",
|
||||
"watchRecentReleaseVideos": "Watch Recent Release Videos",
|
||||
"watchUiUpdatesOverview": "Watch UI Updates Overview"
|
||||
|
||||
@@ -610,13 +610,9 @@
|
||||
"hfTokenSaved": "Gettone HF salvato",
|
||||
"hfForbidden": "Non hai accesso a questo modello HF",
|
||||
"hfTokenLabel": "Gettone HuggingFace (richiesto per alcuni modelli)",
|
||||
"hfForbiddenErrorMessage": "Consigliamo di visitare la pagina del repository. Il proprietario potrebbe richiedere l'accettazione dei termini per poter effettuare il download.",
|
||||
"hfForbiddenErrorMessage": "Consigliamo di visitare la pagina del repository su HuggingFace.com. Il proprietario potrebbe richiedere l'accettazione dei termini per poter effettuare il download.",
|
||||
"hfTokenInvalidErrorMessage2": "Aggiornalo in ",
|
||||
"controlLora": "Controllo LoRA",
|
||||
"urlUnauthorizedErrorMessage2": "Scopri come qui.",
|
||||
"urlForbidden": "Non hai accesso a questo modello",
|
||||
"urlForbiddenErrorMessage": "Potrebbe essere necessario richiedere l'autorizzazione al sito che distribuisce il modello.",
|
||||
"urlUnauthorizedErrorMessage": "Potrebbe essere necessario configurare un gettone API per accedere a questo modello."
|
||||
"controlLora": "Controllo LoRA"
|
||||
},
|
||||
"parameters": {
|
||||
"images": "Immagini",
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
{
|
||||
"common": {
|
||||
"hotkeysLabel": "Skróty klawiszowe",
|
||||
"languagePickerLabel": "Język",
|
||||
"languagePickerLabel": "Wybór języka",
|
||||
"reportBugLabel": "Zgłoś błąd",
|
||||
"settingsLabel": "Ustawienia",
|
||||
"img2img": "Obraz na obraz",
|
||||
"nodes": "Węzły",
|
||||
"upload": "Prześlij",
|
||||
"load": "Załaduj",
|
||||
"statusDisconnected": "Odłączono",
|
||||
"statusDisconnected": "Odłączono od serwera",
|
||||
"githubLabel": "GitHub",
|
||||
"discordLabel": "Discord",
|
||||
"clipboard": "Schowek",
|
||||
@@ -27,87 +27,13 @@
|
||||
"back": "Do tyłu",
|
||||
"auto": "Automatyczny",
|
||||
"beta": "Beta",
|
||||
"close": "Wyjdź",
|
||||
"checkpoint": "Punkt kontrolny",
|
||||
"controlNet": "ControlNet",
|
||||
"details": "Detale",
|
||||
"direction": "Kierunek",
|
||||
"ipAdapter": "Adapter IP",
|
||||
"dontAskMeAgain": "Nie pytaj ponownie",
|
||||
"modelManager": "Menedżer modeli",
|
||||
"blue": "Niebieski",
|
||||
"orderBy": "Sortuj według",
|
||||
"openInNewTab": "Otwórz w nowym oknie",
|
||||
"somethingWentWrong": "Coś poszło nie tak",
|
||||
"green": "Zielony",
|
||||
"red": "Czerwony",
|
||||
"imageFailedToLoad": "Nie można załadować obrazu",
|
||||
"saveAs": "Zapisz jako",
|
||||
"outputs": "Wyjścia",
|
||||
"data": "Dane",
|
||||
"localSystem": "System Lokalny",
|
||||
"t2iAdapter": "Adapter T2I",
|
||||
"selected": "Zaznaczone",
|
||||
"warnings": "Ostrzeżenia",
|
||||
"save": "Zapisz",
|
||||
"created": "Stworzono",
|
||||
"alpha": "Alfa",
|
||||
"error": "Bład",
|
||||
"editor": "Edytor",
|
||||
"loading": "Ładuję",
|
||||
"edit": "Edytuj",
|
||||
"enabled": "Aktywny",
|
||||
"communityLabel": "Społeczeństwo",
|
||||
"linear": "Liniowy",
|
||||
"installed": "Zainstalowany",
|
||||
"dontShowMeThese": "Nie pokazuj mi tego",
|
||||
"openInViewer": "Otwórz podgląd",
|
||||
"safetensors": "Bezpieczniki",
|
||||
"ok": "Ok",
|
||||
"goTo": "Idź do",
|
||||
"loadingImage": "wczytywanie zdjęcia",
|
||||
"input": "Wejście",
|
||||
"view": "Podgląd",
|
||||
"learnMore": "Dowiedz się więcej",
|
||||
"notInstalled": "Nie $t(common.installed)",
|
||||
"loadingModel": "Wczytywanie modelu",
|
||||
"postprocessing": "Przetwarzanie końcowe",
|
||||
"random": "Losowo",
|
||||
"disabled": "Wyłączony",
|
||||
"generating": "Generowanie",
|
||||
"simple": "Prosty",
|
||||
"folder": "Katalog",
|
||||
"format": "Format",
|
||||
"updated": "Zaktualizowano",
|
||||
"unknown": "nieznany",
|
||||
"delete": "Usuń",
|
||||
"template": "Szablon",
|
||||
"txt2img": "Tekst na obraz",
|
||||
"prevPage": "Poprzednia strona",
|
||||
"file": "Plik",
|
||||
"toResolve": "Do rozwiązania",
|
||||
"nextPage": "Następna strona",
|
||||
"unknownError": "Nieznany błąd",
|
||||
"placeholderSelectAModel": "Wybierz model",
|
||||
"new": "Nowy",
|
||||
"none": "Żadne",
|
||||
"reset": "Reset",
|
||||
"on": "Włączony",
|
||||
"aboutHeading": "Posiadaj swoją kreatywną moc"
|
||||
"close": "Wyjdź"
|
||||
},
|
||||
"gallery": {
|
||||
"galleryImageSize": "Rozmiar obrazów",
|
||||
"gallerySettings": "Ustawienia galerii",
|
||||
"autoSwitchNewImages": "Przełączaj na nowe obrazy",
|
||||
"noImagesInGallery": "Brak obrazów w galerii",
|
||||
"gallery": "Galeria",
|
||||
"alwaysShowImageSizeBadge": "Zawsze pokazuj odznakę wielkości obrazu",
|
||||
"assetsTab": "Pliki, które wrzuciłeś do użytku w twoich projektach.",
|
||||
"currentlyInUse": "Ten obraz jest obecnie w użyciu przez następujące funkcje:",
|
||||
"boardsSettings": "Ustawienia tablic",
|
||||
"assets": "Aktywy",
|
||||
"autoAssignBoardOnClick": "Automatycznie przypisz tablicę po kliknięciu",
|
||||
"copy": "Kopiuj"
|
||||
"noImagesInGallery": "Brak obrazów w galerii"
|
||||
},
|
||||
"parameters": {
|
||||
"images": "L. obrazów",
|
||||
@@ -157,14 +83,7 @@
|
||||
"previousImage": "Poprzedni obraz",
|
||||
"nextImage": "Następny obraz",
|
||||
"menu": "Menu",
|
||||
"mode": "Tryb",
|
||||
"resetUI": "$t(accessibility.reset) UI",
|
||||
"uploadImages": "Wgrywaj obrazy",
|
||||
"about": "Informacje",
|
||||
"toggleRightPanel": "Przełącz prawy panel (G)",
|
||||
"toggleLeftPanel": "Przełącz lewy panel (G)",
|
||||
"createIssue": "Stwórz problem",
|
||||
"submitSupportTicket": "Wyślij bilet pomocy"
|
||||
"mode": "Tryb"
|
||||
},
|
||||
"boards": {
|
||||
"cancel": "Anuluj",
|
||||
@@ -179,34 +98,7 @@
|
||||
"downloadBoard": "Pobierz tablice",
|
||||
"loading": "Ładowanie...",
|
||||
"move": "Przenieś",
|
||||
"noMatching": "Brak pasujących tablic",
|
||||
"addBoard": "Dodaj tablicę",
|
||||
"autoAddBoard": "Automatycznie dodaj tablicę",
|
||||
"searchBoard": "Szukaj tablic.",
|
||||
"unarchiveBoard": "Odarchiwizuj tablicę",
|
||||
"selectedForAutoAdd": "Wybrany do automatycznego dodania",
|
||||
"deleteBoard": "Usuń tablicę",
|
||||
"clearSearch": "Usuń historię",
|
||||
"hideBoards": "Ukryj tablice",
|
||||
"viewBoards": "Zobacz tablice",
|
||||
"addSharedBoard": "Dodaj udostępnioną tablicę",
|
||||
"boards": "Tablice",
|
||||
"addPrivateBoard": "Dodaj prywatną tablicę",
|
||||
"movingImagesToBoard_one": "Przenoszenie {{count}} zdjęcia do tablicy:",
|
||||
"movingImagesToBoard_few": "Przenoszenie {{count}} zdjęć do tablicy:",
|
||||
"movingImagesToBoard_many": "Przenoszenie {{count}} zdjęć do tablicy:",
|
||||
"shared": "Udostępnione tablice",
|
||||
"topMessage": "Ta tablica zawiera obrazy wykorzystywane w następujących funkcjach:",
|
||||
"deletedPrivateBoardsCannotbeRestored": "Usunięte tablice nie mogą być odzyskane. Wybierając \"Usuń tylko tablicę\" spowoduje że obrazy zostaną przeniesione do prywatnego nieskategoryzowanego stanu autora obrazu.",
|
||||
"changeBoard": "Zmień tablicę",
|
||||
"bottomMessage": "Usuwając tę tablicę oraz jej obrazów zresetują wszystkie funkcje które obecnie ich używają.",
|
||||
"deleteBoardAndImages": "Usuń tablicę i zdjęcia",
|
||||
"deleteBoardOnly": "Usuń tylko tablicę",
|
||||
"deletedBoardsCannotbeRestored": "Usunięte tablice nie mogą być odzyskane. Wybierając \"Usuń tylko tablicę\" spowoduje że obrazy zostaną przeniesione do nieskategoryzowanego stanu.",
|
||||
"archiveBoard": "Zarchiwizuj tablicę",
|
||||
"archived": "Zarchiwizowano",
|
||||
"myBoard": "Moja tablica",
|
||||
"menuItemAutoAdd": "Automatycznie dodaj do tej tablicy"
|
||||
"noMatching": "Brak pasujących tablic"
|
||||
},
|
||||
"accordions": {
|
||||
"compositing": {
|
||||
@@ -227,103 +119,5 @@
|
||||
"control": {
|
||||
"title": "Kontrola"
|
||||
}
|
||||
},
|
||||
"hrf": {
|
||||
"metadata": {
|
||||
"enabled": "Włączono poprawkę wysokiej rozdzielczości",
|
||||
"strength": "Moc poprawki wysokiej rozdzielczości",
|
||||
"method": "Metoda High Resolution Fix"
|
||||
},
|
||||
"hrf": "Poprawka \"Wysoka rozdzielczość\"",
|
||||
"enableHrf": "Włącz poprawkę wysokiej rozdzielczości"
|
||||
},
|
||||
"queue": {
|
||||
"cancelTooltip": "Anuluj aktualną pozycję",
|
||||
"resumeFailed": "Błąd z kontynuowaniem procesora",
|
||||
"current": "Obecne",
|
||||
"cancelBatchFailed": "Problem z anulacją masy",
|
||||
"queueFront": "Dodaj do przodu kolejki",
|
||||
"cancelBatch": "Anuluj serię",
|
||||
"cancelFailed": "Problem z anulowaniem pozycji",
|
||||
"pruneTooltip": "Wyczyść {{item_count}} skończonych pozycji",
|
||||
"pruneSucceeded": "Wyczyszczono {{item_count}} zakończonych pozycji z kolejki",
|
||||
"cancelBatchSucceeded": "Partia anulowana",
|
||||
"clear": "Wyczyść",
|
||||
"clearTooltip": "Anuluj i usuń wszystkie pozycje",
|
||||
"clearSucceeded": "Kolejka wyczyszczona",
|
||||
"cancelItem": "Anuluj pozycję",
|
||||
"clearQueueAlertDialog2": "Czy na pewno chcesz wyczyścić kolejkę?",
|
||||
"pauseFailed": "Problem z zapauzowaniem processora",
|
||||
"clearFailed": "Problem z czyszczeniem kolejki",
|
||||
"queueBack": "Dodaj do kolejki",
|
||||
"queueEmpty": "Kolejka pusta",
|
||||
"enqueueing": "Kolejkowanie partii",
|
||||
"resumeTooltip": "Kontynuuj processor",
|
||||
"resumeSucceeded": "Processor kontynuowany",
|
||||
"pause": "Zapauzuj",
|
||||
"pauseTooltip": "Zapauzuj processor",
|
||||
"queue": "Kolejka",
|
||||
"resume": "Kontynuuj",
|
||||
"cancel": "Anuluj",
|
||||
"cancelSucceeded": "Pozycja anulowana",
|
||||
"prune": "Wyczyść",
|
||||
"pauseSucceeded": "Processor zapauzowany",
|
||||
"clearQueueAlertDialog": "Czyszczenie kolejki od razu anuluje wszystkie przetwarzane elementy and całkowicie czyści kolejkę. Oczekujące filtry zostaną anulowane.",
|
||||
"pruneFailed": "Problem z wyczyszczeniem kolejki",
|
||||
"batchQueued": "Masa w kolejce",
|
||||
"openQueue": "Otwórz kolejkę",
|
||||
"iterations_one": "Iteracja",
|
||||
"iterations_few": "Iteracje",
|
||||
"iterations_many": "Iteracje",
|
||||
"graphQueued": "Wykres w kolejce",
|
||||
"canvas": "Płótno",
|
||||
"generation": "Generacja",
|
||||
"status": "Status",
|
||||
"total": "Suma",
|
||||
"time": "Czas",
|
||||
"front": "Przód",
|
||||
"back": "tył",
|
||||
"batchFailedToQueue": "Nie można zkolejkować masy",
|
||||
"completedIn": "Ukończony w całości",
|
||||
"other": "Inne",
|
||||
"origin": "Pochodzenie",
|
||||
"destination": "Miejsce docelowe",
|
||||
"notReady": "Nie można zkolejkować",
|
||||
"canceled": "Anulowano",
|
||||
"in_progress": "W trakcie",
|
||||
"gallery": "Galeria",
|
||||
"session": "Sesja",
|
||||
"pending": "W toku",
|
||||
"completed": "Zakończono",
|
||||
"item": "Pozycja",
|
||||
"failed": "Niepowodzenie",
|
||||
"batchFieldValues": "Masowe Wartości pól",
|
||||
"graphFailedToQueue": "NIe udało się dodać tabeli do kolejki",
|
||||
"workflows": "Przepływy pracy",
|
||||
"next": "Następny",
|
||||
"batchQueuedDesc_one": "Dodano {{count}} sesję do {{direction}} kolejki",
|
||||
"batchQueuedDesc_few": "Dodano {{count}} sesje do {{direction}} kolejki",
|
||||
"batchQueuedDesc_many": "Dodano {{count}} sesje do {{direction}} kolejki",
|
||||
"batch": "Masa",
|
||||
"upscaling": "Skalowanie w górę",
|
||||
"generations_one": "Generacja",
|
||||
"generations_few": "Generacje",
|
||||
"generations_many": "Generacje",
|
||||
"prompts_one": "Monit",
|
||||
"prompts_few": "Monity",
|
||||
"prompts_many": "Monity",
|
||||
"batchSize": "Rozmiar masy"
|
||||
},
|
||||
"prompt": {
|
||||
"compatibleEmbeddings": "Kompatybilne osadzenia",
|
||||
"noMatchingTriggers": "Nie dopasowywanie spustów"
|
||||
},
|
||||
"invocationCache": {
|
||||
"hits": "Uderzenia cache",
|
||||
"enable": "Włącz",
|
||||
"clear": "Wyczyść",
|
||||
"disable": "Wyłącz",
|
||||
"maxCacheSize": "Maksymalny rozmiar cache",
|
||||
"cacheSize": "Rozmiar Cache"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -231,7 +231,7 @@
|
||||
"resume": "Tiếp Tục",
|
||||
"enqueueing": "Xếp Vào Hàng Hàng Loạt",
|
||||
"prompts_other": "Lệnh",
|
||||
"iterations_other": "Vòng Lặp",
|
||||
"iterations_other": "Lặp Lại",
|
||||
"total": "Tổng",
|
||||
"pruneFailed": "Có Vấn Đề Khi Cắt Bớt Mục Khỏi Hàng",
|
||||
"clearSucceeded": "Hàng Đã Được Dọn Sạch",
|
||||
@@ -271,7 +271,7 @@
|
||||
"queueFront": "Thêm Vào Đầu Hàng",
|
||||
"resumeTooltip": "Tiếp Tục Bộ Xử Lý",
|
||||
"clearFailed": "Có Vấn Đề Khi Dọn Dẹp Hàng",
|
||||
"generations_other": "Ảnh Tạo Sinh",
|
||||
"generations_other": "Máy Tạo Sinh",
|
||||
"cancelBatch": "Huỷ Bỏ Lô",
|
||||
"status": "Trạng Thái",
|
||||
"pending": "Đang Chờ",
|
||||
@@ -648,7 +648,7 @@
|
||||
"defaultSettingsSaved": "Đã Lưu Thiết Lập Mặc Định",
|
||||
"description": "Dòng Mô Tả",
|
||||
"imageEncoderModelId": "ID Model Image Encoder",
|
||||
"hfForbiddenErrorMessage": "Chúng tôi gợi ý vào các repository. Chủ sở hữu có thể yêu cầu chấp nhận điều khoản để tải xuống.",
|
||||
"hfForbiddenErrorMessage": "Chúng tôi gợi ý vào trang repository trên HuggingFace.com. Chủ sở hữu có thể yêu cầu chấp nhận điều khoản để tải xuống.",
|
||||
"hfTokenSaved": "Đã Lưu HF Token",
|
||||
"learnMoreAboutSupportedModels": "Tìm hiểu thêm về những model được hỗ trợ",
|
||||
"availableModels": "Model Có Sẵn",
|
||||
@@ -731,7 +731,7 @@
|
||||
"repo_id": "ID Repository",
|
||||
"scanFolder": "Quét Thư Mục",
|
||||
"scanFolderHelper": "Thư mục sẽ được quét để tìm model. Có thể sẽ mất nhiều thời gian với những thư mục lớn.",
|
||||
"scanResults": "Kết Quả",
|
||||
"scanResults": "Kết Quả Quét",
|
||||
"t5Encoder": "T5 Encoder",
|
||||
"mainModelTriggerPhrases": "Từ Ngữ Kích Hoạt Cho Model Chính",
|
||||
"textualInversions": "Bộ Đảo Ngược Văn Bản",
|
||||
@@ -739,12 +739,7 @@
|
||||
"width": "Chiều Rộng",
|
||||
"starterModelsInModelManager": "Model khởi đầu có thể tìm thấy ở Trình Quản Lý Model",
|
||||
"clipLEmbed": "CLIP-L Embed",
|
||||
"clipGEmbed": "CLIP-G Embed",
|
||||
"controlLora": "LoRA Điều Khiển Được",
|
||||
"urlUnauthorizedErrorMessage2": "Tìm hiểu thêm.",
|
||||
"urlForbidden": "Bạn không có quyền truy cập vào model này",
|
||||
"urlForbiddenErrorMessage": "Bạn có thể cần yêu cầu quyền truy cập từ trang web đang cung cấp model.",
|
||||
"urlUnauthorizedErrorMessage": "Bạn có thể cần thiếp lập một token API để dùng được model này."
|
||||
"clipGEmbed": "CLIP-G Embed"
|
||||
},
|
||||
"metadata": {
|
||||
"guidance": "Hướng Dẫn",
|
||||
@@ -1438,8 +1433,7 @@
|
||||
"collectionEmpty": "{{nodeLabel}} -> {{fieldLabel}} tài nguyên trống",
|
||||
"collectionTooFewItems": "{{nodeLabel}} -> {{fieldLabel}}: quá ít mục, tối thiểu {{minItems}}",
|
||||
"collectionTooManyItems": "{{nodeLabel}} -> {{fieldLabel}}: quá nhiều mục, tối đa {{maxItems}}",
|
||||
"canvasIsSelectingObject": "Canvas đang bận (đang chọn đồ vật)",
|
||||
"fluxModelMultipleControlLoRAs": "Chỉ có thể dùng 1 LoRA Điều Khiển Được"
|
||||
"canvasIsSelectingObject": "Canvas đang bận (đang chọn đồ vật)"
|
||||
},
|
||||
"cfgScale": "Thang CFG",
|
||||
"useSeed": "Dùng Hạt Giống",
|
||||
|
||||
@@ -2,19 +2,10 @@ import { logger } from 'app/logging/logger';
|
||||
import { enqueueRequested } from 'app/store/actions';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import type { ImageField } from 'features/nodes/types/common';
|
||||
import {
|
||||
isFloatFieldCollectionInputInstance,
|
||||
isImageFieldCollectionInputInstance,
|
||||
isIntegerFieldCollectionInputInstance,
|
||||
isStringFieldCollectionInputInstance,
|
||||
} from 'features/nodes/types/field';
|
||||
import { resolveNumberFieldCollectionValue } from 'features/nodes/types/fieldValidators';
|
||||
import type { InvocationNodeEdge } from 'features/nodes/types/invocation';
|
||||
import { isBatchNode, isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { isImageFieldCollectionInputInstance } from 'features/nodes/types/field';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { buildNodesGraph } from 'features/nodes/util/graph/buildNodesGraph';
|
||||
import { buildWorkflowWithValidation } from 'features/nodes/util/workflow/buildWorkflow';
|
||||
import { groupBy } from 'lodash-es';
|
||||
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
|
||||
import type { Batch, BatchConfig } from 'services/api/types';
|
||||
|
||||
@@ -42,140 +33,28 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
|
||||
|
||||
const data: Batch['data'] = [];
|
||||
|
||||
const batchNodes = nodes.nodes.filter(isInvocationNode).filter(isBatchNode);
|
||||
|
||||
// Handle zipping batch nodes. First group the batch nodes by their batch_group_id
|
||||
const groupedBatchNodes = groupBy(batchNodes, (node) => node.data.inputs['batch_group_id']?.value);
|
||||
|
||||
const addProductBatchDataCollectionItem = (
|
||||
edges: InvocationNodeEdge[],
|
||||
items?: ImageField[] | string[] | number[]
|
||||
) => {
|
||||
const productBatchDataCollectionItems: NonNullable<Batch['data']>[number] = [];
|
||||
for (const edge of edges) {
|
||||
// Skip edges from batch nodes - these should not be in the graph, they exist only in the UI
|
||||
const imageBatchNodes = nodes.nodes.filter(isInvocationNode).filter((node) => node.data.type === 'image_batch');
|
||||
for (const node of imageBatchNodes) {
|
||||
const images = node.data.inputs['images'];
|
||||
if (!isImageFieldCollectionInputInstance(images)) {
|
||||
log.warn({ nodeId: node.id }, 'Image batch images field is not an image collection');
|
||||
break;
|
||||
}
|
||||
const edgesFromImageBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'image');
|
||||
const batchDataCollectionItem: NonNullable<Batch['data']>[number] = [];
|
||||
for (const edge of edgesFromImageBatch) {
|
||||
if (!edge.targetHandle) {
|
||||
break;
|
||||
}
|
||||
productBatchDataCollectionItems.push({
|
||||
batchDataCollectionItem.push({
|
||||
node_path: edge.target,
|
||||
field_name: edge.targetHandle,
|
||||
items,
|
||||
items: images.value,
|
||||
});
|
||||
}
|
||||
if (productBatchDataCollectionItems.length > 0) {
|
||||
data.push(productBatchDataCollectionItems);
|
||||
}
|
||||
};
|
||||
|
||||
// Then, we will create a batch data collection item for each group
|
||||
for (const [batchGroupId, batchNodes] of Object.entries(groupedBatchNodes)) {
|
||||
const zippedBatchDataCollectionItems: NonNullable<Batch['data']>[number] = [];
|
||||
const addZippedBatchDataCollectionItem = (
|
||||
edges: InvocationNodeEdge[],
|
||||
items?: ImageField[] | string[] | number[]
|
||||
) => {
|
||||
for (const edge of edges) {
|
||||
if (!edge.targetHandle) {
|
||||
break;
|
||||
}
|
||||
zippedBatchDataCollectionItems.push({
|
||||
node_path: edge.target,
|
||||
field_name: edge.targetHandle,
|
||||
items,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// Grab image batch nodes for special handling
|
||||
const imageBatchNodes = batchNodes.filter((node) => node.data.type === 'image_batch');
|
||||
|
||||
for (const node of imageBatchNodes) {
|
||||
// Satisfy TS
|
||||
const images = node.data.inputs['images'];
|
||||
if (!isImageFieldCollectionInputInstance(images)) {
|
||||
log.warn({ nodeId: node.id }, 'Image batch images field is not an image collection');
|
||||
break;
|
||||
}
|
||||
|
||||
// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
|
||||
const edgesFromImageBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'image');
|
||||
if (batchGroupId !== 'None') {
|
||||
addZippedBatchDataCollectionItem(edgesFromImageBatch, images.value);
|
||||
} else {
|
||||
addProductBatchDataCollectionItem(edgesFromImageBatch, images.value);
|
||||
}
|
||||
}
|
||||
|
||||
// Grab string batch nodes for special handling
|
||||
const stringBatchNodes = batchNodes.filter((node) => node.data.type === 'string_batch');
|
||||
for (const node of stringBatchNodes) {
|
||||
// Satisfy TS
|
||||
const strings = node.data.inputs['strings'];
|
||||
if (!isStringFieldCollectionInputInstance(strings)) {
|
||||
log.warn({ nodeId: node.id }, 'String batch strings field is not a string collection');
|
||||
break;
|
||||
}
|
||||
|
||||
// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
|
||||
const edgesFromStringBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'value');
|
||||
if (batchGroupId !== 'None') {
|
||||
addZippedBatchDataCollectionItem(edgesFromStringBatch, strings.value);
|
||||
} else {
|
||||
addProductBatchDataCollectionItem(edgesFromStringBatch, strings.value);
|
||||
}
|
||||
}
|
||||
|
||||
// Grab integer batch nodes for special handling
|
||||
const integerBatchNodes = batchNodes.filter((node) => node.data.type === 'integer_batch');
|
||||
for (const node of integerBatchNodes) {
|
||||
// Satisfy TS
|
||||
const integers = node.data.inputs['integers'];
|
||||
if (!isIntegerFieldCollectionInputInstance(integers)) {
|
||||
log.warn({ nodeId: node.id }, 'Integer batch integers field is not an integer collection');
|
||||
break;
|
||||
}
|
||||
if (!integers.value) {
|
||||
log.warn({ nodeId: node.id }, 'Integer batch integers field is empty');
|
||||
break;
|
||||
}
|
||||
|
||||
// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
|
||||
const edgesFromStringBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'value');
|
||||
const resolvedValue = resolveNumberFieldCollectionValue(integers);
|
||||
if (batchGroupId !== 'None') {
|
||||
addZippedBatchDataCollectionItem(edgesFromStringBatch, resolvedValue);
|
||||
} else {
|
||||
addProductBatchDataCollectionItem(edgesFromStringBatch, resolvedValue);
|
||||
}
|
||||
}
|
||||
|
||||
// Grab float batch nodes for special handling
|
||||
const floatBatchNodes = batchNodes.filter((node) => node.data.type === 'float_batch');
|
||||
for (const node of floatBatchNodes) {
|
||||
// Satisfy TS
|
||||
const floats = node.data.inputs['floats'];
|
||||
if (!isFloatFieldCollectionInputInstance(floats)) {
|
||||
log.warn({ nodeId: node.id }, 'Float batch floats field is not a float collection');
|
||||
break;
|
||||
}
|
||||
if (!floats.value) {
|
||||
log.warn({ nodeId: node.id }, 'Float batch floats field is empty');
|
||||
break;
|
||||
}
|
||||
|
||||
// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
|
||||
const edgesFromStringBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'value');
|
||||
const resolvedValue = resolveNumberFieldCollectionValue(floats);
|
||||
if (batchGroupId !== 'None') {
|
||||
addZippedBatchDataCollectionItem(edgesFromStringBatch, resolvedValue);
|
||||
} else {
|
||||
addProductBatchDataCollectionItem(edgesFromStringBatch, resolvedValue);
|
||||
}
|
||||
}
|
||||
|
||||
// Finally, if this batch data collection item has any items, add it to the data array
|
||||
if (batchGroupId !== 'None' && zippedBatchDataCollectionItems.length > 0) {
|
||||
data.push(zippedBatchDataCollectionItems);
|
||||
if (batchDataCollectionItem.length > 0) {
|
||||
data.push(batchDataCollectionItem);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -166,10 +166,8 @@ export const createStore = (uniqueStoreKey?: string, persist = true) =>
|
||||
reducer: rememberedRootReducer,
|
||||
middleware: (getDefaultMiddleware) =>
|
||||
getDefaultMiddleware({
|
||||
serializableCheck: false,
|
||||
immutableCheck: false,
|
||||
// serializableCheck: import.meta.env.MODE === 'development',
|
||||
// immutableCheck: import.meta.env.MODE === 'development',
|
||||
serializableCheck: import.meta.env.MODE === 'development',
|
||||
immutableCheck: import.meta.env.MODE === 'development',
|
||||
})
|
||||
.concat(api.middleware)
|
||||
.concat(dynamicMiddlewares)
|
||||
|
||||
@@ -57,7 +57,6 @@ export const CanvasMainPanelContent = memo(() => {
|
||||
gap={2}
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
overflow="hidden"
|
||||
>
|
||||
<CanvasManagerProviderGate>
|
||||
<CanvasToolbar />
|
||||
@@ -71,7 +70,6 @@ export const CanvasMainPanelContent = memo(() => {
|
||||
h="full"
|
||||
bg={dynamicGrid ? 'base.850' : 'base.900'}
|
||||
borderRadius="base"
|
||||
overflow="hidden"
|
||||
>
|
||||
<InvokeCanvasComponent />
|
||||
<CanvasManagerProviderGate>
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
|
||||
import { Combobox, CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import type { BlurFilterConfig, BlurTypes } from 'features/controlLayers/store/filters';
|
||||
import { IMAGE_FILTERS, isBlurTypes } from 'features/controlLayers/store/filters';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { FilterComponentProps } from './types';
|
||||
|
||||
type Props = FilterComponentProps<BlurFilterConfig>;
|
||||
const DEFAULTS = IMAGE_FILTERS.img_blur.buildDefaults();
|
||||
|
||||
export const FilterBlur = memo(({ onChange, config }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const handleBlurTypeChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
if (!isBlurTypes(v?.value)) {
|
||||
return;
|
||||
}
|
||||
onChange({ ...config, blur_type: v.value });
|
||||
},
|
||||
[config, onChange]
|
||||
);
|
||||
|
||||
const handleRadiusChange = useCallback(
|
||||
(v: number) => {
|
||||
onChange({ ...config, radius: v });
|
||||
},
|
||||
[config, onChange]
|
||||
);
|
||||
|
||||
const options: { label: string; value: BlurTypes }[] = useMemo(
|
||||
() => [
|
||||
{ label: t('controlLayers.filter.img_blur.gaussian_type'), value: 'gaussian' },
|
||||
{ label: t('controlLayers.filter.img_blur.box_type'), value: 'box' },
|
||||
],
|
||||
[t]
|
||||
);
|
||||
|
||||
const value = useMemo(() => options.filter((o) => o.value === config.blur_type)[0], [options, config.blur_type]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<FormControl>
|
||||
<FormLabel m={0}>{t('controlLayers.filter.img_blur.blur_type')}</FormLabel>
|
||||
<Combobox value={value} options={options} onChange={handleBlurTypeChange} isSearchable={false} />
|
||||
</FormControl>
|
||||
<FormControl>
|
||||
<FormLabel m={0}>{t('controlLayers.filter.img_blur.blur_radius')}</FormLabel>
|
||||
<CompositeSlider
|
||||
value={config.radius}
|
||||
defaultValue={DEFAULTS.radius}
|
||||
onChange={handleRadiusChange}
|
||||
min={1}
|
||||
max={64}
|
||||
step={0.1}
|
||||
marks
|
||||
/>
|
||||
<CompositeNumberInput
|
||||
value={config.radius}
|
||||
defaultValue={DEFAULTS.radius}
|
||||
onChange={handleRadiusChange}
|
||||
min={1}
|
||||
max={4096}
|
||||
step={0.1}
|
||||
/>
|
||||
</FormControl>
|
||||
</>
|
||||
);
|
||||
});
|
||||
|
||||
FilterBlur.displayName = 'FilterBlur';
|
||||
@@ -1,111 +0,0 @@
|
||||
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
|
||||
import { Combobox, CompositeNumberInput, CompositeSlider, FormControl, FormLabel, Switch } from '@invoke-ai/ui-library';
|
||||
import type { NoiseFilterConfig, NoiseTypes } from 'features/controlLayers/store/filters';
|
||||
import { IMAGE_FILTERS, isNoiseTypes } from 'features/controlLayers/store/filters';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { FilterComponentProps } from './types';
|
||||
|
||||
type Props = FilterComponentProps<NoiseFilterConfig>;
|
||||
const DEFAULTS = IMAGE_FILTERS.img_noise.buildDefaults();
|
||||
|
||||
export const FilterNoise = memo(({ onChange, config }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const handleNoiseTypeChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
if (!isNoiseTypes(v?.value)) {
|
||||
return;
|
||||
}
|
||||
onChange({ ...config, noise_type: v.value });
|
||||
},
|
||||
[config, onChange]
|
||||
);
|
||||
|
||||
const handleAmountChange = useCallback(
|
||||
(v: number) => {
|
||||
onChange({ ...config, amount: v });
|
||||
},
|
||||
[config, onChange]
|
||||
);
|
||||
|
||||
const handleColorChange = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => {
|
||||
onChange({ ...config, noise_color: e.target.checked });
|
||||
},
|
||||
[config, onChange]
|
||||
);
|
||||
|
||||
const handleSizeChange = useCallback(
|
||||
(v: number) => {
|
||||
onChange({ ...config, size: v });
|
||||
},
|
||||
[config, onChange]
|
||||
);
|
||||
|
||||
const options: { label: string; value: NoiseTypes }[] = useMemo(
|
||||
() => [
|
||||
{ label: t('controlLayers.filter.img_noise.gaussian_type'), value: 'gaussian' },
|
||||
{ label: t('controlLayers.filter.img_noise.salt_and_pepper_type'), value: 'salt_and_pepper' },
|
||||
],
|
||||
[t]
|
||||
);
|
||||
|
||||
const value = useMemo(() => options.filter((o) => o.value === config.noise_type)[0], [options, config.noise_type]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<FormControl>
|
||||
<FormLabel m={0}>{t('controlLayers.filter.img_noise.noise_type')}</FormLabel>
|
||||
<Combobox value={value} options={options} onChange={handleNoiseTypeChange} isSearchable={false} />
|
||||
</FormControl>
|
||||
<FormControl>
|
||||
<FormLabel m={0}>{t('controlLayers.filter.img_noise.noise_amount')}</FormLabel>
|
||||
<CompositeSlider
|
||||
value={config.amount}
|
||||
defaultValue={DEFAULTS.amount}
|
||||
onChange={handleAmountChange}
|
||||
min={0}
|
||||
max={1}
|
||||
step={0.01}
|
||||
marks
|
||||
/>
|
||||
<CompositeNumberInput
|
||||
value={config.amount}
|
||||
defaultValue={DEFAULTS.amount}
|
||||
onChange={handleAmountChange}
|
||||
min={0}
|
||||
max={1}
|
||||
step={0.01}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormControl>
|
||||
<FormLabel m={0}>{t('controlLayers.filter.img_noise.size')}</FormLabel>
|
||||
<CompositeSlider
|
||||
value={config.size}
|
||||
defaultValue={DEFAULTS.size}
|
||||
onChange={handleSizeChange}
|
||||
min={1}
|
||||
max={16}
|
||||
step={1}
|
||||
marks
|
||||
/>
|
||||
<CompositeNumberInput
|
||||
value={config.size}
|
||||
defaultValue={DEFAULTS.size}
|
||||
onChange={handleSizeChange}
|
||||
min={1}
|
||||
max={256}
|
||||
step={1}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormControl w="max-content">
|
||||
<FormLabel m={0}>{t('controlLayers.filter.img_noise.noise_color')}</FormLabel>
|
||||
<Switch defaultChecked={DEFAULTS.noise_color} isChecked={config.noise_color} onChange={handleColorChange} />
|
||||
</FormControl>
|
||||
</>
|
||||
);
|
||||
});
|
||||
|
||||
FilterNoise.displayName = 'Filternoise';
|
||||
@@ -1,5 +1,4 @@
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import { FilterBlur } from 'features/controlLayers/components/Filters/FilterBlur';
|
||||
import { FilterCannyEdgeDetection } from 'features/controlLayers/components/Filters/FilterCannyEdgeDetection';
|
||||
import { FilterColorMap } from 'features/controlLayers/components/Filters/FilterColorMap';
|
||||
import { FilterContentShuffle } from 'features/controlLayers/components/Filters/FilterContentShuffle';
|
||||
@@ -9,7 +8,6 @@ import { FilterHEDEdgeDetection } from 'features/controlLayers/components/Filter
|
||||
import { FilterLineartEdgeDetection } from 'features/controlLayers/components/Filters/FilterLineartEdgeDetection';
|
||||
import { FilterMediaPipeFaceDetection } from 'features/controlLayers/components/Filters/FilterMediaPipeFaceDetection';
|
||||
import { FilterMLSDDetection } from 'features/controlLayers/components/Filters/FilterMLSDDetection';
|
||||
import { FilterNoise } from 'features/controlLayers/components/Filters/FilterNoise';
|
||||
import { FilterPiDiNetEdgeDetection } from 'features/controlLayers/components/Filters/FilterPiDiNetEdgeDetection';
|
||||
import { FilterSpandrel } from 'features/controlLayers/components/Filters/FilterSpandrel';
|
||||
import type { FilterConfig } from 'features/controlLayers/store/filters';
|
||||
@@ -21,10 +19,6 @@ type Props = { filterConfig: FilterConfig; onChange: (filterConfig: FilterConfig
|
||||
export const FilterSettings = memo(({ filterConfig, onChange }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
if (filterConfig.type === 'img_blur') {
|
||||
return <FilterBlur config={filterConfig} onChange={onChange} />;
|
||||
}
|
||||
|
||||
if (filterConfig.type === 'canny_edge_detection') {
|
||||
return <FilterCannyEdgeDetection config={filterConfig} onChange={onChange} />;
|
||||
}
|
||||
@@ -65,10 +59,6 @@ export const FilterSettings = memo(({ filterConfig, onChange }: Props) => {
|
||||
return <FilterPiDiNetEdgeDetection config={filterConfig} onChange={onChange} />;
|
||||
}
|
||||
|
||||
if (filterConfig.type === 'img_noise') {
|
||||
return <FilterNoise config={filterConfig} onChange={onChange} />;
|
||||
}
|
||||
|
||||
if (filterConfig.type === 'spandrel_filter') {
|
||||
return <FilterSpandrel config={filterConfig} onChange={onChange} />;
|
||||
}
|
||||
|
||||
@@ -297,9 +297,10 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
|
||||
const imageState = imageDTOToImageObject(filterResult.value);
|
||||
this.$imageState.set(imageState);
|
||||
|
||||
// Stash the existing image module - we will destroy it after the new image is rendered to prevent a flash
|
||||
// of an empty layer
|
||||
const oldImageModule = this.imageModule;
|
||||
// Destroy any existing masked image and create a new one
|
||||
if (this.imageModule) {
|
||||
this.imageModule.destroy();
|
||||
}
|
||||
|
||||
this.imageModule = new CanvasObjectImage(imageState, this);
|
||||
|
||||
@@ -308,16 +309,6 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
|
||||
|
||||
this.konva.group.add(this.imageModule.konva.group);
|
||||
|
||||
// The filtered image have some transparency, so we need to hide the objects of the parent entity to prevent the
|
||||
// two images from blending. We will show the objects again in the teardown method, which is always called after
|
||||
// the filter finishes (applied or canceled).
|
||||
this.parent.renderer.hideObjects();
|
||||
|
||||
if (oldImageModule) {
|
||||
// Destroy the old image module now that the new one is rendered
|
||||
oldImageModule.destroy();
|
||||
}
|
||||
|
||||
// The porcessing is complete, set can set the last processed hash and isProcessing to false
|
||||
this.$lastProcessedHash.set(hash);
|
||||
|
||||
@@ -433,8 +424,6 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
|
||||
|
||||
teardown = () => {
|
||||
this.unsubscribe();
|
||||
// Re-enable the objects of the parent entity
|
||||
this.parent.renderer.showObjects();
|
||||
this.konva.group.remove();
|
||||
// The reset must be done _after_ unsubscribing from listeners, in case the listeners would otherwise react to
|
||||
// the reset. For example, if auto-processing is enabled and we reset the state, it may trigger processing.
|
||||
|
||||
@@ -185,14 +185,6 @@ export class CanvasEntityObjectRenderer extends CanvasModuleBase {
|
||||
return didRender;
|
||||
};
|
||||
|
||||
hideObjects = () => {
|
||||
this.konva.objectGroup.hide();
|
||||
};
|
||||
|
||||
showObjects = () => {
|
||||
this.konva.objectGroup.show();
|
||||
};
|
||||
|
||||
adoptObjectRenderer = (renderer: AnyObjectRenderer) => {
|
||||
this.renderers.set(renderer.id, renderer);
|
||||
renderer.konva.group.moveTo(this.konva.objectGroup);
|
||||
|
||||
@@ -10,7 +10,6 @@ import {
|
||||
getKonvaNodeDebugAttrs,
|
||||
getPrefixedId,
|
||||
offsetCoord,
|
||||
roundRect,
|
||||
} from 'features/controlLayers/konva/util';
|
||||
import { selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
|
||||
import type { Coordinate, Rect, RectWithRotation } from 'features/controlLayers/store/types';
|
||||
@@ -774,7 +773,7 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
|
||||
const rect = this.getRelativeRect();
|
||||
const rasterizeResult = await withResultAsync(() =>
|
||||
this.parent.renderer.rasterize({
|
||||
rect: roundRect(rect),
|
||||
rect,
|
||||
replaceObjects: true,
|
||||
ignoreCache: true,
|
||||
attrs: { opacity: 1, filters: [] },
|
||||
|
||||
@@ -740,12 +740,3 @@ export const getColorAtCoordinate = (stage: Konva.Stage, coord: Coordinate): Rgb
|
||||
|
||||
return { r, g, b };
|
||||
};
|
||||
|
||||
export const roundRect = (rect: Rect): Rect => {
|
||||
return {
|
||||
x: Math.round(rect.x),
|
||||
y: Math.round(rect.y),
|
||||
width: Math.round(rect.width),
|
||||
height: Math.round(rect.height),
|
||||
};
|
||||
};
|
||||
|
||||
@@ -95,28 +95,6 @@ const zSpandrelFilterConfig = z.object({
|
||||
});
|
||||
export type SpandrelFilterConfig = z.infer<typeof zSpandrelFilterConfig>;
|
||||
|
||||
const zBlurTypes = z.enum(['gaussian', 'box']);
|
||||
export type BlurTypes = z.infer<typeof zBlurTypes>;
|
||||
export const isBlurTypes = (v: unknown): v is BlurTypes => zBlurTypes.safeParse(v).success;
|
||||
const zBlurFilterConfig = z.object({
|
||||
type: z.literal('img_blur'),
|
||||
blur_type: zBlurTypes,
|
||||
radius: z.number().gte(0),
|
||||
});
|
||||
export type BlurFilterConfig = z.infer<typeof zBlurFilterConfig>;
|
||||
|
||||
const zNoiseTypes = z.enum(['gaussian', 'salt_and_pepper']);
|
||||
export type NoiseTypes = z.infer<typeof zNoiseTypes>;
|
||||
export const isNoiseTypes = (v: unknown): v is NoiseTypes => zNoiseTypes.safeParse(v).success;
|
||||
const zNoiseFilterConfig = z.object({
|
||||
type: z.literal('img_noise'),
|
||||
noise_type: zNoiseTypes,
|
||||
amount: z.number().gte(0).lte(1),
|
||||
noise_color: z.boolean(),
|
||||
size: z.number().int().gte(1),
|
||||
});
|
||||
export type NoiseFilterConfig = z.infer<typeof zNoiseFilterConfig>;
|
||||
|
||||
const zFilterConfig = z.discriminatedUnion('type', [
|
||||
zCannyEdgeDetectionFilterConfig,
|
||||
zColorMapFilterConfig,
|
||||
@@ -131,8 +109,6 @@ const zFilterConfig = z.discriminatedUnion('type', [
|
||||
zPiDiNetEdgeDetectionFilterConfig,
|
||||
zDWOpenposeDetectionFilterConfig,
|
||||
zSpandrelFilterConfig,
|
||||
zBlurFilterConfig,
|
||||
zNoiseFilterConfig,
|
||||
]);
|
||||
export type FilterConfig = z.infer<typeof zFilterConfig>;
|
||||
|
||||
@@ -150,8 +126,6 @@ const zFilterType = z.enum([
|
||||
'pidi_edge_detection',
|
||||
'dw_openpose_detection',
|
||||
'spandrel_filter',
|
||||
'img_blur',
|
||||
'img_noise',
|
||||
]);
|
||||
export type FilterType = z.infer<typeof zFilterType>;
|
||||
export const isFilterType = (v: unknown): v is FilterType => zFilterType.safeParse(v).success;
|
||||
@@ -455,62 +429,6 @@ export const IMAGE_FILTERS: { [key in FilterConfig['type']]: ImageFilterData<key
|
||||
return true;
|
||||
},
|
||||
},
|
||||
img_blur: {
|
||||
type: 'img_blur',
|
||||
buildDefaults: () => ({
|
||||
type: 'img_blur',
|
||||
blur_type: 'gaussian',
|
||||
radius: 8,
|
||||
}),
|
||||
buildGraph: ({ image_name }, { blur_type, radius }) => {
|
||||
const graph = new Graph(getPrefixedId('img_blur'));
|
||||
const node = graph.addNode({
|
||||
id: getPrefixedId('img_blur'),
|
||||
type: 'img_blur',
|
||||
image: { image_name },
|
||||
blur_type: blur_type,
|
||||
radius: radius,
|
||||
});
|
||||
return {
|
||||
graph,
|
||||
outputNodeId: node.id,
|
||||
};
|
||||
},
|
||||
},
|
||||
img_noise: {
|
||||
type: 'img_noise',
|
||||
buildDefaults: () => ({
|
||||
type: 'img_noise',
|
||||
noise_type: 'gaussian',
|
||||
amount: 0.3,
|
||||
noise_color: true,
|
||||
size: 1,
|
||||
}),
|
||||
buildGraph: ({ image_name }, { noise_type, amount, noise_color, size }) => {
|
||||
const graph = new Graph(getPrefixedId('img_noise'));
|
||||
const node = graph.addNode({
|
||||
id: getPrefixedId('img_noise'),
|
||||
type: 'img_noise',
|
||||
image: { image_name },
|
||||
noise_type: noise_type,
|
||||
amount: amount,
|
||||
noise_color: noise_color,
|
||||
size: size,
|
||||
});
|
||||
const rand = graph.addNode({
|
||||
id: getPrefixedId('rand_int'),
|
||||
use_cache: false,
|
||||
type: 'rand_int',
|
||||
low: 0,
|
||||
high: 2147483647,
|
||||
});
|
||||
graph.addEdge(rand, 'value', node, 'seed');
|
||||
return {
|
||||
graph,
|
||||
outputNodeId: node.id,
|
||||
};
|
||||
},
|
||||
},
|
||||
} as const;
|
||||
|
||||
/**
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import type {
|
||||
BlurFilterConfig,
|
||||
CannyEdgeDetectionFilterConfig,
|
||||
ColorMapFilterConfig,
|
||||
ContentShuffleFilterConfig,
|
||||
@@ -13,7 +12,6 @@ import type {
|
||||
LineartEdgeDetectionFilterConfig,
|
||||
MediaPipeFaceDetectionFilterConfig,
|
||||
MLSDDetectionFilterConfig,
|
||||
NoiseFilterConfig,
|
||||
NormalMapFilterConfig,
|
||||
PiDiNetEdgeDetectionFilterConfig,
|
||||
} from 'features/controlLayers/store/filters';
|
||||
@@ -56,7 +54,6 @@ describe('Control Adapter Types', () => {
|
||||
});
|
||||
test('Processor Configs', () => {
|
||||
// Types derived from OpenAPI
|
||||
type _BlurFilterConfig = Required<Pick<Invocation<'img_blur'>, 'type' | 'radius' | 'blur_type'>>;
|
||||
type _CannyEdgeDetectionFilterConfig = Required<
|
||||
Pick<Invocation<'canny_edge_detection'>, 'type' | 'low_threshold' | 'high_threshold'>
|
||||
>;
|
||||
@@ -74,9 +71,6 @@ describe('Control Adapter Types', () => {
|
||||
type _MLSDDetectionFilterConfig = Required<
|
||||
Pick<Invocation<'mlsd_detection'>, 'type' | 'score_threshold' | 'distance_threshold'>
|
||||
>;
|
||||
type _NoiseFilterConfig = Required<
|
||||
Pick<Invocation<'img_noise'>, 'type' | 'noise_type' | 'amount' | 'noise_color' | 'size'>
|
||||
>;
|
||||
type _NormalMapFilterConfig = Required<Pick<Invocation<'normal_map'>, 'type'>>;
|
||||
type _DWOpenposeDetectionFilterConfig = Required<
|
||||
Pick<Invocation<'dw_openpose_detection'>, 'type' | 'draw_body' | 'draw_face' | 'draw_hands'>
|
||||
@@ -87,7 +81,6 @@ describe('Control Adapter Types', () => {
|
||||
|
||||
// The processor configs are manually modeled zod schemas. This test ensures that the inferred types are correct.
|
||||
// The types prefixed with `_` are types generated from OpenAPI, while the types without the prefix are manually modeled.
|
||||
assert<Equals<_BlurFilterConfig, BlurFilterConfig>>();
|
||||
assert<Equals<_CannyEdgeDetectionFilterConfig, CannyEdgeDetectionFilterConfig>>();
|
||||
assert<Equals<_ColorMapFilterConfig, ColorMapFilterConfig>>();
|
||||
assert<Equals<_ContentShuffleFilterConfig, ContentShuffleFilterConfig>>();
|
||||
@@ -97,7 +90,6 @@ describe('Control Adapter Types', () => {
|
||||
assert<Equals<_LineartEdgeDetectionFilterConfig, LineartEdgeDetectionFilterConfig>>();
|
||||
assert<Equals<_MediaPipeFaceDetectionFilterConfig, MediaPipeFaceDetectionFilterConfig>>();
|
||||
assert<Equals<_MLSDDetectionFilterConfig, MLSDDetectionFilterConfig>>();
|
||||
assert<Equals<_NoiseFilterConfig, NoiseFilterConfig>>();
|
||||
assert<Equals<_NormalMapFilterConfig, NormalMapFilterConfig>>();
|
||||
assert<Equals<_DWOpenposeDetectionFilterConfig, DWOpenposeDetectionFilterConfig>>();
|
||||
assert<Equals<_PiDiNetEdgeDetectionFilterConfig, PiDiNetEdgeDetectionFilterConfig>>();
|
||||
|
||||
@@ -11,7 +11,7 @@ import type { DndTargetState } from 'features/dnd/types';
|
||||
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
|
||||
import { selectMaxImageUploadCount } from 'features/system/store/configSlice';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { memo, useCallback, useEffect, useRef, useState } from 'react';
|
||||
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { uploadImages } from 'services/api/endpoints/images';
|
||||
import { useBoardName } from 'services/api/hooks/useBoardName';
|
||||
@@ -72,10 +72,11 @@ export const FullscreenDropzone = memo(() => {
|
||||
const maxImageUploadCount = useAppSelector(selectMaxImageUploadCount);
|
||||
const [dndState, setDndState] = useState<DndTargetState>('idle');
|
||||
|
||||
const uploadFilesSchema = useMemo(() => getFilesSchema(maxImageUploadCount), [maxImageUploadCount]);
|
||||
|
||||
const validateAndUploadFiles = useCallback(
|
||||
(files: File[]) => {
|
||||
const { getState } = getStore();
|
||||
const uploadFilesSchema = getFilesSchema(maxImageUploadCount);
|
||||
const parseResult = uploadFilesSchema.safeParse(files);
|
||||
|
||||
if (!parseResult.success) {
|
||||
@@ -104,18 +105,7 @@ export const FullscreenDropzone = memo(() => {
|
||||
|
||||
uploadImages(uploadArgs);
|
||||
},
|
||||
[maxImageUploadCount, t]
|
||||
);
|
||||
|
||||
const onPaste = useCallback(
|
||||
(e: ClipboardEvent) => {
|
||||
if (!e.clipboardData?.files) {
|
||||
return;
|
||||
}
|
||||
const files = Array.from(e.clipboardData.files);
|
||||
validateAndUploadFiles(files);
|
||||
},
|
||||
[validateAndUploadFiles]
|
||||
[maxImageUploadCount, t, uploadFilesSchema]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
@@ -154,12 +144,24 @@ export const FullscreenDropzone = memo(() => {
|
||||
}, [validateAndUploadFiles]);
|
||||
|
||||
useEffect(() => {
|
||||
window.addEventListener('paste', onPaste);
|
||||
const controller = new AbortController();
|
||||
|
||||
document.addEventListener(
|
||||
'paste',
|
||||
(e) => {
|
||||
if (!e.clipboardData?.files) {
|
||||
return;
|
||||
}
|
||||
const files = Array.from(e.clipboardData.files);
|
||||
validateAndUploadFiles(files);
|
||||
},
|
||||
{ signal: controller.signal }
|
||||
);
|
||||
|
||||
return () => {
|
||||
window.removeEventListener('paste', onPaste);
|
||||
controller.abort();
|
||||
};
|
||||
}, [onPaste]);
|
||||
}, [validateAndUploadFiles]);
|
||||
|
||||
return (
|
||||
<Box ref={ref} data-dnd-state={dndState} sx={sx}>
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppDispatch, RootState } from 'app/store/store';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import type {
|
||||
@@ -10,6 +9,7 @@ import { selectComparisonImages } from 'features/gallery/components/ImageViewer/
|
||||
import type { BoardId } from 'features/gallery/store/types';
|
||||
import {
|
||||
addImagesToBoard,
|
||||
addImagesToNodeImageFieldCollectionAction,
|
||||
createNewCanvasEntityFromImage,
|
||||
removeImagesFromBoard,
|
||||
replaceCanvasEntityObjectsWithImage,
|
||||
@@ -19,14 +19,10 @@ import {
|
||||
setRegionalGuidanceReferenceImage,
|
||||
setUpscaleInitialImage,
|
||||
} from 'features/imageActions/actions';
|
||||
import { fieldImageCollectionValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { selectFieldInputInstance, selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { type FieldIdentifier, isImageFieldCollectionInputInstance } from 'features/nodes/types/field';
|
||||
import type { FieldIdentifier } from 'features/nodes/types/field';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import type { JsonObject } from 'type-fest';
|
||||
|
||||
const log = logger('dnd');
|
||||
|
||||
type RecordUnknown = Record<string | symbol, unknown>;
|
||||
|
||||
type DndData<
|
||||
@@ -272,27 +268,15 @@ export const addImagesToNodeImageFieldCollectionDndTarget: DndTarget<
|
||||
}
|
||||
|
||||
const { fieldIdentifier } = targetData.payload;
|
||||
|
||||
const fieldInputInstance = selectFieldInputInstance(
|
||||
selectNodesSlice(getState()),
|
||||
fieldIdentifier.nodeId,
|
||||
fieldIdentifier.fieldName
|
||||
);
|
||||
|
||||
if (!isImageFieldCollectionInputInstance(fieldInputInstance)) {
|
||||
log.warn({ fieldIdentifier }, 'Attempted to add images to a non-image field collection');
|
||||
return;
|
||||
}
|
||||
|
||||
const newValue = fieldInputInstance.value ? [...fieldInputInstance.value] : [];
|
||||
const imageDTOs: ImageDTO[] = [];
|
||||
|
||||
if (singleImageDndSource.typeGuard(sourceData)) {
|
||||
newValue.push({ image_name: sourceData.payload.imageDTO.image_name });
|
||||
imageDTOs.push(sourceData.payload.imageDTO);
|
||||
} else {
|
||||
newValue.push(...sourceData.payload.imageDTOs.map(({ image_name }) => ({ image_name })));
|
||||
imageDTOs.push(...sourceData.payload.imageDTOs);
|
||||
}
|
||||
|
||||
dispatch(fieldImageCollectionValueChanged({ ...fieldIdentifier, value: newValue }));
|
||||
addImagesToNodeImageFieldCollectionAction({ fieldIdentifier, imageDTOs, dispatch, getState });
|
||||
},
|
||||
};
|
||||
//#endregion
|
||||
|
||||
@@ -46,7 +46,6 @@ export const ImageViewer = memo(({ closeButton }: Props) => {
|
||||
left={0}
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
overflow="hidden"
|
||||
>
|
||||
{hasImageToCompare && <CompareToolbar />}
|
||||
{!hasImageToCompare && <ViewerToolbar closeButton={closeButton} />}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import type { ButtonProps } from '@invoke-ai/ui-library';
|
||||
import { Alert, AlertDescription, AlertIcon, Button, Divider, Flex, Link, Spinner, Text } from '@invoke-ai/ui-library';
|
||||
import { Button, Divider, Flex, Spinner, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import { InvokeLogoIcon } from 'common/components/InvokeLogoIcon';
|
||||
@@ -8,10 +7,9 @@ import { $installModelsTab } from 'features/modelManagerV2/subpanels/InstallMode
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { selectIsLocal } from 'features/system/store/configSlice';
|
||||
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||
import type { PropsWithChildren } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { Trans, useTranslation } from 'react-i18next';
|
||||
import { PiArrowSquareOutBold, PiImageBold } from 'react-icons/pi';
|
||||
import { PiImageBold } from 'react-icons/pi';
|
||||
import { useMainModels } from 'services/api/hooks/modelsByType';
|
||||
|
||||
export const NoContentForViewer = memo(() => {
|
||||
@@ -20,105 +18,6 @@ export const NoContentForViewer = memo(() => {
|
||||
const isLocal = useAppSelector(selectIsLocal);
|
||||
const isEnabled = useFeatureStatus('starterModels');
|
||||
const { t } = useTranslation();
|
||||
|
||||
const showStarterBundles = useMemo(() => {
|
||||
return isEnabled && data && mainModels.length === 0;
|
||||
}, [mainModels.length, data, isEnabled]);
|
||||
|
||||
if (hasImages === LOADING_SYMBOL) {
|
||||
// Blank bg w/ a spinner. The new user experience components below have an invoke logo, but it's not centered.
|
||||
// If we show the logo while loading, there is an awkward layout shift where the invoke logo moves a bit. Less
|
||||
// jarring to show a blank bg with a spinner - it will only be shown for a moment as we do the initial images
|
||||
// fetching.
|
||||
return <LoadingSpinner />;
|
||||
}
|
||||
|
||||
if (hasImages) {
|
||||
return <IAINoContentFallback icon={PiImageBold} label={t('gallery.noImageSelected')} />;
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" gap={8} alignItems="center" textAlign="center" maxW="600px">
|
||||
<InvokeLogoIcon w={32} h={32} />
|
||||
<Flex flexDir="column" gap={4} alignItems="center" textAlign="center">
|
||||
{isLocal ? <GetStartedLocal /> : <GetStartedCommercial />}
|
||||
{showStarterBundles && <StarterBundlesCallout />}
|
||||
<Divider />
|
||||
<GettingStartedVideosCallout />
|
||||
{isLocal && <LowVRAMAlert />}
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
NoContentForViewer.displayName = 'NoContentForViewer';
|
||||
|
||||
const LoadingSpinner = () => {
|
||||
return (
|
||||
<Flex position="relative" width="full" height="full" alignItems="center" justifyContent="center">
|
||||
<Spinner label="Loading" color="grey" position="absolute" size="sm" width={8} height={8} right={4} bottom={4} />
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export const ExternalLink = (props: ButtonProps & { href: string }) => {
|
||||
return (
|
||||
<Button
|
||||
as={Link}
|
||||
variant="unstyled"
|
||||
isExternal
|
||||
display="inline-flex"
|
||||
alignItems="center"
|
||||
rightIcon={<PiArrowSquareOutBold />}
|
||||
color="base.50"
|
||||
mt={-1}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
const InlineButton = (props: PropsWithChildren<{ onClick: () => void }>) => {
|
||||
return (
|
||||
<Button variant="link" size="md" onClick={props.onClick} color="base.50">
|
||||
{props.children}
|
||||
</Button>
|
||||
);
|
||||
};
|
||||
|
||||
const StrongComponent = <Text as="span" color="base.50" fontSize="md" />;
|
||||
|
||||
const GetStartedLocal = () => {
|
||||
return (
|
||||
<Text fontSize="md" color="base.200">
|
||||
<Trans i18nKey="newUserExperience.toGetStartedLocal" components={{ StrongComponent }} />
|
||||
</Text>
|
||||
);
|
||||
};
|
||||
|
||||
const GetStartedCommercial = () => {
|
||||
return (
|
||||
<Text fontSize="md" color="base.200">
|
||||
<Trans i18nKey="newUserExperience.toGetStarted" components={{ StrongComponent }} />
|
||||
</Text>
|
||||
);
|
||||
};
|
||||
|
||||
const GettingStartedVideosCallout = () => {
|
||||
return (
|
||||
<Text fontSize="md" color="base.200">
|
||||
<Trans
|
||||
i18nKey="newUserExperience.gettingStartedSeries"
|
||||
components={{
|
||||
LinkComponent: (
|
||||
<ExternalLink href="https://www.youtube.com/playlist?list=PLvWK1Kc8iXGrQy8r9TYg6QdUuJ5MMx-ZO" />
|
||||
),
|
||||
}}
|
||||
/>
|
||||
</Text>
|
||||
);
|
||||
};
|
||||
|
||||
const StarterBundlesCallout = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const handleClickDownloadStarterModels = useCallback(() => {
|
||||
@@ -131,31 +30,89 @@ const StarterBundlesCallout = () => {
|
||||
$installModelsTab.set(0);
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<Text fontSize="md" color="base.200">
|
||||
<Trans
|
||||
i18nKey="newUserExperience.noModelsInstalled"
|
||||
components={{
|
||||
DownloadStarterModelsButton: <InlineButton onClick={handleClickDownloadStarterModels} />,
|
||||
ImportModelsButton: <InlineButton onClick={handleClickImportModels} />,
|
||||
}}
|
||||
/>
|
||||
</Text>
|
||||
);
|
||||
};
|
||||
const showStarterBundles = useMemo(() => {
|
||||
return isEnabled && data && mainModels.length === 0;
|
||||
}, [mainModels.length, data, isEnabled]);
|
||||
|
||||
if (hasImages === LOADING_SYMBOL) {
|
||||
return (
|
||||
// Blank bg w/ a spinner. The new user experience components below have an invoke logo, but it's not centered.
|
||||
// If we show the logo while loading, there is an awkward layout shift where the invoke logo moves a bit. Less
|
||||
// jarring to show a blank bg with a spinner - it will only be shown for a moment as we do the initial images
|
||||
// fetching.
|
||||
<Flex position="relative" width="full" height="full" alignItems="center" justifyContent="center">
|
||||
<Spinner label="Loading" color="grey" position="absolute" size="sm" width={8} height={8} right={4} bottom={4} />
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
if (hasImages) {
|
||||
return <IAINoContentFallback icon={PiImageBold} label={t('gallery.noImageSelected')} />;
|
||||
}
|
||||
|
||||
const LowVRAMAlert = () => {
|
||||
return (
|
||||
<Alert status="warning" borderRadius="base" fontSize="md" shadow="md" w="fit-content">
|
||||
<AlertIcon />
|
||||
<AlertDescription>
|
||||
<Trans
|
||||
i18nKey="newUserExperience.lowVRAMMode"
|
||||
components={{
|
||||
LinkComponent: <ExternalLink href="https://invoke-ai.github.io/InvokeAI/features/low-vram/" />,
|
||||
}}
|
||||
/>
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
<Flex flexDir="column" gap={4} alignItems="center" textAlign="center" maxW="600px">
|
||||
<InvokeLogoIcon w={40} h={40} />
|
||||
<Flex flexDir="column" gap={8} alignItems="center" textAlign="center">
|
||||
<Text fontSize="md" color="base.200" pt={16}>
|
||||
{isLocal ? (
|
||||
<Trans
|
||||
i18nKey="newUserExperience.toGetStartedLocal"
|
||||
components={{
|
||||
StrongComponent: <Text as="span" color="white" fontSize="md" fontWeight="semibold" />,
|
||||
}}
|
||||
/>
|
||||
) : (
|
||||
<Trans
|
||||
i18nKey="newUserExperience.toGetStarted"
|
||||
components={{
|
||||
StrongComponent: <Text as="span" color="white" fontSize="md" fontWeight="semibold" />,
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</Text>
|
||||
|
||||
{showStarterBundles && (
|
||||
<Flex flexDir="column" gap={2} alignItems="center">
|
||||
<Text fontSize="md" color="base.200">
|
||||
{t('newUserExperience.noModelsInstalled')}
|
||||
</Text>
|
||||
<Flex gap={3} alignItems="center">
|
||||
<Button size="sm" onClick={handleClickDownloadStarterModels}>
|
||||
{t('newUserExperience.downloadStarterModels')}
|
||||
</Button>
|
||||
<Text fontSize="sm" color="base.200">
|
||||
{t('common.or')}
|
||||
</Text>
|
||||
<Button size="sm" onClick={handleClickImportModels}>
|
||||
{t('newUserExperience.importModels')}
|
||||
</Button>
|
||||
</Flex>
|
||||
</Flex>
|
||||
)}
|
||||
|
||||
<Divider />
|
||||
|
||||
<Text fontSize="md" color="base.200">
|
||||
<Trans
|
||||
i18nKey="newUserExperience.gettingStartedSeries"
|
||||
components={{
|
||||
LinkComponent: (
|
||||
<Text
|
||||
as="a"
|
||||
color="white"
|
||||
fontSize="md"
|
||||
fontWeight="semibold"
|
||||
href="https://www.youtube.com/playlist?list=PLvWK1Kc8iXGrQy8r9TYg6QdUuJ5MMx-ZO"
|
||||
target="_blank"
|
||||
/>
|
||||
),
|
||||
}}
|
||||
/>
|
||||
</Text>
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
});
|
||||
|
||||
NoContentForViewer.displayName = 'NoContentForViewer';
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppDispatch, RootState } from 'app/store/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { selectDefaultIPAdapter } from 'features/controlLayers/hooks/addLayerHooks';
|
||||
@@ -30,15 +31,19 @@ import { imageDTOToImageObject, imageDTOToImageWithDims, initialControlNet } fro
|
||||
import { calculateNewSize } from 'features/controlLayers/util/getScaledBoundingBoxDimensions';
|
||||
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
|
||||
import type { BoardId } from 'features/gallery/store/types';
|
||||
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { FieldIdentifier } from 'features/nodes/types/field';
|
||||
import { fieldImageCollectionValueChanged, fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { selectFieldInputInstance, selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { type FieldIdentifier, isImageFieldCollectionInputInstance } from 'features/nodes/types/field';
|
||||
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
|
||||
import { getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||
import { uniqBy } from 'lodash-es';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import type { Equals } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const log = logger('system');
|
||||
|
||||
export const setGlobalReferenceImage = (arg: {
|
||||
imageDTO: ImageDTO;
|
||||
entityIdentifier: CanvasEntityIdentifier<'reference_image'>;
|
||||
@@ -72,6 +77,54 @@ export const setNodeImageFieldImage = (arg: {
|
||||
dispatch(fieldImageValueChanged({ ...fieldIdentifier, value: imageDTO }));
|
||||
};
|
||||
|
||||
export const addImagesToNodeImageFieldCollectionAction = (arg: {
|
||||
imageDTOs: ImageDTO[];
|
||||
fieldIdentifier: FieldIdentifier;
|
||||
dispatch: AppDispatch;
|
||||
getState: () => RootState;
|
||||
}) => {
|
||||
const { imageDTOs, fieldIdentifier, dispatch, getState } = arg;
|
||||
const fieldInputInstance = selectFieldInputInstance(
|
||||
selectNodesSlice(getState()),
|
||||
fieldIdentifier.nodeId,
|
||||
fieldIdentifier.fieldName
|
||||
);
|
||||
|
||||
if (!isImageFieldCollectionInputInstance(fieldInputInstance)) {
|
||||
log.warn({ fieldIdentifier }, 'Attempted to add images to a non-image field collection');
|
||||
return;
|
||||
}
|
||||
|
||||
const images = fieldInputInstance.value ? [...fieldInputInstance.value] : [];
|
||||
images.push(...imageDTOs.map(({ image_name }) => ({ image_name })));
|
||||
const uniqueImages = uniqBy(images, 'image_name');
|
||||
dispatch(fieldImageCollectionValueChanged({ ...fieldIdentifier, value: uniqueImages }));
|
||||
};
|
||||
|
||||
export const removeImageFromNodeImageFieldCollectionAction = (arg: {
|
||||
imageName: string;
|
||||
fieldIdentifier: FieldIdentifier;
|
||||
dispatch: AppDispatch;
|
||||
getState: () => RootState;
|
||||
}) => {
|
||||
const { imageName, fieldIdentifier, dispatch, getState } = arg;
|
||||
const fieldInputInstance = selectFieldInputInstance(
|
||||
selectNodesSlice(getState()),
|
||||
fieldIdentifier.nodeId,
|
||||
fieldIdentifier.fieldName
|
||||
);
|
||||
|
||||
if (!isImageFieldCollectionInputInstance(fieldInputInstance)) {
|
||||
log.warn({ fieldIdentifier }, 'Attempted to remove image from a non-image field collection');
|
||||
return;
|
||||
}
|
||||
|
||||
const images = fieldInputInstance.value ? [...fieldInputInstance.value] : [];
|
||||
const imagesWithoutTheImageToRemove = images.filter((image) => image.image_name !== imageName);
|
||||
const uniqueImages = uniqBy(imagesWithoutTheImageToRemove, 'image_name');
|
||||
dispatch(fieldImageCollectionValueChanged({ ...fieldIdentifier, value: uniqueImages }));
|
||||
};
|
||||
|
||||
export const setComparisonImage = (arg: { imageDTO: ImageDTO; dispatch: AppDispatch }) => {
|
||||
const { imageDTO, dispatch } = arg;
|
||||
dispatch(imageToCompareChanged(imageDTO));
|
||||
|
||||
@@ -43,7 +43,7 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
|
||||
{fieldNames.connectionFields.map((fieldName, i) => (
|
||||
<GridItem gridColumnStart={1} gridRowStart={i + 1} key={`${nodeId}.${fieldName}.input-field`}>
|
||||
<InvocationInputFieldCheck nodeId={nodeId} fieldName={fieldName}>
|
||||
<InputField nodeId={nodeId} fieldName={fieldName} isLinearView={false} />
|
||||
<InputField nodeId={nodeId} fieldName={fieldName} />
|
||||
</InvocationInputFieldCheck>
|
||||
</GridItem>
|
||||
))}
|
||||
@@ -59,7 +59,7 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
|
||||
nodeId={nodeId}
|
||||
fieldName={fieldName}
|
||||
>
|
||||
<InputField nodeId={nodeId} fieldName={fieldName} isLinearView={false} />
|
||||
<InputField nodeId={nodeId} fieldName={fieldName} />
|
||||
</InvocationInputFieldCheck>
|
||||
))}
|
||||
{fieldNames.missingFields.map((fieldName) => (
|
||||
@@ -68,7 +68,7 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
|
||||
nodeId={nodeId}
|
||||
fieldName={fieldName}
|
||||
>
|
||||
<InputField nodeId={nodeId} fieldName={fieldName} isLinearView={false} />
|
||||
<InputField nodeId={nodeId} fieldName={fieldName} />
|
||||
</InvocationInputFieldCheck>
|
||||
))}
|
||||
</Flex>
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance';
|
||||
import { useFieldValue } from 'features/nodes/hooks/useFieldValue';
|
||||
import {
|
||||
selectWorkflowSlice,
|
||||
workflowExposedFieldAdded,
|
||||
@@ -19,7 +19,7 @@ type Props = {
|
||||
const FieldLinearViewToggle = ({ nodeId, fieldName }: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const field = useFieldInputInstance(nodeId, fieldName);
|
||||
const value = useFieldValue(nodeId, fieldName);
|
||||
const selectIsExposed = useMemo(
|
||||
() =>
|
||||
createSelector(selectWorkflowSlice, (workflow) => {
|
||||
@@ -31,11 +31,8 @@ const FieldLinearViewToggle = ({ nodeId, fieldName }: Props) => {
|
||||
const isExposed = useAppSelector(selectIsExposed);
|
||||
|
||||
const handleExposeField = useCallback(() => {
|
||||
if (!field) {
|
||||
return;
|
||||
}
|
||||
dispatch(workflowExposedFieldAdded({ nodeId, fieldName, field }));
|
||||
}, [dispatch, field, fieldName, nodeId]);
|
||||
dispatch(workflowExposedFieldAdded({ nodeId, fieldName, value }));
|
||||
}, [dispatch, fieldName, nodeId, value]);
|
||||
|
||||
const handleUnexposeField = useCallback(() => {
|
||||
dispatch(workflowExposedFieldRemoved({ nodeId, fieldName }));
|
||||
|
||||
@@ -14,10 +14,9 @@ import { InputFieldWrapper } from './InputFieldWrapper';
|
||||
interface Props {
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
isLinearView: boolean;
|
||||
}
|
||||
|
||||
const InputField = ({ nodeId, fieldName, isLinearView }: Props) => {
|
||||
const InputField = ({ nodeId, fieldName }: Props) => {
|
||||
const fieldTemplate = useFieldInputTemplate(nodeId, fieldName);
|
||||
const [isHovered, setIsHovered] = useState(false);
|
||||
const isInvalid = useFieldIsInvalid(nodeId, fieldName);
|
||||
@@ -70,12 +69,12 @@ const InputField = ({ nodeId, fieldName, isLinearView }: Props) => {
|
||||
px={2}
|
||||
>
|
||||
<Flex flexDir="column" w="full" gap={1} onMouseEnter={onMouseEnter} onMouseLeave={onMouseLeave}>
|
||||
<Flex gap={1} alignItems="center">
|
||||
<Flex gap={1}>
|
||||
<EditableFieldTitle nodeId={nodeId} fieldName={fieldName} kind="inputs" isInvalid={isInvalid} withTooltip />
|
||||
{isHovered && <FieldResetToDefaultValueButton nodeId={nodeId} fieldName={fieldName} />}
|
||||
{isHovered && <FieldLinearViewToggle nodeId={nodeId} fieldName={fieldName} />}
|
||||
</Flex>
|
||||
<InputFieldRenderer nodeId={nodeId} fieldName={fieldName} isLinearView={isLinearView} />
|
||||
<InputFieldRenderer nodeId={nodeId} fieldName={fieldName} />
|
||||
</Flex>
|
||||
</FormControl>
|
||||
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import { ImageFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldCollectionInputComponent';
|
||||
import ModelIdentifierFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent';
|
||||
import { NumberFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/NumberFieldCollectionInputComponent';
|
||||
import { StringFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/StringFieldCollectionInputComponent';
|
||||
import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance';
|
||||
import { useFieldInputTemplate } from 'features/nodes/hooks/useFieldInputTemplate';
|
||||
import {
|
||||
@@ -23,8 +21,6 @@ import {
|
||||
isControlNetModelFieldInputTemplate,
|
||||
isEnumFieldInputInstance,
|
||||
isEnumFieldInputTemplate,
|
||||
isFloatFieldCollectionInputInstance,
|
||||
isFloatFieldCollectionInputTemplate,
|
||||
isFloatFieldInputInstance,
|
||||
isFloatFieldInputTemplate,
|
||||
isFluxMainModelFieldInputInstance,
|
||||
@@ -35,8 +31,6 @@ import {
|
||||
isImageFieldCollectionInputTemplate,
|
||||
isImageFieldInputInstance,
|
||||
isImageFieldInputTemplate,
|
||||
isIntegerFieldCollectionInputInstance,
|
||||
isIntegerFieldCollectionInputTemplate,
|
||||
isIntegerFieldInputInstance,
|
||||
isIntegerFieldInputTemplate,
|
||||
isIPAdapterModelFieldInputInstance,
|
||||
@@ -57,8 +51,6 @@ import {
|
||||
isSDXLRefinerModelFieldInputTemplate,
|
||||
isSpandrelImageToImageModelFieldInputInstance,
|
||||
isSpandrelImageToImageModelFieldInputTemplate,
|
||||
isStringFieldCollectionInputInstance,
|
||||
isStringFieldCollectionInputTemplate,
|
||||
isStringFieldInputInstance,
|
||||
isStringFieldInputTemplate,
|
||||
isT2IAdapterModelFieldInputInstance,
|
||||
@@ -99,285 +91,96 @@ import VAEModelFieldInputComponent from './inputs/VAEModelFieldInputComponent';
|
||||
type InputFieldProps = {
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
isLinearView: boolean;
|
||||
};
|
||||
|
||||
const InputFieldRenderer = ({ nodeId, fieldName, isLinearView }: InputFieldProps) => {
|
||||
const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
const fieldInstance = useFieldInputInstance(nodeId, fieldName);
|
||||
const fieldTemplate = useFieldInputTemplate(nodeId, fieldName);
|
||||
|
||||
if (isStringFieldCollectionInputInstance(fieldInstance) && isStringFieldCollectionInputTemplate(fieldTemplate)) {
|
||||
return (
|
||||
<StringFieldCollectionInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isStringFieldInputInstance(fieldInstance) && isStringFieldInputTemplate(fieldTemplate)) {
|
||||
return (
|
||||
<StringFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
return <StringFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isBooleanFieldInputInstance(fieldInstance) && isBooleanFieldInputTemplate(fieldTemplate)) {
|
||||
return (
|
||||
<BooleanFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
return <BooleanFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isIntegerFieldInputInstance(fieldInstance) && isIntegerFieldInputTemplate(fieldTemplate)) {
|
||||
return (
|
||||
<NumberFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isFloatFieldInputInstance(fieldInstance) && isFloatFieldInputTemplate(fieldTemplate)) {
|
||||
return (
|
||||
<NumberFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isIntegerFieldCollectionInputInstance(fieldInstance) && isIntegerFieldCollectionInputTemplate(fieldTemplate)) {
|
||||
return (
|
||||
<NumberFieldCollectionInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isFloatFieldCollectionInputInstance(fieldInstance) && isFloatFieldCollectionInputTemplate(fieldTemplate)) {
|
||||
return (
|
||||
<NumberFieldCollectionInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
if (
|
||||
(isIntegerFieldInputInstance(fieldInstance) && isIntegerFieldInputTemplate(fieldTemplate)) ||
|
||||
(isFloatFieldInputInstance(fieldInstance) && isFloatFieldInputTemplate(fieldTemplate))
|
||||
) {
|
||||
return <NumberFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isEnumFieldInputInstance(fieldInstance) && isEnumFieldInputTemplate(fieldTemplate)) {
|
||||
return (
|
||||
<EnumFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
return <EnumFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isImageFieldCollectionInputInstance(fieldInstance) && isImageFieldCollectionInputTemplate(fieldTemplate)) {
|
||||
return (
|
||||
<ImageFieldCollectionInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
return <ImageFieldCollectionInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isImageFieldInputInstance(fieldInstance) && isImageFieldInputTemplate(fieldTemplate)) {
|
||||
return (
|
||||
<ImageFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
return <ImageFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isBoardFieldInputInstance(fieldInstance) && isBoardFieldInputTemplate(fieldTemplate)) {
|
||||
return (
|
||||
<BoardFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
return <BoardFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isMainModelFieldInputInstance(fieldInstance) && isMainModelFieldInputTemplate(fieldTemplate)) {
|
||||
return (
|
||||
<MainModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
return <MainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isModelIdentifierFieldInputInstance(fieldInstance) && isModelIdentifierFieldInputTemplate(fieldTemplate)) {
|
||||
return (
|
||||
<ModelIdentifierFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
return <ModelIdentifierFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isSDXLRefinerModelFieldInputInstance(fieldInstance) && isSDXLRefinerModelFieldInputTemplate(fieldTemplate)) {
|
||||
return (
|
||||
<RefinerModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
return <RefinerModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isVAEModelFieldInputInstance(fieldInstance) && isVAEModelFieldInputTemplate(fieldTemplate)) {
|
||||
return (
|
||||
<VAEModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
return <VAEModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isT5EncoderModelFieldInputInstance(fieldInstance) && isT5EncoderModelFieldInputTemplate(fieldTemplate)) {
|
||||
return (
|
||||
<T5EncoderModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
return <T5EncoderModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
if (isCLIPEmbedModelFieldInputInstance(fieldInstance) && isCLIPEmbedModelFieldInputTemplate(fieldTemplate)) {
|
||||
return (
|
||||
<CLIPEmbedModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
return <CLIPEmbedModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isCLIPLEmbedModelFieldInputInstance(fieldInstance) && isCLIPLEmbedModelFieldInputTemplate(fieldTemplate)) {
|
||||
return (
|
||||
<CLIPLEmbedModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
return <CLIPLEmbedModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isCLIPGEmbedModelFieldInputInstance(fieldInstance) && isCLIPGEmbedModelFieldInputTemplate(fieldTemplate)) {
|
||||
return (
|
||||
<CLIPGEmbedModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
return <CLIPGEmbedModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isControlLoRAModelFieldInputInstance(fieldInstance) && isControlLoRAModelFieldInputTemplate(fieldTemplate)) {
|
||||
return (
|
||||
<ControlLoRAModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
return <ControlLoRAModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isFluxVAEModelFieldInputInstance(fieldInstance) && isFluxVAEModelFieldInputTemplate(fieldTemplate)) {
|
||||
return (
|
||||
<FluxVAEModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
return <FluxVAEModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isLoRAModelFieldInputInstance(fieldInstance) && isLoRAModelFieldInputTemplate(fieldTemplate)) {
|
||||
return (
|
||||
<LoRAModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
return <LoRAModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isControlNetModelFieldInputInstance(fieldInstance) && isControlNetModelFieldInputTemplate(fieldTemplate)) {
|
||||
return (
|
||||
<ControlNetModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
return <ControlNetModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isIPAdapterModelFieldInputInstance(fieldInstance) && isIPAdapterModelFieldInputTemplate(fieldTemplate)) {
|
||||
return (
|
||||
<IPAdapterModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
return <IPAdapterModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isT2IAdapterModelFieldInputInstance(fieldInstance) && isT2IAdapterModelFieldInputTemplate(fieldTemplate)) {
|
||||
return (
|
||||
<T2IAdapterModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
return <T2IAdapterModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (
|
||||
@@ -389,64 +192,28 @@ const InputFieldRenderer = ({ nodeId, fieldName, isLinearView }: InputFieldProps
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isColorFieldInputInstance(fieldInstance) && isColorFieldInputTemplate(fieldTemplate)) {
|
||||
return (
|
||||
<ColorFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
return <ColorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isFluxMainModelFieldInputInstance(fieldInstance) && isFluxMainModelFieldInputTemplate(fieldTemplate)) {
|
||||
return (
|
||||
<FluxMainModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
return <FluxMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isSD3MainModelFieldInputInstance(fieldInstance) && isSD3MainModelFieldInputTemplate(fieldTemplate)) {
|
||||
return (
|
||||
<SD3MainModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
return <SD3MainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isSDXLMainModelFieldInputInstance(fieldInstance) && isSDXLMainModelFieldInputTemplate(fieldTemplate)) {
|
||||
return (
|
||||
<SDXLMainModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
return <SDXLMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isSchedulerFieldInputInstance(fieldInstance) && isSchedulerFieldInputTemplate(fieldTemplate)) {
|
||||
return (
|
||||
<SchedulerFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
return <SchedulerFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (fieldTemplate) {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import { selectInvocationNode, selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
@@ -18,7 +18,7 @@ export const InvocationInputFieldCheck = memo(({ nodeId, fieldName, children }:
|
||||
const templates = useStore($templates);
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, (nodesSlice) => {
|
||||
createSelector(selectNodesSlice, (nodesSlice) => {
|
||||
const node = selectInvocationNode(nodesSlice, nodeId);
|
||||
const instance = node.data.inputs[fieldName];
|
||||
const template = templates[node.data.type];
|
||||
|
||||
@@ -97,11 +97,7 @@ const LinearViewFieldInternal = ({ fieldIdentifier }: Props) => {
|
||||
icon={<PiTrashSimpleBold />}
|
||||
/>
|
||||
</Flex>
|
||||
<InputFieldRenderer
|
||||
nodeId={fieldIdentifier.nodeId}
|
||||
fieldName={fieldIdentifier.fieldName}
|
||||
isLinearView={true}
|
||||
/>
|
||||
<InputFieldRenderer nodeId={fieldIdentifier.nodeId} fieldName={fieldIdentifier.fieldName} />
|
||||
</Flex>
|
||||
</Flex>
|
||||
<DndListDropIndicator dndState={dndListState} />
|
||||
|
||||
@@ -26,7 +26,7 @@ const EnumFieldInputComponent = (props: FieldComponentProps<EnumFieldInputInstan
|
||||
);
|
||||
|
||||
return (
|
||||
<Select className="nowheel nodrag" onChange={handleValueChanged} value={field.value} size="sm">
|
||||
<Select className="nowheel nodrag" onChange={handleValueChanged} value={field.value}>
|
||||
{fieldTemplate.options.map((option) => (
|
||||
<option key={option} value={option}>
|
||||
{fieldTemplate.ui_choice_labels ? fieldTemplate.ui_choice_labels[option] : option}
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
import { CompositeNumberInput, Flex, FormControl, FormLabel, IconButton } from '@invoke-ai/ui-library';
|
||||
import {
|
||||
type FloatRangeStartStepCountGenerator,
|
||||
getDefaultFloatRangeStartStepCountGenerator,
|
||||
} from 'features/nodes/types/generators';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiArrowCounterClockwiseBold } from 'react-icons/pi';
|
||||
|
||||
type FloatRangeGeneratorProps = {
|
||||
state: FloatRangeStartStepCountGenerator;
|
||||
onChange: (state: FloatRangeStartStepCountGenerator) => void;
|
||||
};
|
||||
|
||||
export const FloatRangeGenerator = memo(({ state, onChange }: FloatRangeGeneratorProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const onChangeStart = useCallback(
|
||||
(start: number) => {
|
||||
onChange({ ...state, start });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
const onChangeStep = useCallback(
|
||||
(step: number) => {
|
||||
onChange({ ...state, step });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
const onChangeCount = useCallback(
|
||||
(count: number) => {
|
||||
onChange({ ...state, count });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
|
||||
const onReset = useCallback(() => {
|
||||
onChange(getDefaultFloatRangeStartStepCountGenerator());
|
||||
}, [onChange]);
|
||||
|
||||
return (
|
||||
<Flex gap={1} alignItems="flex-end" p={1}>
|
||||
<FormControl orientation="vertical" gap={1}>
|
||||
<FormLabel m={0}>{t('common.start')}</FormLabel>
|
||||
<CompositeNumberInput value={state.start} onChange={onChangeStart} min={-Infinity} max={Infinity} step={0.01} />
|
||||
</FormControl>
|
||||
<FormControl orientation="vertical" gap={1}>
|
||||
<FormLabel m={0}>{t('common.count')}</FormLabel>
|
||||
<CompositeNumberInput value={state.count} onChange={onChangeCount} min={1} max={Infinity} />
|
||||
</FormControl>
|
||||
<FormControl orientation="vertical" gap={1}>
|
||||
<FormLabel m={0}>{t('common.step')}</FormLabel>
|
||||
<CompositeNumberInput value={state.step} onChange={onChangeStep} min={-Infinity} max={Infinity} step={0.01} />
|
||||
</FormControl>
|
||||
<IconButton
|
||||
onClick={onReset}
|
||||
aria-label={t('common.reset')}
|
||||
icon={<PiArrowCounterClockwiseBold />}
|
||||
variant="ghost"
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
FloatRangeGenerator.displayName = 'FloatRangeGenerator';
|
||||
@@ -10,9 +10,9 @@ import { addImagesToNodeImageFieldCollectionDndTarget } from 'features/dnd/dnd';
|
||||
import { DndDropTarget } from 'features/dnd/DndDropTarget';
|
||||
import { DndImage } from 'features/dnd/DndImage';
|
||||
import { DndImageIcon } from 'features/dnd/DndImageIcon';
|
||||
import { removeImageFromNodeImageFieldCollectionAction } from 'features/imageActions/actions';
|
||||
import { useFieldIsInvalid } from 'features/nodes/hooks/useFieldIsInvalid';
|
||||
import { fieldImageCollectionValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { ImageField } from 'features/nodes/types/common';
|
||||
import type { ImageFieldCollectionInputInstance, ImageFieldCollectionInputTemplate } from 'features/nodes/types/field';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
@@ -61,12 +61,15 @@ export const ImageFieldCollectionInputComponent = memo(
|
||||
);
|
||||
|
||||
const onRemoveImage = useCallback(
|
||||
(index: number) => {
|
||||
const newValue = field.value ? [...field.value] : [];
|
||||
newValue.splice(index, 1);
|
||||
store.dispatch(fieldImageCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
|
||||
(imageName: string) => {
|
||||
removeImageFromNodeImageFieldCollectionAction({
|
||||
imageName,
|
||||
fieldIdentifier: { nodeId, fieldName: field.name },
|
||||
dispatch: store.dispatch,
|
||||
getState: store.getState,
|
||||
});
|
||||
},
|
||||
[field.name, field.value, nodeId, store]
|
||||
[field.name, nodeId, store.dispatch, store.getState]
|
||||
);
|
||||
|
||||
return (
|
||||
@@ -87,7 +90,7 @@ export const ImageFieldCollectionInputComponent = memo(
|
||||
isError={isInvalid}
|
||||
onUpload={onUpload}
|
||||
fontSize={24}
|
||||
variant="ghost"
|
||||
variant="outline"
|
||||
/>
|
||||
)}
|
||||
{field.value && field.value.length > 0 && (
|
||||
@@ -99,9 +102,9 @@ export const ImageFieldCollectionInputComponent = memo(
|
||||
options={overlayscrollbarsOptions}
|
||||
>
|
||||
<Grid w="full" h="full" templateColumns="repeat(4, 1fr)" gap={1}>
|
||||
{field.value.map((value, index) => (
|
||||
<GridItem key={index} position="relative" className="nodrag">
|
||||
<ImageGridItemContent value={value} index={index} onRemoveImage={onRemoveImage} />
|
||||
{field.value.map(({ image_name }) => (
|
||||
<GridItem key={image_name} position="relative" className="nodrag">
|
||||
<ImageGridItemContent imageName={image_name} onRemoveImage={onRemoveImage} />
|
||||
</GridItem>
|
||||
))}
|
||||
</Grid>
|
||||
@@ -121,11 +124,11 @@ export const ImageFieldCollectionInputComponent = memo(
|
||||
ImageFieldCollectionInputComponent.displayName = 'ImageFieldCollectionInputComponent';
|
||||
|
||||
const ImageGridItemContent = memo(
|
||||
({ value, index, onRemoveImage }: { value: ImageField; index: number; onRemoveImage: (index: number) => void }) => {
|
||||
const query = useGetImageDTOQuery(value.image_name);
|
||||
({ imageName, onRemoveImage }: { imageName: string; onRemoveImage: (imageName: string) => void }) => {
|
||||
const query = useGetImageDTOQuery(imageName);
|
||||
const onClickRemove = useCallback(() => {
|
||||
onRemoveImage(index);
|
||||
}, [index, onRemoveImage]);
|
||||
onRemoveImage(imageName);
|
||||
}, [imageName, onRemoveImage]);
|
||||
|
||||
if (query.isLoading) {
|
||||
return <IAINoContentFallbackWithSpinner />;
|
||||
|
||||
@@ -1,320 +0,0 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import {
|
||||
Button,
|
||||
CompositeNumberInput,
|
||||
Divider,
|
||||
Flex,
|
||||
FormControl,
|
||||
FormLabel,
|
||||
Grid,
|
||||
GridItem,
|
||||
IconButton,
|
||||
Switch,
|
||||
Text,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { NUMPY_RAND_MAX } from 'app/constants';
|
||||
import { useAppStore } from 'app/store/nanostores/store';
|
||||
import { getOverlayScrollbarsParams, overlayScrollbarsStyles } from 'common/components/OverlayScrollbars/constants';
|
||||
import { FloatRangeGenerator } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FloatRangeGenerator';
|
||||
import { useFieldIsInvalid } from 'features/nodes/hooks/useFieldIsInvalid';
|
||||
import {
|
||||
fieldNumberCollectionGeneratorCommitted,
|
||||
fieldNumberCollectionGeneratorStateChanged,
|
||||
fieldNumberCollectionGeneratorToggled,
|
||||
fieldNumberCollectionLockLinearViewToggled,
|
||||
fieldNumberCollectionValueChanged,
|
||||
} from 'features/nodes/store/nodesSlice';
|
||||
import type {
|
||||
FloatFieldCollectionInputInstance,
|
||||
FloatFieldCollectionInputTemplate,
|
||||
IntegerFieldCollectionInputInstance,
|
||||
IntegerFieldCollectionInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { resolveNumberFieldCollectionValue } from 'features/nodes/types/fieldValidators';
|
||||
import type {
|
||||
FloatRangeStartStepCountGenerator,
|
||||
IntegerRangeStartStepCountGenerator,
|
||||
} from 'features/nodes/types/generators';
|
||||
import { isNil, round } from 'lodash-es';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiLockSimpleFill, PiLockSimpleOpenFill, PiXBold } from 'react-icons/pi';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
const overlayscrollbarsOptions = getOverlayScrollbarsParams().options;
|
||||
|
||||
const sx = {
|
||||
borderWidth: 1,
|
||||
'&[data-error=true]': {
|
||||
borderColor: 'error.500',
|
||||
borderStyle: 'solid',
|
||||
},
|
||||
} satisfies SystemStyleObject;
|
||||
|
||||
export const NumberFieldCollectionInputComponent = memo(
|
||||
(
|
||||
props:
|
||||
| FieldComponentProps<IntegerFieldCollectionInputInstance, IntegerFieldCollectionInputTemplate>
|
||||
| FieldComponentProps<FloatFieldCollectionInputInstance, FloatFieldCollectionInputTemplate>
|
||||
) => {
|
||||
const { nodeId, field, fieldTemplate, isLinearView } = props;
|
||||
const store = useAppStore();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const isInvalid = useFieldIsInvalid(nodeId, field.name);
|
||||
const isIntegerField = useMemo(() => fieldTemplate.type.name === 'IntegerField', [fieldTemplate.type]);
|
||||
|
||||
const onRemoveNumber = useCallback(
|
||||
(index: number) => {
|
||||
const newValue = field.value ? [...field.value] : [];
|
||||
newValue.splice(index, 1);
|
||||
store.dispatch(fieldNumberCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
|
||||
},
|
||||
[field.name, field.value, nodeId, store]
|
||||
);
|
||||
|
||||
const onChangeNumber = useCallback(
|
||||
(index: number, value: number) => {
|
||||
const newValue = field.value ? [...field.value] : [];
|
||||
newValue[index] = value;
|
||||
store.dispatch(fieldNumberCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
|
||||
},
|
||||
[field.name, field.value, nodeId, store]
|
||||
);
|
||||
|
||||
const onAddNumber = useCallback(() => {
|
||||
const newValue = field.value ? [...field.value, 0] : [0];
|
||||
store.dispatch(fieldNumberCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
|
||||
}, [field.name, field.value, nodeId, store]);
|
||||
|
||||
const min = useMemo(() => {
|
||||
let min = -NUMPY_RAND_MAX;
|
||||
if (!isNil(fieldTemplate.minimum)) {
|
||||
min = fieldTemplate.minimum;
|
||||
}
|
||||
if (!isNil(fieldTemplate.exclusiveMinimum)) {
|
||||
min = fieldTemplate.exclusiveMinimum + 0.01;
|
||||
}
|
||||
return min;
|
||||
}, [fieldTemplate.exclusiveMinimum, fieldTemplate.minimum]);
|
||||
|
||||
const max = useMemo(() => {
|
||||
let max = NUMPY_RAND_MAX;
|
||||
if (!isNil(fieldTemplate.maximum)) {
|
||||
max = fieldTemplate.maximum;
|
||||
}
|
||||
if (!isNil(fieldTemplate.exclusiveMaximum)) {
|
||||
max = fieldTemplate.exclusiveMaximum - 0.01;
|
||||
}
|
||||
return max;
|
||||
}, [fieldTemplate.exclusiveMaximum, fieldTemplate.maximum]);
|
||||
|
||||
const step = useMemo(() => {
|
||||
if (isNil(fieldTemplate.multipleOf)) {
|
||||
return isIntegerField ? 1 : 0.1;
|
||||
}
|
||||
return fieldTemplate.multipleOf;
|
||||
}, [fieldTemplate.multipleOf, isIntegerField]);
|
||||
|
||||
const fineStep = useMemo(() => {
|
||||
if (isNil(fieldTemplate.multipleOf)) {
|
||||
return isIntegerField ? 1 : 0.01;
|
||||
}
|
||||
return fieldTemplate.multipleOf;
|
||||
}, [fieldTemplate.multipleOf, isIntegerField]);
|
||||
|
||||
const toggleGenerator = useCallback(() => {
|
||||
store.dispatch(fieldNumberCollectionGeneratorToggled({ nodeId, fieldName: field.name }));
|
||||
}, [field.name, nodeId, store]);
|
||||
|
||||
const onChangeGenerator = useCallback(
|
||||
(generatorState: FloatRangeStartStepCountGenerator | IntegerRangeStartStepCountGenerator) => {
|
||||
store.dispatch(fieldNumberCollectionGeneratorStateChanged({ nodeId, fieldName: field.name, generatorState }));
|
||||
},
|
||||
[field.name, nodeId, store]
|
||||
);
|
||||
|
||||
const onCommitGenerator = useCallback(() => {
|
||||
store.dispatch(fieldNumberCollectionGeneratorCommitted({ nodeId, fieldName: field.name }));
|
||||
}, [field.name, nodeId, store]);
|
||||
|
||||
const onToggleLockLinearView = useCallback(() => {
|
||||
store.dispatch(fieldNumberCollectionLockLinearViewToggled({ nodeId, fieldName: field.name }));
|
||||
}, [field.name, nodeId, store]);
|
||||
|
||||
const valuesAsString = useMemo(() => {
|
||||
const resolvedValue = resolveNumberFieldCollectionValue(field);
|
||||
return resolvedValue ? resolvedValue.map((val) => round(val, 2)).join(', ') : '';
|
||||
}, [field]);
|
||||
|
||||
const isLockedOnLinearView = !(field.lockLinearView && isLinearView);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
className="nodrag"
|
||||
position="relative"
|
||||
w="full"
|
||||
h="auto"
|
||||
maxH={64}
|
||||
alignItems="stretch"
|
||||
justifyContent="center"
|
||||
p={1}
|
||||
sx={sx}
|
||||
data-error={isInvalid}
|
||||
borderRadius="base"
|
||||
flexDir="column"
|
||||
gap={1}
|
||||
>
|
||||
<Flex w="full" gap={2}>
|
||||
{!field.generator && (
|
||||
<Button onClick={onAddNumber} variant="ghost" flexGrow={1} size="sm">
|
||||
{t('nodes.addValue')}
|
||||
</Button>
|
||||
)}
|
||||
{field.generator && isLockedOnLinearView && (
|
||||
<Button
|
||||
tooltip={
|
||||
<Flex p={1} flexDir="column">
|
||||
<Text fontWeight="semibold">{t('nodes.generatedValues')}:</Text>
|
||||
<Text fontFamily="monospace">{valuesAsString}</Text>
|
||||
</Flex>
|
||||
}
|
||||
onClick={onCommitGenerator}
|
||||
variant="ghost"
|
||||
flexGrow={1}
|
||||
size="sm"
|
||||
>
|
||||
{t('nodes.commitValues')}
|
||||
</Button>
|
||||
)}
|
||||
{isLockedOnLinearView && (
|
||||
<FormControl w="min-content" pe={isLinearView ? 2 : undefined}>
|
||||
<FormLabel m={0}>{t('nodes.generator')}</FormLabel>
|
||||
<Switch onChange={toggleGenerator} isChecked={Boolean(field.generator)} size="sm" />
|
||||
</FormControl>
|
||||
)}
|
||||
{!isLinearView && (
|
||||
<IconButton
|
||||
onClick={onToggleLockLinearView}
|
||||
tooltip={field.lockLinearView ? t('nodes.unlockLinearView') : t('nodes.lockLinearView')}
|
||||
aria-label={field.lockLinearView ? t('nodes.unlockLinearView') : t('nodes.lockLinearView')}
|
||||
icon={field.lockLinearView ? <PiLockSimpleFill /> : <PiLockSimpleOpenFill />}
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
/>
|
||||
)}
|
||||
</Flex>
|
||||
{!field.generator && field.value && field.value.length > 0 && (
|
||||
<>
|
||||
{!(field.lockLinearView && isLinearView) && <Divider />}
|
||||
<OverlayScrollbarsComponent
|
||||
className="nowheel"
|
||||
defer
|
||||
style={overlayScrollbarsStyles}
|
||||
options={overlayscrollbarsOptions}
|
||||
>
|
||||
<Grid gap={1} gridTemplateColumns="auto 1fr auto" alignItems="center">
|
||||
{field.value.map((value, index) => (
|
||||
<NumberListItemContent
|
||||
key={index}
|
||||
value={value}
|
||||
index={index}
|
||||
min={min}
|
||||
max={max}
|
||||
step={step}
|
||||
fineStep={fineStep}
|
||||
isIntegerField={isIntegerField}
|
||||
onRemoveNumber={onRemoveNumber}
|
||||
onChangeNumber={onChangeNumber}
|
||||
/>
|
||||
))}
|
||||
</Grid>
|
||||
</OverlayScrollbarsComponent>
|
||||
</>
|
||||
)}
|
||||
{field.generator && field.generator.type === 'float-range-generator-start-step-count' && (
|
||||
<>
|
||||
{!(field.lockLinearView && isLinearView) && <Divider />}
|
||||
<FloatRangeGenerator state={field.generator} onChange={onChangeGenerator} />
|
||||
</>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
NumberFieldCollectionInputComponent.displayName = 'NumberFieldCollectionInputComponent';
|
||||
|
||||
type NumberListItemContentProps = {
|
||||
value: number;
|
||||
index: number;
|
||||
isIntegerField: boolean;
|
||||
min: number;
|
||||
max: number;
|
||||
step: number;
|
||||
fineStep: number;
|
||||
onRemoveNumber: (index: number) => void;
|
||||
onChangeNumber: (index: number, value: number) => void;
|
||||
};
|
||||
|
||||
const NumberListItemContent = memo(
|
||||
({
|
||||
value,
|
||||
index,
|
||||
isIntegerField,
|
||||
min,
|
||||
max,
|
||||
step,
|
||||
fineStep,
|
||||
onRemoveNumber,
|
||||
onChangeNumber,
|
||||
}: NumberListItemContentProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const onClickRemove = useCallback(() => {
|
||||
onRemoveNumber(index);
|
||||
}, [index, onRemoveNumber]);
|
||||
const onChange = useCallback(
|
||||
(v: number) => {
|
||||
onChangeNumber(index, isIntegerField ? Math.floor(Number(v)) : Number(v));
|
||||
},
|
||||
[index, isIntegerField, onChangeNumber]
|
||||
);
|
||||
|
||||
return (
|
||||
<>
|
||||
<GridItem>
|
||||
<FormLabel ps={1} m={0}>
|
||||
{index + 1}.
|
||||
</FormLabel>
|
||||
</GridItem>
|
||||
<GridItem>
|
||||
<CompositeNumberInput
|
||||
onChange={onChange}
|
||||
value={value}
|
||||
min={min}
|
||||
max={max}
|
||||
step={step}
|
||||
fineStep={fineStep}
|
||||
className="nodrag"
|
||||
flexGrow={1}
|
||||
/>
|
||||
</GridItem>
|
||||
<GridItem>
|
||||
<IconButton
|
||||
tabIndex={-1}
|
||||
size="sm"
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
onClick={onClickRemove}
|
||||
icon={<PiXBold />}
|
||||
aria-label={t('common.delete')}
|
||||
/>
|
||||
</GridItem>
|
||||
</>
|
||||
);
|
||||
}
|
||||
);
|
||||
NumberListItemContent.displayName = 'NumberListItemContent';
|
||||
@@ -1,152 +0,0 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Box, Flex, Grid, GridItem, IconButton, Input } from '@invoke-ai/ui-library';
|
||||
import { useAppStore } from 'app/store/nanostores/store';
|
||||
import { getOverlayScrollbarsParams, overlayScrollbarsStyles } from 'common/components/OverlayScrollbars/constants';
|
||||
import { useFieldIsInvalid } from 'features/nodes/hooks/useFieldIsInvalid';
|
||||
import { fieldStringCollectionValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type {
|
||||
StringFieldCollectionInputInstance,
|
||||
StringFieldCollectionInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlusBold, PiXBold } from 'react-icons/pi';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
const overlayscrollbarsOptions = getOverlayScrollbarsParams().options;
|
||||
|
||||
const sx = {
|
||||
borderWidth: 1,
|
||||
'&[data-error=true]': {
|
||||
borderColor: 'error.500',
|
||||
borderStyle: 'solid',
|
||||
},
|
||||
} satisfies SystemStyleObject;
|
||||
|
||||
export const StringFieldCollectionInputComponent = memo(
|
||||
(props: FieldComponentProps<StringFieldCollectionInputInstance, StringFieldCollectionInputTemplate>) => {
|
||||
const { nodeId, field } = props;
|
||||
const store = useAppStore();
|
||||
|
||||
const isInvalid = useFieldIsInvalid(nodeId, field.name);
|
||||
|
||||
const onRemoveString = useCallback(
|
||||
(index: number) => {
|
||||
const newValue = field.value ? [...field.value] : [];
|
||||
newValue.splice(index, 1);
|
||||
store.dispatch(fieldStringCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
|
||||
},
|
||||
[field.name, field.value, nodeId, store]
|
||||
);
|
||||
|
||||
const onChangeString = useCallback(
|
||||
(index: number, value: string) => {
|
||||
const newValue = field.value ? [...field.value] : [];
|
||||
newValue[index] = value;
|
||||
store.dispatch(fieldStringCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
|
||||
},
|
||||
[field.name, field.value, nodeId, store]
|
||||
);
|
||||
|
||||
const onAddString = useCallback(() => {
|
||||
const newValue = field.value ? [...field.value, ''] : [''];
|
||||
store.dispatch(fieldStringCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
|
||||
}, [field.name, field.value, nodeId, store]);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
className="nodrag"
|
||||
position="relative"
|
||||
w="full"
|
||||
h="full"
|
||||
maxH={64}
|
||||
alignItems="stretch"
|
||||
justifyContent="center"
|
||||
>
|
||||
{(!field.value || field.value.length === 0) && (
|
||||
<Box w="full" sx={sx} data-error={isInvalid} borderRadius="base">
|
||||
<IconButton
|
||||
w="full"
|
||||
onClick={onAddString}
|
||||
aria-label="Add Item"
|
||||
icon={<PiPlusBold />}
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
/>
|
||||
</Box>
|
||||
)}
|
||||
{field.value && field.value.length > 0 && (
|
||||
<Box w="full" h="auto" p={1} sx={sx} data-error={isInvalid} borderRadius="base">
|
||||
<OverlayScrollbarsComponent
|
||||
className="nowheel"
|
||||
defer
|
||||
style={overlayScrollbarsStyles}
|
||||
options={overlayscrollbarsOptions}
|
||||
>
|
||||
<Grid w="full" h="full" templateColumns="repeat(1, 1fr)" gap={1}>
|
||||
<IconButton
|
||||
onClick={onAddString}
|
||||
aria-label="Add Item"
|
||||
icon={<PiPlusBold />}
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
/>
|
||||
{field.value.map((value, index) => (
|
||||
<GridItem key={index} position="relative" className="nodrag">
|
||||
<StringListItemContent
|
||||
value={value}
|
||||
index={index}
|
||||
onRemoveString={onRemoveString}
|
||||
onChangeString={onChangeString}
|
||||
/>
|
||||
</GridItem>
|
||||
))}
|
||||
</Grid>
|
||||
</OverlayScrollbarsComponent>
|
||||
</Box>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
StringFieldCollectionInputComponent.displayName = 'StringFieldCollectionInputComponent';
|
||||
|
||||
type StringListItemContentProps = {
|
||||
value: string;
|
||||
index: number;
|
||||
onRemoveString: (index: number) => void;
|
||||
onChangeString: (index: number, value: string) => void;
|
||||
};
|
||||
|
||||
const StringListItemContent = memo(({ value, index, onRemoveString, onChangeString }: StringListItemContentProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const onClickRemove = useCallback(() => {
|
||||
onRemoveString(index);
|
||||
}, [index, onRemoveString]);
|
||||
const onChange = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => {
|
||||
onChangeString(index, e.target.value);
|
||||
},
|
||||
[index, onChangeString]
|
||||
);
|
||||
return (
|
||||
<Flex alignItems="center" gap={1}>
|
||||
<Input size="xs" resize="none" value={value} onChange={onChange} />
|
||||
<IconButton
|
||||
size="sm"
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
onClick={onClickRemove}
|
||||
icon={<PiXBold />}
|
||||
aria-label={t('common.remove')}
|
||||
tooltip={t('common.remove')}
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
StringListItemContent.displayName = 'StringListItemContent';
|
||||
@@ -4,5 +4,4 @@ export type FieldComponentProps<V extends FieldInputInstance, T extends FieldInp
|
||||
nodeId: string;
|
||||
field: V;
|
||||
fieldTemplate: T;
|
||||
isLinearView: boolean;
|
||||
};
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
import type { SystemStyleObject, TextProps } from '@invoke-ai/ui-library';
|
||||
import { Box, Editable, EditableInput, Flex, Text, useEditableControls } from '@invoke-ai/ui-library';
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Box, Editable, EditableInput, EditablePreview, Flex, useEditableControls } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useBatchGroupColorToken } from 'features/nodes/hooks/useBatchGroupColorToken';
|
||||
import { useBatchGroupId } from 'features/nodes/hooks/useBatchGroupId';
|
||||
import { useNodeLabel } from 'features/nodes/hooks/useNodeLabel';
|
||||
import { useNodeTemplateTitle } from 'features/nodes/hooks/useNodeTemplateTitle';
|
||||
import { nodeLabelChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
|
||||
import type { MouseEvent } from 'react';
|
||||
import { memo, useCallback, useEffect, useMemo, useState } from 'react';
|
||||
import { memo, useCallback, useEffect, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
type Props = {
|
||||
@@ -19,8 +17,6 @@ type Props = {
|
||||
const NodeTitle = ({ nodeId, title }: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const label = useNodeLabel(nodeId);
|
||||
const batchGroupId = useBatchGroupId(nodeId);
|
||||
const batchGroupColorToken = useBatchGroupColorToken(batchGroupId);
|
||||
const templateTitle = useNodeTemplateTitle(nodeId);
|
||||
const { t } = useTranslation();
|
||||
|
||||
@@ -33,16 +29,6 @@ const NodeTitle = ({ nodeId, title }: Props) => {
|
||||
[dispatch, nodeId, title, templateTitle, label, t]
|
||||
);
|
||||
|
||||
const localTitleWithBatchGroupId = useMemo(() => {
|
||||
if (!batchGroupId) {
|
||||
return localTitle;
|
||||
}
|
||||
if (batchGroupId === 'None') {
|
||||
return `${localTitle} (${t('nodes.noBatchGroup')})`;
|
||||
}
|
||||
return `${localTitle} (${batchGroupId})`;
|
||||
}, [batchGroupId, localTitle, t]);
|
||||
|
||||
const handleChange = useCallback((newTitle: string) => {
|
||||
setLocalTitle(newTitle);
|
||||
}, []);
|
||||
@@ -64,16 +50,7 @@ const NodeTitle = ({ nodeId, title }: Props) => {
|
||||
w="full"
|
||||
h="full"
|
||||
>
|
||||
<Preview
|
||||
fontSize="sm"
|
||||
p={0}
|
||||
w="full"
|
||||
noOfLines={1}
|
||||
color={batchGroupColorToken}
|
||||
fontWeight={batchGroupId ? 'semibold' : undefined}
|
||||
>
|
||||
{localTitleWithBatchGroupId}
|
||||
</Preview>
|
||||
<EditablePreview fontSize="sm" p={0} w="full" noOfLines={1} />
|
||||
<EditableInput className="nodrag" fontSize="sm" sx={editableInputStyles} />
|
||||
<EditableControls />
|
||||
</Editable>
|
||||
@@ -83,16 +60,6 @@ const NodeTitle = ({ nodeId, title }: Props) => {
|
||||
|
||||
export default memo(NodeTitle);
|
||||
|
||||
const Preview = (props: TextProps) => {
|
||||
const { isEditing } = useEditableControls();
|
||||
|
||||
if (isEditing) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return <Text {...props} />;
|
||||
};
|
||||
|
||||
function EditableControls() {
|
||||
const { isEditing, getEditButtonProps } = useEditableControls();
|
||||
const handleDoubleClick = useCallback(
|
||||
|
||||
@@ -5,7 +5,7 @@ import NodeOpacitySlider from './NodeOpacitySlider';
|
||||
import ViewportControls from './ViewportControls';
|
||||
|
||||
const BottomLeftPanel = () => (
|
||||
<Flex gap={2} position="absolute" bottom={2} insetInlineStart={2}>
|
||||
<Flex gap={2} position="absolute" bottom={0} insetInlineStart={0}>
|
||||
<ViewportControls />
|
||||
<NodeOpacitySlider />
|
||||
</Flex>
|
||||
|
||||
@@ -20,7 +20,7 @@ const MinimapPanel = () => {
|
||||
const shouldShowMinimapPanel = useAppSelector(selectShouldShowMinimapPanel);
|
||||
|
||||
return (
|
||||
<Flex gap={2} position="absolute" bottom={2} insetInlineEnd={2}>
|
||||
<Flex gap={2} position="absolute" bottom={0} insetInlineEnd={0}>
|
||||
{shouldShowMinimapPanel && (
|
||||
<ChakraMiniMap
|
||||
pannable
|
||||
|
||||
@@ -12,7 +12,7 @@ import { memo } from 'react';
|
||||
const TopCenterPanel = () => {
|
||||
const name = useAppSelector(selectWorkflowName);
|
||||
return (
|
||||
<Flex gap={2} top={2} left={2} right={2} position="absolute" alignItems="flex-start" pointerEvents="none">
|
||||
<Flex gap={2} top={0} left={0} right={0} position="absolute" alignItems="flex-start" pointerEvents="none">
|
||||
<Flex gap="2">
|
||||
<AddNodeButton />
|
||||
<UpdateNodesButton />
|
||||
|
||||
@@ -46,7 +46,7 @@ const WorkflowFieldInternal = ({ nodeId, fieldName }: Props) => {
|
||||
</Flex>
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
<InputFieldRenderer nodeId={nodeId} fieldName={fieldName} isLinearView={true} />
|
||||
<InputFieldRenderer nodeId={nodeId} fieldName={fieldName} />
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useBatchGroupColorToken = (batchGroupId?: string) => {
|
||||
const batchGroupColorToken = useMemo(() => {
|
||||
switch (batchGroupId) {
|
||||
case 'Group 1':
|
||||
return 'invokeGreen.300';
|
||||
case 'Group 2':
|
||||
return 'invokeBlue.300';
|
||||
case 'Group 3':
|
||||
return 'invokePurple.200';
|
||||
case 'Group 4':
|
||||
return 'invokeRed.300';
|
||||
case 'Group 5':
|
||||
return 'invokeYellow.300';
|
||||
default:
|
||||
return undefined;
|
||||
}
|
||||
}, [batchGroupId]);
|
||||
|
||||
return batchGroupColorToken;
|
||||
};
|
||||
@@ -1,19 +0,0 @@
|
||||
import { useNode } from 'features/nodes/hooks/useNode';
|
||||
import { isBatchNode, isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useBatchGroupId = (nodeId: string) => {
|
||||
const node = useNode(nodeId);
|
||||
|
||||
const batchGroupId = useMemo(() => {
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
if (!isBatchNode(node)) {
|
||||
return;
|
||||
}
|
||||
return node.data.inputs['batch_group_id']?.value as string;
|
||||
}, [node]);
|
||||
|
||||
return batchGroupId;
|
||||
};
|
||||
@@ -3,21 +3,7 @@ import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useConnectionState } from 'features/nodes/hooks/useConnectionState';
|
||||
import { useFieldInputTemplate } from 'features/nodes/hooks/useFieldInputTemplate';
|
||||
import { selectFieldInputInstance, selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import {
|
||||
isFloatFieldCollectionInputInstance,
|
||||
isFloatFieldCollectionInputTemplate,
|
||||
isImageFieldCollectionInputInstance,
|
||||
isImageFieldCollectionInputTemplate,
|
||||
isIntegerFieldCollectionInputInstance,
|
||||
isIntegerFieldCollectionInputTemplate,
|
||||
isStringFieldCollectionInputInstance,
|
||||
isStringFieldCollectionInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import {
|
||||
validateImageFieldCollectionValue,
|
||||
validateNumberFieldCollectionValue,
|
||||
validateStringFieldCollectionValue,
|
||||
} from 'features/nodes/types/fieldValidators';
|
||||
import { isImageFieldCollectionInputInstance, isImageFieldCollectionInputTemplate } from 'features/nodes/types/field';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useFieldIsInvalid = (nodeId: string, fieldName: string) => {
|
||||
@@ -49,27 +35,13 @@ export const useFieldIsInvalid = (nodeId: string, fieldName: string) => {
|
||||
}
|
||||
|
||||
// Else special handling for individual field types
|
||||
|
||||
if (isImageFieldCollectionInputInstance(field) && isImageFieldCollectionInputTemplate(template)) {
|
||||
if (validateImageFieldCollectionValue(field.value, template).length > 0) {
|
||||
// Image collections may have min or max item counts
|
||||
if (template.minItems !== undefined && field.value.length < template.minItems) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if (isStringFieldCollectionInputInstance(field) && isStringFieldCollectionInputTemplate(template)) {
|
||||
if (validateStringFieldCollectionValue(field.value, template).length > 0) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if (isIntegerFieldCollectionInputInstance(field) && isIntegerFieldCollectionInputTemplate(template)) {
|
||||
if (validateNumberFieldCollectionValue(field, template).length > 0) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if (isFloatFieldCollectionInputInstance(field) && isFloatFieldCollectionInputTemplate(template)) {
|
||||
if (validateNumberFieldCollectionValue(field, template).length > 0) {
|
||||
if (template.maxItems !== undefined && field.value.length > template.maxItems) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance';
|
||||
import { useFieldValue } from 'features/nodes/hooks/useFieldValue';
|
||||
import { fieldValueReset } from 'features/nodes/store/nodesSlice';
|
||||
import { selectWorkflowSlice } from 'features/nodes/store/workflowSlice';
|
||||
import { isFloatFieldCollectionInputInstance, isIntegerFieldCollectionInputInstance } from 'features/nodes/types/field';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
|
||||
@@ -11,38 +10,19 @@ export const useFieldOriginalValue = (nodeId: string, fieldName: string) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const selectOriginalExposedFieldValues = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectWorkflowSlice, (workflow) =>
|
||||
workflow.originalExposedFieldValues.find((v) => v.nodeId === nodeId && v.fieldName === fieldName)
|
||||
createSelector(
|
||||
selectWorkflowSlice,
|
||||
(workflow) =>
|
||||
workflow.originalExposedFieldValues.find((v) => v.nodeId === nodeId && v.fieldName === fieldName)?.value
|
||||
),
|
||||
[nodeId, fieldName]
|
||||
);
|
||||
const exposedField = useAppSelector(selectOriginalExposedFieldValues);
|
||||
const field = useFieldInputInstance(nodeId, fieldName);
|
||||
const isValueChanged = useMemo(() => {
|
||||
if (!field) {
|
||||
// Field is not found, so it is not changed
|
||||
return false;
|
||||
}
|
||||
if (isFloatFieldCollectionInputInstance(field) && isFloatFieldCollectionInputInstance(exposedField?.field)) {
|
||||
return !isEqual(field.generator, exposedField.field.generator);
|
||||
}
|
||||
if (isIntegerFieldCollectionInputInstance(field) && isIntegerFieldCollectionInputInstance(exposedField?.field)) {
|
||||
return !isEqual(field.generator, exposedField.field.generator);
|
||||
}
|
||||
return !isEqual(field.value, exposedField?.field.value);
|
||||
}, [field, exposedField]);
|
||||
const originalValue = useAppSelector(selectOriginalExposedFieldValues);
|
||||
const value = useFieldValue(nodeId, fieldName);
|
||||
const isValueChanged = useMemo(() => !isEqual(value, originalValue), [value, originalValue]);
|
||||
const onReset = useCallback(() => {
|
||||
if (!exposedField) {
|
||||
return;
|
||||
}
|
||||
const { value } = exposedField.field;
|
||||
const generator =
|
||||
isIntegerFieldCollectionInputInstance(exposedField.field) ||
|
||||
isFloatFieldCollectionInputInstance(exposedField.field)
|
||||
? exposedField.field.generator
|
||||
: undefined;
|
||||
dispatch(fieldValueReset({ nodeId, fieldName, value, generator }));
|
||||
}, [dispatch, fieldName, nodeId, exposedField]);
|
||||
dispatch(fieldValueReset({ nodeId, fieldName, value: originalValue }));
|
||||
}, [dispatch, fieldName, nodeId, originalValue]);
|
||||
|
||||
return { originalValue: exposedField, isValueChanged, onReset };
|
||||
return { originalValue, isValueChanged, onReset };
|
||||
};
|
||||
|
||||
@@ -19,7 +19,6 @@ import type {
|
||||
FluxVAEModelFieldValue,
|
||||
ImageFieldCollectionValue,
|
||||
ImageFieldValue,
|
||||
IntegerFieldCollectionValue,
|
||||
IntegerFieldValue,
|
||||
IPAdapterModelFieldValue,
|
||||
LoRAModelFieldValue,
|
||||
@@ -29,15 +28,12 @@ import type {
|
||||
SDXLRefinerModelFieldValue,
|
||||
SpandrelImageToImageModelFieldValue,
|
||||
StatefulFieldValue,
|
||||
StringFieldCollectionValue,
|
||||
StringFieldValue,
|
||||
T2IAdapterModelFieldValue,
|
||||
T5EncoderModelFieldValue,
|
||||
VAEModelFieldValue,
|
||||
} from 'features/nodes/types/field';
|
||||
import {
|
||||
isFloatFieldCollectionInputInstance,
|
||||
isIntegerFieldCollectionInputInstance,
|
||||
zBoardFieldValue,
|
||||
zBooleanFieldValue,
|
||||
zCLIPEmbedModelFieldValue,
|
||||
@@ -47,12 +43,10 @@ import {
|
||||
zControlLoRAModelFieldValue,
|
||||
zControlNetModelFieldValue,
|
||||
zEnumFieldValue,
|
||||
zFloatFieldCollectionValue,
|
||||
zFloatFieldValue,
|
||||
zFluxVAEModelFieldValue,
|
||||
zImageFieldCollectionValue,
|
||||
zImageFieldValue,
|
||||
zIntegerFieldCollectionValue,
|
||||
zIntegerFieldValue,
|
||||
zIPAdapterModelFieldValue,
|
||||
zLoRAModelFieldValue,
|
||||
@@ -62,22 +56,11 @@ import {
|
||||
zSDXLRefinerModelFieldValue,
|
||||
zSpandrelImageToImageModelFieldValue,
|
||||
zStatefulFieldValue,
|
||||
zStringFieldCollectionValue,
|
||||
zStringFieldValue,
|
||||
zT2IAdapterModelFieldValue,
|
||||
zT5EncoderModelFieldValue,
|
||||
zVAEModelFieldValue,
|
||||
} from 'features/nodes/types/field';
|
||||
import type {
|
||||
FloatRangeStartStepCountGenerator,
|
||||
IntegerRangeStartStepCountGenerator,
|
||||
} from 'features/nodes/types/generators';
|
||||
import {
|
||||
floatRangeStartStepCountGenerator,
|
||||
getDefaultFloatRangeStartStepCountGenerator,
|
||||
getDefaultIntegerRangeStartStepCountGenerator,
|
||||
integerRangeStartStepCountGenerator,
|
||||
} from 'features/nodes/types/generators';
|
||||
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
|
||||
import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
|
||||
import { atom, computed } from 'nanostores';
|
||||
@@ -95,22 +78,11 @@ const initialNodesState: NodesState = {
|
||||
edges: [],
|
||||
};
|
||||
|
||||
type FieldValueAction<T extends FieldValue, U = unknown> = PayloadAction<
|
||||
{
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
value: T;
|
||||
} & U
|
||||
>;
|
||||
|
||||
const selectField = (state: NodesState, nodeId: string, fieldName: string) => {
|
||||
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
|
||||
const node = state.nodes?.[nodeIndex];
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
return node.data?.inputs[fieldName];
|
||||
};
|
||||
type FieldValueAction<T extends FieldValue> = PayloadAction<{
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
value: T;
|
||||
}>;
|
||||
|
||||
const fieldValueReducer = <T extends FieldValue>(
|
||||
state: NodesState,
|
||||
@@ -118,24 +90,17 @@ const fieldValueReducer = <T extends FieldValue>(
|
||||
schema: z.ZodTypeAny
|
||||
) => {
|
||||
const { nodeId, fieldName, value } = action.payload;
|
||||
const field = selectField(state, nodeId, fieldName);
|
||||
const result = schema.safeParse(value);
|
||||
if (!field || !result.success) {
|
||||
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
|
||||
const node = state.nodes?.[nodeIndex];
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
field.value = result.data;
|
||||
// Special handling if the field value is being reset
|
||||
if (result.data === undefined) {
|
||||
if (isFloatFieldCollectionInputInstance(field)) {
|
||||
if (field.lockLinearView && field.generator) {
|
||||
field.generator = getDefaultFloatRangeStartStepCountGenerator();
|
||||
}
|
||||
} else if (isIntegerFieldCollectionInputInstance(field)) {
|
||||
if (field.lockLinearView && field.generator) {
|
||||
field.generator = getDefaultIntegerRangeStartStepCountGenerator();
|
||||
}
|
||||
}
|
||||
const input = node.data?.inputs[fieldName];
|
||||
const result = schema.safeParse(value);
|
||||
if (!input || nodeIndex < 0 || !result.success) {
|
||||
return;
|
||||
}
|
||||
input.value = result.data;
|
||||
};
|
||||
|
||||
export const nodesSlice = createSlice({
|
||||
@@ -340,123 +305,15 @@ export const nodesSlice = createSlice({
|
||||
}
|
||||
node.data.notes = notes;
|
||||
},
|
||||
fieldValueReset: (
|
||||
state,
|
||||
action: FieldValueAction<
|
||||
StatefulFieldValue,
|
||||
{ generator?: IntegerRangeStartStepCountGenerator | FloatRangeStartStepCountGenerator }
|
||||
>
|
||||
) => {
|
||||
const { nodeId, fieldName, value, generator } = action.payload;
|
||||
const field = selectField(state, nodeId, fieldName);
|
||||
const result = zStatefulFieldValue.safeParse(value);
|
||||
|
||||
if (!field || !result.success) {
|
||||
return;
|
||||
}
|
||||
|
||||
field.value = result.data;
|
||||
|
||||
if (isFloatFieldCollectionInputInstance(field) && generator?.type === 'float-range-generator-start-step-count') {
|
||||
field.generator = generator;
|
||||
} else if (
|
||||
isIntegerFieldCollectionInputInstance(field) &&
|
||||
generator?.type === 'integer-range-generator-start-step-count'
|
||||
) {
|
||||
field.generator = generator;
|
||||
}
|
||||
fieldValueReset: (state, action: FieldValueAction<StatefulFieldValue>) => {
|
||||
fieldValueReducer(state, action, zStatefulFieldValue);
|
||||
},
|
||||
fieldStringValueChanged: (state, action: FieldValueAction<StringFieldValue>) => {
|
||||
fieldValueReducer(state, action, zStringFieldValue);
|
||||
},
|
||||
fieldStringCollectionValueChanged: (state, action: FieldValueAction<StringFieldCollectionValue>) => {
|
||||
fieldValueReducer(state, action, zStringFieldCollectionValue);
|
||||
},
|
||||
fieldNumberValueChanged: (state, action: FieldValueAction<IntegerFieldValue | FloatFieldValue>) => {
|
||||
fieldValueReducer(state, action, zIntegerFieldValue.or(zFloatFieldValue));
|
||||
},
|
||||
fieldNumberCollectionValueChanged: (state, action: FieldValueAction<IntegerFieldCollectionValue>) => {
|
||||
fieldValueReducer(state, action, zIntegerFieldCollectionValue.or(zFloatFieldCollectionValue));
|
||||
},
|
||||
fieldNumberCollectionGeneratorToggled: (state, action: PayloadAction<{ nodeId: string; fieldName: string }>) => {
|
||||
const { nodeId, fieldName } = action.payload;
|
||||
const field = selectField(state, nodeId, fieldName);
|
||||
if (!field) {
|
||||
return;
|
||||
}
|
||||
if (isFloatFieldCollectionInputInstance(field)) {
|
||||
field.generator = field.generator ? undefined : getDefaultFloatRangeStartStepCountGenerator();
|
||||
} else if (isIntegerFieldCollectionInputInstance(field)) {
|
||||
field.generator = field.generator ? undefined : getDefaultIntegerRangeStartStepCountGenerator();
|
||||
} else {
|
||||
// This should never happen
|
||||
}
|
||||
},
|
||||
fieldNumberCollectionGeneratorStateChanged: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
generatorState: FloatRangeStartStepCountGenerator | IntegerRangeStartStepCountGenerator;
|
||||
}>
|
||||
) => {
|
||||
const { nodeId, fieldName, generatorState } = action.payload;
|
||||
const field = selectField(state, nodeId, fieldName);
|
||||
if (!field) {
|
||||
return;
|
||||
}
|
||||
if (
|
||||
isFloatFieldCollectionInputInstance(field) &&
|
||||
generatorState.type === 'float-range-generator-start-step-count'
|
||||
) {
|
||||
field.generator = generatorState;
|
||||
} else if (
|
||||
isIntegerFieldCollectionInputInstance(field) &&
|
||||
generatorState.type === 'integer-range-generator-start-step-count'
|
||||
) {
|
||||
field.generator = generatorState;
|
||||
} else {
|
||||
// This should never happen
|
||||
}
|
||||
},
|
||||
fieldNumberCollectionGeneratorCommitted: (state, action: PayloadAction<{ nodeId: string; fieldName: string }>) => {
|
||||
const { nodeId, fieldName } = action.payload;
|
||||
const field = selectField(state, nodeId, fieldName);
|
||||
if (!field) {
|
||||
return;
|
||||
}
|
||||
if (
|
||||
isFloatFieldCollectionInputInstance(field) &&
|
||||
field.generator &&
|
||||
field.generator.type === 'float-range-generator-start-step-count'
|
||||
) {
|
||||
field.value = floatRangeStartStepCountGenerator(field.generator);
|
||||
field.generator = undefined;
|
||||
} else if (
|
||||
isIntegerFieldCollectionInputInstance(field) &&
|
||||
field.generator &&
|
||||
field.generator.type === 'integer-range-generator-start-step-count'
|
||||
) {
|
||||
field.value = integerRangeStartStepCountGenerator(field.generator);
|
||||
field.generator = undefined;
|
||||
} else {
|
||||
// This should never happen
|
||||
}
|
||||
},
|
||||
fieldNumberCollectionLockLinearViewToggled: (
|
||||
state,
|
||||
action: PayloadAction<{ nodeId: string; fieldName: string }>
|
||||
) => {
|
||||
const { nodeId, fieldName } = action.payload;
|
||||
const field = selectField(state, nodeId, fieldName);
|
||||
if (!field) {
|
||||
return;
|
||||
}
|
||||
if (!isFloatFieldCollectionInputInstance(field) && !isIntegerFieldCollectionInputInstance(field)) {
|
||||
return;
|
||||
}
|
||||
field.lockLinearView = !field.lockLinearView;
|
||||
},
|
||||
fieldBooleanValueChanged: (state, action: FieldValueAction<BooleanFieldValue>) => {
|
||||
fieldValueReducer(state, action, zBooleanFieldValue);
|
||||
},
|
||||
@@ -578,15 +435,9 @@ export const {
|
||||
fieldModelIdentifierValueChanged,
|
||||
fieldMainModelValueChanged,
|
||||
fieldNumberValueChanged,
|
||||
fieldNumberCollectionValueChanged,
|
||||
fieldNumberCollectionGeneratorToggled,
|
||||
fieldNumberCollectionGeneratorStateChanged,
|
||||
fieldNumberCollectionGeneratorCommitted,
|
||||
fieldNumberCollectionLockLinearViewToggled,
|
||||
fieldRefinerModelValueChanged,
|
||||
fieldSchedulerValueChanged,
|
||||
fieldStringValueChanged,
|
||||
fieldStringCollectionValueChanged,
|
||||
fieldVaeModelValueChanged,
|
||||
fieldT5EncoderValueChanged,
|
||||
fieldCLIPEmbedValueChanged,
|
||||
@@ -695,11 +546,9 @@ export const isAnyNodeOrEdgeMutation = isAnyOf(
|
||||
fieldLoRAModelValueChanged,
|
||||
fieldMainModelValueChanged,
|
||||
fieldNumberValueChanged,
|
||||
fieldNumberCollectionValueChanged,
|
||||
fieldRefinerModelValueChanged,
|
||||
fieldSchedulerValueChanged,
|
||||
fieldStringValueChanged,
|
||||
fieldStringCollectionValueChanged,
|
||||
fieldVaeModelValueChanged,
|
||||
fieldT5EncoderValueChanged,
|
||||
fieldCLIPEmbedValueChanged,
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import type {
|
||||
FieldIdentifier,
|
||||
FieldInputInstance,
|
||||
FieldInputTemplate,
|
||||
FieldOutputTemplate,
|
||||
StatefulFieldValue,
|
||||
} from 'features/nodes/types/field';
|
||||
import type {
|
||||
AnyNode,
|
||||
@@ -31,15 +31,15 @@ export type NodesState = {
|
||||
};
|
||||
|
||||
export type WorkflowMode = 'edit' | 'view';
|
||||
export type FieldIdentifierWithInstance = FieldIdentifier & {
|
||||
field: FieldInputInstance;
|
||||
export type FieldIdentifierWithValue = FieldIdentifier & {
|
||||
value: StatefulFieldValue;
|
||||
};
|
||||
|
||||
export type WorkflowsState = Omit<WorkflowV3, 'nodes' | 'edges'> & {
|
||||
_version: 2;
|
||||
_version: 1;
|
||||
isTouched: boolean;
|
||||
mode: WorkflowMode;
|
||||
originalExposedFieldValues: FieldIdentifierWithInstance[];
|
||||
originalExposedFieldValues: FieldIdentifierWithValue[];
|
||||
searchTerm: string;
|
||||
orderBy?: WorkflowRecordOrderBy;
|
||||
orderDirection: SQLiteDirection;
|
||||
|
||||
@@ -5,7 +5,7 @@ import { deepClone } from 'common/util/deepClone';
|
||||
import { workflowLoaded } from 'features/nodes/store/actions';
|
||||
import { isAnyNodeOrEdgeMutation, nodeEditorReset, nodesChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type {
|
||||
FieldIdentifierWithInstance,
|
||||
FieldIdentifierWithValue,
|
||||
WorkflowMode,
|
||||
WorkflowsState as WorkflowState,
|
||||
} from 'features/nodes/store/types';
|
||||
@@ -31,7 +31,7 @@ const blankWorkflow: Omit<WorkflowV3, 'nodes' | 'edges'> = {
|
||||
};
|
||||
|
||||
const initialWorkflowState: WorkflowState = {
|
||||
_version: 2,
|
||||
_version: 1,
|
||||
isTouched: false,
|
||||
mode: 'view',
|
||||
originalExposedFieldValues: [],
|
||||
@@ -62,7 +62,7 @@ export const workflowSlice = createSlice({
|
||||
const { id, isOpen } = action.payload;
|
||||
state.categorySections[id] = isOpen;
|
||||
},
|
||||
workflowExposedFieldAdded: (state, action: PayloadAction<FieldIdentifierWithInstance>) => {
|
||||
workflowExposedFieldAdded: (state, action: PayloadAction<FieldIdentifierWithValue>) => {
|
||||
state.exposedFields = uniqBy(
|
||||
state.exposedFields.concat(omit(action.payload, 'value')),
|
||||
(field) => `${field.nodeId}-${field.fieldName}`
|
||||
@@ -128,25 +128,25 @@ export const workflowSlice = createSlice({
|
||||
builder.addCase(workflowLoaded, (state, action) => {
|
||||
const { nodes, edges: _edges, ...workflowExtra } = action.payload;
|
||||
|
||||
const originalExposedFieldValues: FieldIdentifierWithInstance[] = [];
|
||||
const originalExposedFieldValues: FieldIdentifierWithValue[] = [];
|
||||
|
||||
workflowExtra.exposedFields.forEach(({ nodeId, fieldName }) => {
|
||||
const node = nodes.find((n) => n.id === nodeId);
|
||||
workflowExtra.exposedFields.forEach((field) => {
|
||||
const node = nodes.find((n) => n.id === field.nodeId);
|
||||
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const field = node.data.inputs[fieldName];
|
||||
const input = node.data.inputs[field.fieldName];
|
||||
|
||||
if (!field) {
|
||||
if (!input) {
|
||||
return;
|
||||
}
|
||||
|
||||
const originalExposedFieldValue = {
|
||||
nodeId,
|
||||
fieldName,
|
||||
field,
|
||||
nodeId: field.nodeId,
|
||||
fieldName: field.fieldName,
|
||||
value: input.value,
|
||||
};
|
||||
originalExposedFieldValues.push(originalExposedFieldValue);
|
||||
});
|
||||
@@ -243,9 +243,6 @@ const migrateWorkflowState = (state: any): any => {
|
||||
if (!('_version' in state)) {
|
||||
state._version = 1;
|
||||
}
|
||||
if (state._version === 1) {
|
||||
return deepClone(initialWorkflowState);
|
||||
}
|
||||
return state;
|
||||
};
|
||||
|
||||
|
||||
@@ -1,8 +1,3 @@
|
||||
import {
|
||||
zFloatRangeStartStepCountGenerator,
|
||||
zIntegerRangeStartStepCountGenerator,
|
||||
} from 'features/nodes/types/generators';
|
||||
import { buildTypeGuard } from 'features/parameters/types/parameterSchemas';
|
||||
import { z } from 'zod';
|
||||
|
||||
import { zBoardField, zColorField, zImageField, zModelIdentifierField, zSchedulerField } from './common';
|
||||
@@ -83,35 +78,14 @@ const zIntegerFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('IntegerField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zIntegerCollectionFieldType = z.object({
|
||||
name: z.literal('IntegerField'),
|
||||
cardinality: z.literal(COLLECTION),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
export const isIntegerCollectionFieldType = buildTypeGuard(zIntegerCollectionFieldType);
|
||||
|
||||
const zFloatFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('FloatField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zFloatCollectionFieldType = z.object({
|
||||
name: z.literal('FloatField'),
|
||||
cardinality: z.literal(COLLECTION),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
export const isFloatCollectionFieldType = buildTypeGuard(zFloatCollectionFieldType);
|
||||
|
||||
const zStringFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('StringField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zStringCollectionFieldType = z.object({
|
||||
name: z.literal('StringField'),
|
||||
cardinality: z.literal(COLLECTION),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
export const isStringCollectionFieldType = buildTypeGuard(zStringCollectionFieldType);
|
||||
|
||||
const zBooleanFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('BooleanField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
@@ -129,7 +103,9 @@ const zImageCollectionFieldType = z.object({
|
||||
cardinality: z.literal(COLLECTION),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
export const isImageCollectionFieldType = buildTypeGuard(zImageCollectionFieldType);
|
||||
export const isImageCollectionFieldType = (
|
||||
fieldType: FieldType
|
||||
): fieldType is z.infer<typeof zImageCollectionFieldType> => zImageCollectionFieldType.safeParse(fieldType).success;
|
||||
const zBoardFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('BoardField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
@@ -278,48 +254,10 @@ const zIntegerFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type IntegerFieldValue = z.infer<typeof zIntegerFieldValue>;
|
||||
export type IntegerFieldInputInstance = z.infer<typeof zIntegerFieldInputInstance>;
|
||||
export type IntegerFieldInputTemplate = z.infer<typeof zIntegerFieldInputTemplate>;
|
||||
export const isIntegerFieldInputInstance = buildTypeGuard(zIntegerFieldInputInstance);
|
||||
export const isIntegerFieldInputTemplate = buildTypeGuard(zIntegerFieldInputTemplate);
|
||||
// #endregion
|
||||
|
||||
// #region IntegerField Collection
|
||||
export const zIntegerFieldCollectionValue = z.array(zIntegerFieldValue).optional();
|
||||
const zIntegerFieldCollectionInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zIntegerFieldCollectionValue,
|
||||
generator: zIntegerRangeStartStepCountGenerator.optional(),
|
||||
lockLinearView: z.boolean().default(false),
|
||||
});
|
||||
const zIntegerFieldCollectionInputTemplate = zFieldInputTemplateBase
|
||||
.extend({
|
||||
type: zIntegerCollectionFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zIntegerFieldCollectionValue,
|
||||
maxItems: z.number().int().gte(0).optional(),
|
||||
minItems: z.number().int().gte(0).optional(),
|
||||
multipleOf: z.number().int().optional(),
|
||||
maximum: z.number().int().optional(),
|
||||
exclusiveMaximum: z.number().int().optional(),
|
||||
minimum: z.number().int().optional(),
|
||||
exclusiveMinimum: z.number().int().optional(),
|
||||
})
|
||||
.refine(
|
||||
(val) => {
|
||||
if (val.maxItems !== undefined && val.minItems !== undefined) {
|
||||
return val.maxItems >= val.minItems;
|
||||
}
|
||||
return true;
|
||||
},
|
||||
{ message: 'maxItems must be greater than or equal to minItems' }
|
||||
);
|
||||
|
||||
const zIntegerFieldCollectionOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zIntegerCollectionFieldType,
|
||||
});
|
||||
export type IntegerFieldCollectionValue = z.infer<typeof zIntegerFieldCollectionValue>;
|
||||
export type IntegerFieldCollectionInputInstance = z.infer<typeof zIntegerFieldCollectionInputInstance>;
|
||||
export type IntegerFieldCollectionInputTemplate = z.infer<typeof zIntegerFieldCollectionInputTemplate>;
|
||||
export const isIntegerFieldCollectionInputInstance = buildTypeGuard(zIntegerFieldCollectionInputInstance);
|
||||
export const isIntegerFieldCollectionInputTemplate = buildTypeGuard(zIntegerFieldCollectionInputTemplate);
|
||||
export const isIntegerFieldInputInstance = (val: unknown): val is IntegerFieldInputInstance =>
|
||||
zIntegerFieldInputInstance.safeParse(val).success;
|
||||
export const isIntegerFieldInputTemplate = (val: unknown): val is IntegerFieldInputTemplate =>
|
||||
zIntegerFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region FloatField
|
||||
@@ -344,48 +282,10 @@ const zFloatFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type FloatFieldValue = z.infer<typeof zFloatFieldValue>;
|
||||
export type FloatFieldInputInstance = z.infer<typeof zFloatFieldInputInstance>;
|
||||
export type FloatFieldInputTemplate = z.infer<typeof zFloatFieldInputTemplate>;
|
||||
export const isFloatFieldInputInstance = buildTypeGuard(zFloatFieldInputInstance);
|
||||
export const isFloatFieldInputTemplate = buildTypeGuard(zFloatFieldInputTemplate);
|
||||
// #endregion
|
||||
|
||||
// #region FloatField Collection
|
||||
|
||||
export const zFloatFieldCollectionValue = z.array(zFloatFieldValue).optional();
|
||||
const zFloatFieldCollectionInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zFloatFieldCollectionValue,
|
||||
generator: zFloatRangeStartStepCountGenerator.optional(),
|
||||
lockLinearView: z.boolean().default(false),
|
||||
});
|
||||
const zFloatFieldCollectionInputTemplate = zFieldInputTemplateBase
|
||||
.extend({
|
||||
type: zFloatCollectionFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zFloatFieldCollectionValue,
|
||||
maxItems: z.number().int().gte(0).optional(),
|
||||
minItems: z.number().int().gte(0).optional(),
|
||||
multipleOf: z.number().int().optional(),
|
||||
maximum: z.number().optional(),
|
||||
exclusiveMaximum: z.number().optional(),
|
||||
minimum: z.number().optional(),
|
||||
exclusiveMinimum: z.number().optional(),
|
||||
})
|
||||
.refine(
|
||||
(val) => {
|
||||
if (val.maxItems !== undefined && val.minItems !== undefined) {
|
||||
return val.maxItems >= val.minItems;
|
||||
}
|
||||
return true;
|
||||
},
|
||||
{ message: 'maxItems must be greater than or equal to minItems' }
|
||||
);
|
||||
|
||||
const zFloatFieldCollectionOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zFloatCollectionFieldType,
|
||||
});
|
||||
export type FloatFieldCollectionInputInstance = z.infer<typeof zFloatFieldCollectionInputInstance>;
|
||||
export type FloatFieldCollectionInputTemplate = z.infer<typeof zFloatFieldCollectionInputTemplate>;
|
||||
export const isFloatFieldCollectionInputInstance = buildTypeGuard(zFloatFieldCollectionInputInstance);
|
||||
export const isFloatFieldCollectionInputTemplate = buildTypeGuard(zFloatFieldCollectionInputTemplate);
|
||||
export const isFloatFieldInputInstance = (val: unknown): val is FloatFieldInputInstance =>
|
||||
zFloatFieldInputInstance.safeParse(val).success;
|
||||
export const isFloatFieldInputTemplate = (val: unknown): val is FloatFieldInputTemplate =>
|
||||
zFloatFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region StringField
|
||||
@@ -415,55 +315,13 @@ const zStringFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zStringFieldType,
|
||||
});
|
||||
|
||||
// #region StringField Collection
|
||||
export const zStringFieldCollectionValue = z.array(zStringFieldValue).optional();
|
||||
const zStringFieldCollectionInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zStringFieldCollectionValue,
|
||||
});
|
||||
const zStringFieldCollectionInputTemplate = zFieldInputTemplateBase
|
||||
.extend({
|
||||
type: zStringCollectionFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zStringFieldCollectionValue,
|
||||
maxLength: z.number().int().gte(0).optional(),
|
||||
minLength: z.number().int().gte(0).optional(),
|
||||
maxItems: z.number().int().gte(0).optional(),
|
||||
minItems: z.number().int().gte(0).optional(),
|
||||
})
|
||||
.refine(
|
||||
(val) => {
|
||||
if (val.maxLength !== undefined && val.minLength !== undefined) {
|
||||
return val.maxLength >= val.minLength;
|
||||
}
|
||||
return true;
|
||||
},
|
||||
{ message: 'maxLength must be greater than or equal to minLength' }
|
||||
)
|
||||
.refine(
|
||||
(val) => {
|
||||
if (val.maxItems !== undefined && val.minItems !== undefined) {
|
||||
return val.maxItems >= val.minItems;
|
||||
}
|
||||
return true;
|
||||
},
|
||||
{ message: 'maxItems must be greater than or equal to minItems' }
|
||||
);
|
||||
|
||||
const zStringFieldCollectionOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zStringCollectionFieldType,
|
||||
});
|
||||
export type StringFieldCollectionValue = z.infer<typeof zStringFieldCollectionValue>;
|
||||
export type StringFieldCollectionInputInstance = z.infer<typeof zStringFieldCollectionInputInstance>;
|
||||
export type StringFieldCollectionInputTemplate = z.infer<typeof zStringFieldCollectionInputTemplate>;
|
||||
export const isStringFieldCollectionInputInstance = buildTypeGuard(zStringFieldCollectionInputInstance);
|
||||
export const isStringFieldCollectionInputTemplate = buildTypeGuard(zStringFieldCollectionInputTemplate);
|
||||
// #endregion
|
||||
|
||||
export type StringFieldValue = z.infer<typeof zStringFieldValue>;
|
||||
export type StringFieldInputInstance = z.infer<typeof zStringFieldInputInstance>;
|
||||
export type StringFieldInputTemplate = z.infer<typeof zStringFieldInputTemplate>;
|
||||
export const isStringFieldInputInstance = buildTypeGuard(zStringFieldInputInstance);
|
||||
export const isStringFieldInputTemplate = buildTypeGuard(zStringFieldInputTemplate);
|
||||
export const isStringFieldInputInstance = (val: unknown): val is StringFieldInputInstance =>
|
||||
zStringFieldInputInstance.safeParse(val).success;
|
||||
export const isStringFieldInputTemplate = (val: unknown): val is StringFieldInputTemplate =>
|
||||
zStringFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region BooleanField
|
||||
@@ -483,8 +341,10 @@ const zBooleanFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type BooleanFieldValue = z.infer<typeof zBooleanFieldValue>;
|
||||
export type BooleanFieldInputInstance = z.infer<typeof zBooleanFieldInputInstance>;
|
||||
export type BooleanFieldInputTemplate = z.infer<typeof zBooleanFieldInputTemplate>;
|
||||
export const isBooleanFieldInputInstance = buildTypeGuard(zBooleanFieldInputInstance);
|
||||
export const isBooleanFieldInputTemplate = buildTypeGuard(zBooleanFieldInputTemplate);
|
||||
export const isBooleanFieldInputInstance = (val: unknown): val is BooleanFieldInputInstance =>
|
||||
zBooleanFieldInputInstance.safeParse(val).success;
|
||||
export const isBooleanFieldInputTemplate = (val: unknown): val is BooleanFieldInputTemplate =>
|
||||
zBooleanFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region EnumField
|
||||
@@ -506,8 +366,10 @@ const zEnumFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type EnumFieldValue = z.infer<typeof zEnumFieldValue>;
|
||||
export type EnumFieldInputInstance = z.infer<typeof zEnumFieldInputInstance>;
|
||||
export type EnumFieldInputTemplate = z.infer<typeof zEnumFieldInputTemplate>;
|
||||
export const isEnumFieldInputInstance = buildTypeGuard(zEnumFieldInputInstance);
|
||||
export const isEnumFieldInputTemplate = buildTypeGuard(zEnumFieldInputTemplate);
|
||||
export const isEnumFieldInputInstance = (val: unknown): val is EnumFieldInputInstance =>
|
||||
zEnumFieldInputInstance.safeParse(val).success;
|
||||
export const isEnumFieldInputTemplate = (val: unknown): val is EnumFieldInputTemplate =>
|
||||
zEnumFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region ImageField
|
||||
@@ -526,8 +388,10 @@ const zImageFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type ImageFieldValue = z.infer<typeof zImageFieldValue>;
|
||||
export type ImageFieldInputInstance = z.infer<typeof zImageFieldInputInstance>;
|
||||
export type ImageFieldInputTemplate = z.infer<typeof zImageFieldInputTemplate>;
|
||||
export const isImageFieldInputInstance = buildTypeGuard(zImageFieldInputInstance);
|
||||
export const isImageFieldInputTemplate = buildTypeGuard(zImageFieldInputTemplate);
|
||||
export const isImageFieldInputInstance = (val: unknown): val is ImageFieldInputInstance =>
|
||||
zImageFieldInputInstance.safeParse(val).success;
|
||||
export const isImageFieldInputTemplate = (val: unknown): val is ImageFieldInputTemplate =>
|
||||
zImageFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region ImageField Collection
|
||||
@@ -550,7 +414,7 @@ const zImageFieldCollectionInputTemplate = zFieldInputTemplateBase
|
||||
}
|
||||
return true;
|
||||
},
|
||||
{ message: 'maxItems must be greater than or equal to minItems' }
|
||||
{ message: 'maxLength must be greater than or equal to minLength' }
|
||||
);
|
||||
|
||||
const zImageFieldCollectionOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
@@ -559,8 +423,10 @@ const zImageFieldCollectionOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type ImageFieldCollectionValue = z.infer<typeof zImageFieldCollectionValue>;
|
||||
export type ImageFieldCollectionInputInstance = z.infer<typeof zImageFieldCollectionInputInstance>;
|
||||
export type ImageFieldCollectionInputTemplate = z.infer<typeof zImageFieldCollectionInputTemplate>;
|
||||
export const isImageFieldCollectionInputInstance = buildTypeGuard(zImageFieldCollectionInputInstance);
|
||||
export const isImageFieldCollectionInputTemplate = buildTypeGuard(zImageFieldCollectionInputTemplate);
|
||||
export const isImageFieldCollectionInputInstance = (val: unknown): val is ImageFieldCollectionInputInstance =>
|
||||
zImageFieldCollectionInputInstance.safeParse(val).success;
|
||||
export const isImageFieldCollectionInputTemplate = (val: unknown): val is ImageFieldCollectionInputTemplate =>
|
||||
zImageFieldCollectionInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region BoardField
|
||||
@@ -580,8 +446,10 @@ const zBoardFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type BoardFieldValue = z.infer<typeof zBoardFieldValue>;
|
||||
export type BoardFieldInputInstance = z.infer<typeof zBoardFieldInputInstance>;
|
||||
export type BoardFieldInputTemplate = z.infer<typeof zBoardFieldInputTemplate>;
|
||||
export const isBoardFieldInputInstance = buildTypeGuard(zBoardFieldInputInstance);
|
||||
export const isBoardFieldInputTemplate = buildTypeGuard(zBoardFieldInputTemplate);
|
||||
export const isBoardFieldInputInstance = (val: unknown): val is BoardFieldInputInstance =>
|
||||
zBoardFieldInputInstance.safeParse(val).success;
|
||||
export const isBoardFieldInputTemplate = (val: unknown): val is BoardFieldInputTemplate =>
|
||||
zBoardFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region ColorField
|
||||
@@ -601,8 +469,10 @@ const zColorFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type ColorFieldValue = z.infer<typeof zColorFieldValue>;
|
||||
export type ColorFieldInputInstance = z.infer<typeof zColorFieldInputInstance>;
|
||||
export type ColorFieldInputTemplate = z.infer<typeof zColorFieldInputTemplate>;
|
||||
export const isColorFieldInputInstance = buildTypeGuard(zColorFieldInputInstance);
|
||||
export const isColorFieldInputTemplate = buildTypeGuard(zColorFieldInputTemplate);
|
||||
export const isColorFieldInputInstance = (val: unknown): val is ColorFieldInputInstance =>
|
||||
zColorFieldInputInstance.safeParse(val).success;
|
||||
export const isColorFieldInputTemplate = (val: unknown): val is ColorFieldInputTemplate =>
|
||||
zColorFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region MainModelField
|
||||
@@ -622,8 +492,10 @@ const zMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type MainModelFieldValue = z.infer<typeof zMainModelFieldValue>;
|
||||
export type MainModelFieldInputInstance = z.infer<typeof zMainModelFieldInputInstance>;
|
||||
export type MainModelFieldInputTemplate = z.infer<typeof zMainModelFieldInputTemplate>;
|
||||
export const isMainModelFieldInputInstance = buildTypeGuard(zMainModelFieldInputInstance);
|
||||
export const isMainModelFieldInputTemplate = buildTypeGuard(zMainModelFieldInputTemplate);
|
||||
export const isMainModelFieldInputInstance = (val: unknown): val is MainModelFieldInputInstance =>
|
||||
zMainModelFieldInputInstance.safeParse(val).success;
|
||||
export const isMainModelFieldInputTemplate = (val: unknown): val is MainModelFieldInputTemplate =>
|
||||
zMainModelFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region ModelIdentifierField
|
||||
@@ -642,8 +514,10 @@ const zModelIdentifierFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type ModelIdentifierFieldValue = z.infer<typeof zModelIdentifierFieldValue>;
|
||||
export type ModelIdentifierFieldInputInstance = z.infer<typeof zModelIdentifierFieldInputInstance>;
|
||||
export type ModelIdentifierFieldInputTemplate = z.infer<typeof zModelIdentifierFieldInputTemplate>;
|
||||
export const isModelIdentifierFieldInputInstance = buildTypeGuard(zModelIdentifierFieldInputInstance);
|
||||
export const isModelIdentifierFieldInputTemplate = buildTypeGuard(zModelIdentifierFieldInputTemplate);
|
||||
export const isModelIdentifierFieldInputInstance = (val: unknown): val is ModelIdentifierFieldInputInstance =>
|
||||
zModelIdentifierFieldInputInstance.safeParse(val).success;
|
||||
export const isModelIdentifierFieldInputTemplate = (val: unknown): val is ModelIdentifierFieldInputTemplate =>
|
||||
zModelIdentifierFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region SDXLMainModelField
|
||||
@@ -662,8 +536,10 @@ const zSDXLMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
});
|
||||
export type SDXLMainModelFieldInputInstance = z.infer<typeof zSDXLMainModelFieldInputInstance>;
|
||||
export type SDXLMainModelFieldInputTemplate = z.infer<typeof zSDXLMainModelFieldInputTemplate>;
|
||||
export const isSDXLMainModelFieldInputInstance = buildTypeGuard(zSDXLMainModelFieldInputInstance);
|
||||
export const isSDXLMainModelFieldInputTemplate = buildTypeGuard(zSDXLMainModelFieldInputTemplate);
|
||||
export const isSDXLMainModelFieldInputInstance = (val: unknown): val is SDXLMainModelFieldInputInstance =>
|
||||
zSDXLMainModelFieldInputInstance.safeParse(val).success;
|
||||
export const isSDXLMainModelFieldInputTemplate = (val: unknown): val is SDXLMainModelFieldInputTemplate =>
|
||||
zSDXLMainModelFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region SD3MainModelField
|
||||
@@ -682,8 +558,10 @@ const zSD3MainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
});
|
||||
export type SD3MainModelFieldInputInstance = z.infer<typeof zSD3MainModelFieldInputInstance>;
|
||||
export type SD3MainModelFieldInputTemplate = z.infer<typeof zSD3MainModelFieldInputTemplate>;
|
||||
export const isSD3MainModelFieldInputInstance = buildTypeGuard(zSD3MainModelFieldInputInstance);
|
||||
export const isSD3MainModelFieldInputTemplate = buildTypeGuard(zSD3MainModelFieldInputTemplate);
|
||||
export const isSD3MainModelFieldInputInstance = (val: unknown): val is SD3MainModelFieldInputInstance =>
|
||||
zSD3MainModelFieldInputInstance.safeParse(val).success;
|
||||
export const isSD3MainModelFieldInputTemplate = (val: unknown): val is SD3MainModelFieldInputTemplate =>
|
||||
zSD3MainModelFieldInputTemplate.safeParse(val).success;
|
||||
|
||||
// #endregion
|
||||
|
||||
@@ -703,8 +581,10 @@ const zFluxMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
});
|
||||
export type FluxMainModelFieldInputInstance = z.infer<typeof zFluxMainModelFieldInputInstance>;
|
||||
export type FluxMainModelFieldInputTemplate = z.infer<typeof zFluxMainModelFieldInputTemplate>;
|
||||
export const isFluxMainModelFieldInputInstance = buildTypeGuard(zFluxMainModelFieldInputInstance);
|
||||
export const isFluxMainModelFieldInputTemplate = buildTypeGuard(zFluxMainModelFieldInputTemplate);
|
||||
export const isFluxMainModelFieldInputInstance = (val: unknown): val is FluxMainModelFieldInputInstance =>
|
||||
zFluxMainModelFieldInputInstance.safeParse(val).success;
|
||||
export const isFluxMainModelFieldInputTemplate = (val: unknown): val is FluxMainModelFieldInputTemplate =>
|
||||
zFluxMainModelFieldInputTemplate.safeParse(val).success;
|
||||
|
||||
// #endregion
|
||||
|
||||
@@ -726,8 +606,10 @@ const zSDXLRefinerModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type SDXLRefinerModelFieldValue = z.infer<typeof zSDXLRefinerModelFieldValue>;
|
||||
export type SDXLRefinerModelFieldInputInstance = z.infer<typeof zSDXLRefinerModelFieldInputInstance>;
|
||||
export type SDXLRefinerModelFieldInputTemplate = z.infer<typeof zSDXLRefinerModelFieldInputTemplate>;
|
||||
export const isSDXLRefinerModelFieldInputInstance = buildTypeGuard(zSDXLRefinerModelFieldInputInstance);
|
||||
export const isSDXLRefinerModelFieldInputTemplate = buildTypeGuard(zSDXLRefinerModelFieldInputTemplate);
|
||||
export const isSDXLRefinerModelFieldInputInstance = (val: unknown): val is SDXLRefinerModelFieldInputInstance =>
|
||||
zSDXLRefinerModelFieldInputInstance.safeParse(val).success;
|
||||
export const isSDXLRefinerModelFieldInputTemplate = (val: unknown): val is SDXLRefinerModelFieldInputTemplate =>
|
||||
zSDXLRefinerModelFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region VAEModelField
|
||||
@@ -747,8 +629,10 @@ const zVAEModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type VAEModelFieldValue = z.infer<typeof zVAEModelFieldValue>;
|
||||
export type VAEModelFieldInputInstance = z.infer<typeof zVAEModelFieldInputInstance>;
|
||||
export type VAEModelFieldInputTemplate = z.infer<typeof zVAEModelFieldInputTemplate>;
|
||||
export const isVAEModelFieldInputInstance = buildTypeGuard(zVAEModelFieldInputInstance);
|
||||
export const isVAEModelFieldInputTemplate = buildTypeGuard(zVAEModelFieldInputTemplate);
|
||||
export const isVAEModelFieldInputInstance = (val: unknown): val is VAEModelFieldInputInstance =>
|
||||
zVAEModelFieldInputInstance.safeParse(val).success;
|
||||
export const isVAEModelFieldInputTemplate = (val: unknown): val is VAEModelFieldInputTemplate =>
|
||||
zVAEModelFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region LoRAModelField
|
||||
@@ -768,8 +652,10 @@ const zLoRAModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type LoRAModelFieldValue = z.infer<typeof zLoRAModelFieldValue>;
|
||||
export type LoRAModelFieldInputInstance = z.infer<typeof zLoRAModelFieldInputInstance>;
|
||||
export type LoRAModelFieldInputTemplate = z.infer<typeof zLoRAModelFieldInputTemplate>;
|
||||
export const isLoRAModelFieldInputInstance = buildTypeGuard(zLoRAModelFieldInputInstance);
|
||||
export const isLoRAModelFieldInputTemplate = buildTypeGuard(zLoRAModelFieldInputTemplate);
|
||||
export const isLoRAModelFieldInputInstance = (val: unknown): val is LoRAModelFieldInputInstance =>
|
||||
zLoRAModelFieldInputInstance.safeParse(val).success;
|
||||
export const isLoRAModelFieldInputTemplate = (val: unknown): val is LoRAModelFieldInputTemplate =>
|
||||
zLoRAModelFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region ControlNetModelField
|
||||
@@ -789,8 +675,10 @@ const zControlNetModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type ControlNetModelFieldValue = z.infer<typeof zControlNetModelFieldValue>;
|
||||
export type ControlNetModelFieldInputInstance = z.infer<typeof zControlNetModelFieldInputInstance>;
|
||||
export type ControlNetModelFieldInputTemplate = z.infer<typeof zControlNetModelFieldInputTemplate>;
|
||||
export const isControlNetModelFieldInputInstance = buildTypeGuard(zControlNetModelFieldInputInstance);
|
||||
export const isControlNetModelFieldInputTemplate = buildTypeGuard(zControlNetModelFieldInputTemplate);
|
||||
export const isControlNetModelFieldInputInstance = (val: unknown): val is ControlNetModelFieldInputInstance =>
|
||||
zControlNetModelFieldInputInstance.safeParse(val).success;
|
||||
export const isControlNetModelFieldInputTemplate = (val: unknown): val is ControlNetModelFieldInputTemplate =>
|
||||
zControlNetModelFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region IPAdapterModelField
|
||||
@@ -810,8 +698,10 @@ const zIPAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type IPAdapterModelFieldValue = z.infer<typeof zIPAdapterModelFieldValue>;
|
||||
export type IPAdapterModelFieldInputInstance = z.infer<typeof zIPAdapterModelFieldInputInstance>;
|
||||
export type IPAdapterModelFieldInputTemplate = z.infer<typeof zIPAdapterModelFieldInputTemplate>;
|
||||
export const isIPAdapterModelFieldInputInstance = buildTypeGuard(zIPAdapterModelFieldInputInstance);
|
||||
export const isIPAdapterModelFieldInputTemplate = buildTypeGuard(zIPAdapterModelFieldInputTemplate);
|
||||
export const isIPAdapterModelFieldInputInstance = (val: unknown): val is IPAdapterModelFieldInputInstance =>
|
||||
zIPAdapterModelFieldInputInstance.safeParse(val).success;
|
||||
export const isIPAdapterModelFieldInputTemplate = (val: unknown): val is IPAdapterModelFieldInputTemplate =>
|
||||
zIPAdapterModelFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region T2IAdapterField
|
||||
@@ -831,8 +721,10 @@ const zT2IAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type T2IAdapterModelFieldValue = z.infer<typeof zT2IAdapterModelFieldValue>;
|
||||
export type T2IAdapterModelFieldInputInstance = z.infer<typeof zT2IAdapterModelFieldInputInstance>;
|
||||
export type T2IAdapterModelFieldInputTemplate = z.infer<typeof zT2IAdapterModelFieldInputTemplate>;
|
||||
export const isT2IAdapterModelFieldInputInstance = buildTypeGuard(zT2IAdapterModelFieldInputInstance);
|
||||
export const isT2IAdapterModelFieldInputTemplate = buildTypeGuard(zT2IAdapterModelFieldInputTemplate);
|
||||
export const isT2IAdapterModelFieldInputInstance = (val: unknown): val is T2IAdapterModelFieldInputInstance =>
|
||||
zT2IAdapterModelFieldInputInstance.safeParse(val).success;
|
||||
export const isT2IAdapterModelFieldInputTemplate = (val: unknown): val is T2IAdapterModelFieldInputTemplate =>
|
||||
zT2IAdapterModelFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region SpandrelModelToModelField
|
||||
@@ -852,12 +744,14 @@ const zSpandrelImageToImageModelFieldOutputTemplate = zFieldOutputTemplateBase.e
|
||||
export type SpandrelImageToImageModelFieldValue = z.infer<typeof zSpandrelImageToImageModelFieldValue>;
|
||||
export type SpandrelImageToImageModelFieldInputInstance = z.infer<typeof zSpandrelImageToImageModelFieldInputInstance>;
|
||||
export type SpandrelImageToImageModelFieldInputTemplate = z.infer<typeof zSpandrelImageToImageModelFieldInputTemplate>;
|
||||
export const isSpandrelImageToImageModelFieldInputInstance = buildTypeGuard(
|
||||
zSpandrelImageToImageModelFieldInputInstance
|
||||
);
|
||||
export const isSpandrelImageToImageModelFieldInputTemplate = buildTypeGuard(
|
||||
zSpandrelImageToImageModelFieldInputTemplate
|
||||
);
|
||||
export const isSpandrelImageToImageModelFieldInputInstance = (
|
||||
val: unknown
|
||||
): val is SpandrelImageToImageModelFieldInputInstance =>
|
||||
zSpandrelImageToImageModelFieldInputInstance.safeParse(val).success;
|
||||
export const isSpandrelImageToImageModelFieldInputTemplate = (
|
||||
val: unknown
|
||||
): val is SpandrelImageToImageModelFieldInputTemplate =>
|
||||
zSpandrelImageToImageModelFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region T5EncoderModelField
|
||||
@@ -876,8 +770,10 @@ export type T5EncoderModelFieldValue = z.infer<typeof zT5EncoderModelFieldValue>
|
||||
|
||||
export type T5EncoderModelFieldInputInstance = z.infer<typeof zT5EncoderModelFieldInputInstance>;
|
||||
export type T5EncoderModelFieldInputTemplate = z.infer<typeof zT5EncoderModelFieldInputTemplate>;
|
||||
export const isT5EncoderModelFieldInputInstance = buildTypeGuard(zT5EncoderModelFieldInputInstance);
|
||||
export const isT5EncoderModelFieldInputTemplate = buildTypeGuard(zT5EncoderModelFieldInputTemplate);
|
||||
export const isT5EncoderModelFieldInputInstance = (val: unknown): val is T5EncoderModelFieldInputInstance =>
|
||||
zT5EncoderModelFieldInputInstance.safeParse(val).success;
|
||||
export const isT5EncoderModelFieldInputTemplate = (val: unknown): val is T5EncoderModelFieldInputTemplate =>
|
||||
zT5EncoderModelFieldInputTemplate.safeParse(val).success;
|
||||
|
||||
// #endregion
|
||||
|
||||
@@ -897,8 +793,10 @@ export type FluxVAEModelFieldValue = z.infer<typeof zFluxVAEModelFieldValue>;
|
||||
|
||||
export type FluxVAEModelFieldInputInstance = z.infer<typeof zFluxVAEModelFieldInputInstance>;
|
||||
export type FluxVAEModelFieldInputTemplate = z.infer<typeof zFluxVAEModelFieldInputTemplate>;
|
||||
export const isFluxVAEModelFieldInputInstance = buildTypeGuard(zFluxVAEModelFieldInputInstance);
|
||||
export const isFluxVAEModelFieldInputTemplate = buildTypeGuard(zFluxVAEModelFieldInputTemplate);
|
||||
export const isFluxVAEModelFieldInputInstance = (val: unknown): val is FluxVAEModelFieldInputInstance =>
|
||||
zFluxVAEModelFieldInputInstance.safeParse(val).success;
|
||||
export const isFluxVAEModelFieldInputTemplate = (val: unknown): val is FluxVAEModelFieldInputTemplate =>
|
||||
zFluxVAEModelFieldInputTemplate.safeParse(val).success;
|
||||
|
||||
// #endregion
|
||||
|
||||
@@ -918,8 +816,10 @@ export type CLIPEmbedModelFieldValue = z.infer<typeof zCLIPEmbedModelFieldValue>
|
||||
|
||||
export type CLIPEmbedModelFieldInputInstance = z.infer<typeof zCLIPEmbedModelFieldInputInstance>;
|
||||
export type CLIPEmbedModelFieldInputTemplate = z.infer<typeof zCLIPEmbedModelFieldInputTemplate>;
|
||||
export const isCLIPEmbedModelFieldInputInstance = buildTypeGuard(zCLIPEmbedModelFieldInputInstance);
|
||||
export const isCLIPEmbedModelFieldInputTemplate = buildTypeGuard(zCLIPEmbedModelFieldInputTemplate);
|
||||
export const isCLIPEmbedModelFieldInputInstance = (val: unknown): val is CLIPEmbedModelFieldInputInstance =>
|
||||
zCLIPEmbedModelFieldInputInstance.safeParse(val).success;
|
||||
export const isCLIPEmbedModelFieldInputTemplate = (val: unknown): val is CLIPEmbedModelFieldInputTemplate =>
|
||||
zCLIPEmbedModelFieldInputTemplate.safeParse(val).success;
|
||||
|
||||
// #endregion
|
||||
|
||||
@@ -939,8 +839,10 @@ export type CLIPLEmbedModelFieldValue = z.infer<typeof zCLIPLEmbedModelFieldValu
|
||||
|
||||
export type CLIPLEmbedModelFieldInputInstance = z.infer<typeof zCLIPLEmbedModelFieldInputInstance>;
|
||||
export type CLIPLEmbedModelFieldInputTemplate = z.infer<typeof zCLIPLEmbedModelFieldInputTemplate>;
|
||||
export const isCLIPLEmbedModelFieldInputInstance = buildTypeGuard(zCLIPLEmbedModelFieldInputInstance);
|
||||
export const isCLIPLEmbedModelFieldInputTemplate = buildTypeGuard(zCLIPLEmbedModelFieldInputTemplate);
|
||||
export const isCLIPLEmbedModelFieldInputInstance = (val: unknown): val is CLIPLEmbedModelFieldInputInstance =>
|
||||
zCLIPLEmbedModelFieldInputInstance.safeParse(val).success;
|
||||
export const isCLIPLEmbedModelFieldInputTemplate = (val: unknown): val is CLIPLEmbedModelFieldInputTemplate =>
|
||||
zCLIPLEmbedModelFieldInputTemplate.safeParse(val).success;
|
||||
|
||||
// #endregion
|
||||
|
||||
@@ -960,8 +862,10 @@ export type CLIPGEmbedModelFieldValue = z.infer<typeof zCLIPLEmbedModelFieldValu
|
||||
|
||||
export type CLIPGEmbedModelFieldInputInstance = z.infer<typeof zCLIPGEmbedModelFieldInputInstance>;
|
||||
export type CLIPGEmbedModelFieldInputTemplate = z.infer<typeof zCLIPGEmbedModelFieldInputTemplate>;
|
||||
export const isCLIPGEmbedModelFieldInputInstance = buildTypeGuard(zCLIPGEmbedModelFieldInputInstance);
|
||||
export const isCLIPGEmbedModelFieldInputTemplate = buildTypeGuard(zCLIPGEmbedModelFieldInputTemplate);
|
||||
export const isCLIPGEmbedModelFieldInputInstance = (val: unknown): val is CLIPGEmbedModelFieldInputInstance =>
|
||||
zCLIPGEmbedModelFieldInputInstance.safeParse(val).success;
|
||||
export const isCLIPGEmbedModelFieldInputTemplate = (val: unknown): val is CLIPGEmbedModelFieldInputTemplate =>
|
||||
zCLIPGEmbedModelFieldInputTemplate.safeParse(val).success;
|
||||
|
||||
// #endregion
|
||||
|
||||
@@ -981,8 +885,10 @@ export type ControlLoRAModelFieldValue = z.infer<typeof zCLIPLEmbedModelFieldVal
|
||||
|
||||
export type ControlLoRAModelFieldInputInstance = z.infer<typeof zControlLoRAModelFieldInputInstance>;
|
||||
export type ControlLoRAModelFieldInputTemplate = z.infer<typeof zControlLoRAModelFieldInputTemplate>;
|
||||
export const isControlLoRAModelFieldInputInstance = buildTypeGuard(zControlLoRAModelFieldInputInstance);
|
||||
export const isControlLoRAModelFieldInputTemplate = buildTypeGuard(zControlLoRAModelFieldInputTemplate);
|
||||
export const isControlLoRAModelFieldInputInstance = (val: unknown): val is ControlLoRAModelFieldInputInstance =>
|
||||
zControlLoRAModelFieldInputInstance.safeParse(val).success;
|
||||
export const isControlLoRAModelFieldInputTemplate = (val: unknown): val is ControlLoRAModelFieldInputTemplate =>
|
||||
zControlLoRAModelFieldInputTemplate.safeParse(val).success;
|
||||
|
||||
// #endregion
|
||||
|
||||
@@ -1003,8 +909,10 @@ const zSchedulerFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type SchedulerFieldValue = z.infer<typeof zSchedulerFieldValue>;
|
||||
export type SchedulerFieldInputInstance = z.infer<typeof zSchedulerFieldInputInstance>;
|
||||
export type SchedulerFieldInputTemplate = z.infer<typeof zSchedulerFieldInputTemplate>;
|
||||
export const isSchedulerFieldInputInstance = buildTypeGuard(zSchedulerFieldInputInstance);
|
||||
export const isSchedulerFieldInputTemplate = buildTypeGuard(zSchedulerFieldInputTemplate);
|
||||
export const isSchedulerFieldInputInstance = (val: unknown): val is SchedulerFieldInputInstance =>
|
||||
zSchedulerFieldInputInstance.safeParse(val).success;
|
||||
export const isSchedulerFieldInputTemplate = (val: unknown): val is SchedulerFieldInputTemplate =>
|
||||
zSchedulerFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region StatelessField
|
||||
@@ -1055,11 +963,8 @@ export type StatelessFieldInputTemplate = z.infer<typeof zStatelessFieldInputTem
|
||||
// #region StatefulFieldValue & FieldValue
|
||||
export const zStatefulFieldValue = z.union([
|
||||
zIntegerFieldValue,
|
||||
zIntegerFieldCollectionValue,
|
||||
zFloatFieldValue,
|
||||
zFloatFieldCollectionValue,
|
||||
zStringFieldValue,
|
||||
zStringFieldCollectionValue,
|
||||
zBooleanFieldValue,
|
||||
zEnumFieldValue,
|
||||
zImageFieldValue,
|
||||
@@ -1095,11 +1000,8 @@ export type FieldValue = z.infer<typeof zFieldValue>;
|
||||
// #region StatefulFieldInputInstance & FieldInputInstance
|
||||
const zStatefulFieldInputInstance = z.union([
|
||||
zIntegerFieldInputInstance,
|
||||
zIntegerFieldCollectionInputInstance,
|
||||
zFloatFieldInputInstance,
|
||||
zFloatFieldCollectionInputInstance,
|
||||
zStringFieldInputInstance,
|
||||
zStringFieldCollectionInputInstance,
|
||||
zBooleanFieldInputInstance,
|
||||
zEnumFieldInputInstance,
|
||||
zImageFieldInputInstance,
|
||||
@@ -1126,17 +1028,15 @@ const zStatefulFieldInputInstance = z.union([
|
||||
|
||||
export const zFieldInputInstance = z.union([zStatefulFieldInputInstance, zStatelessFieldInputInstance]);
|
||||
export type FieldInputInstance = z.infer<typeof zFieldInputInstance>;
|
||||
export const isFieldInputInstance = buildTypeGuard(zFieldInputInstance);
|
||||
export const isFieldInputInstance = (val: unknown): val is FieldInputInstance =>
|
||||
zFieldInputInstance.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region StatefulFieldInputTemplate & FieldInputTemplate
|
||||
const zStatefulFieldInputTemplate = z.union([
|
||||
zIntegerFieldInputTemplate,
|
||||
zIntegerFieldCollectionInputTemplate,
|
||||
zFloatFieldInputTemplate,
|
||||
zFloatFieldCollectionInputTemplate,
|
||||
zStringFieldInputTemplate,
|
||||
zStringFieldCollectionInputTemplate,
|
||||
zBooleanFieldInputTemplate,
|
||||
zEnumFieldInputTemplate,
|
||||
zImageFieldInputTemplate,
|
||||
@@ -1167,17 +1067,15 @@ const zStatefulFieldInputTemplate = z.union([
|
||||
|
||||
export const zFieldInputTemplate = z.union([zStatefulFieldInputTemplate, zStatelessFieldInputTemplate]);
|
||||
export type FieldInputTemplate = z.infer<typeof zFieldInputTemplate>;
|
||||
export const isFieldInputTemplate = buildTypeGuard(zFieldInputTemplate);
|
||||
export const isFieldInputTemplate = (val: unknown): val is FieldInputTemplate =>
|
||||
zFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region StatefulFieldOutputTemplate & FieldOutputTemplate
|
||||
const zStatefulFieldOutputTemplate = z.union([
|
||||
zIntegerFieldOutputTemplate,
|
||||
zIntegerFieldCollectionOutputTemplate,
|
||||
zFloatFieldOutputTemplate,
|
||||
zFloatFieldCollectionOutputTemplate,
|
||||
zStringFieldOutputTemplate,
|
||||
zStringFieldCollectionOutputTemplate,
|
||||
zBooleanFieldOutputTemplate,
|
||||
zEnumFieldOutputTemplate,
|
||||
zImageFieldOutputTemplate,
|
||||
|
||||
@@ -1,133 +0,0 @@
|
||||
import type {
|
||||
FloatFieldCollectionInputInstance,
|
||||
FloatFieldCollectionInputTemplate,
|
||||
ImageFieldCollectionInputTemplate,
|
||||
ImageFieldCollectionValue,
|
||||
IntegerFieldCollectionInputInstance,
|
||||
IntegerFieldCollectionInputTemplate,
|
||||
StringFieldCollectionInputTemplate,
|
||||
StringFieldCollectionValue,
|
||||
} from 'features/nodes/types/field';
|
||||
import {
|
||||
floatRangeStartStepCountGenerator,
|
||||
integerRangeStartStepCountGenerator,
|
||||
} from 'features/nodes/types/generators';
|
||||
import { t } from 'i18next';
|
||||
|
||||
export const validateImageFieldCollectionValue = (
|
||||
value: NonNullable<ImageFieldCollectionValue>,
|
||||
template: ImageFieldCollectionInputTemplate
|
||||
): string[] => {
|
||||
const reasons: string[] = [];
|
||||
const { minItems, maxItems } = template;
|
||||
const count = value.length;
|
||||
|
||||
// Image collections may have min or max items to validate
|
||||
if (minItems !== undefined && minItems > 0 && count === 0) {
|
||||
reasons.push(t('parameters.invoke.collectionEmpty'));
|
||||
}
|
||||
|
||||
if (minItems !== undefined && count < minItems) {
|
||||
reasons.push(t('parameters.invoke.collectionTooFewItems', { count, minItems }));
|
||||
}
|
||||
|
||||
if (maxItems !== undefined && count > maxItems) {
|
||||
reasons.push(t('parameters.invoke.collectionTooManyItems', { count, maxItems }));
|
||||
}
|
||||
|
||||
return reasons;
|
||||
};
|
||||
|
||||
export const validateStringFieldCollectionValue = (
|
||||
value: NonNullable<StringFieldCollectionValue>,
|
||||
template: StringFieldCollectionInputTemplate
|
||||
): string[] => {
|
||||
const reasons: string[] = [];
|
||||
const { minItems, maxItems, minLength, maxLength } = template;
|
||||
const count = value.length;
|
||||
|
||||
// Image collections may have min or max items to validate
|
||||
if (minItems !== undefined && minItems > 0 && count === 0) {
|
||||
reasons.push(t('parameters.invoke.collectionEmpty'));
|
||||
}
|
||||
|
||||
if (minItems !== undefined && count < minItems) {
|
||||
reasons.push(t('parameters.invoke.collectionTooFewItems', { count, minItems }));
|
||||
}
|
||||
|
||||
if (maxItems !== undefined && count > maxItems) {
|
||||
reasons.push(t('parameters.invoke.collectionTooManyItems', { count, maxItems }));
|
||||
}
|
||||
|
||||
for (const str of value) {
|
||||
if (maxLength !== undefined && str.length > maxLength) {
|
||||
reasons.push(t('parameters.invoke.collectionStringTooLong', { value, maxLength }));
|
||||
}
|
||||
if (minLength !== undefined && str.length < minLength) {
|
||||
reasons.push(t('parameters.invoke.collectionStringTooShort', { value, minLength }));
|
||||
}
|
||||
}
|
||||
|
||||
return reasons;
|
||||
};
|
||||
|
||||
export const resolveNumberFieldCollectionValue = (
|
||||
field: IntegerFieldCollectionInputInstance | FloatFieldCollectionInputInstance
|
||||
): number[] | undefined => {
|
||||
if (field.generator?.type === 'float-range-generator-start-step-count') {
|
||||
return floatRangeStartStepCountGenerator(field.generator);
|
||||
} else if (field.generator?.type === 'integer-range-generator-start-step-count') {
|
||||
return integerRangeStartStepCountGenerator(field.generator);
|
||||
} else {
|
||||
return field.value;
|
||||
}
|
||||
};
|
||||
|
||||
export const validateNumberFieldCollectionValue = (
|
||||
field: IntegerFieldCollectionInputInstance | FloatFieldCollectionInputInstance,
|
||||
template: IntegerFieldCollectionInputTemplate | FloatFieldCollectionInputTemplate
|
||||
): string[] => {
|
||||
const reasons: string[] = [];
|
||||
const { minItems, maxItems, minimum, maximum, exclusiveMinimum, exclusiveMaximum, multipleOf } = template;
|
||||
const value = resolveNumberFieldCollectionValue(field);
|
||||
|
||||
if (value === undefined) {
|
||||
reasons.push(t('parameters.invoke.collectionEmpty'));
|
||||
return reasons;
|
||||
}
|
||||
|
||||
const count = value.length;
|
||||
|
||||
// Image collections may have min or max items to validate
|
||||
if (minItems !== undefined && minItems > 0 && count === 0) {
|
||||
reasons.push(t('parameters.invoke.collectionEmpty'));
|
||||
}
|
||||
|
||||
if (minItems !== undefined && count < minItems) {
|
||||
reasons.push(t('parameters.invoke.collectionTooFewItems', { count, minItems }));
|
||||
}
|
||||
|
||||
if (maxItems !== undefined && count > maxItems) {
|
||||
reasons.push(t('parameters.invoke.collectionTooManyItems', { count, maxItems }));
|
||||
}
|
||||
|
||||
for (const num of value) {
|
||||
if (maximum !== undefined && num > maximum) {
|
||||
reasons.push(t('parameters.invoke.collectionNumberGTMax', { value, maximum }));
|
||||
}
|
||||
if (minimum !== undefined && num < minimum) {
|
||||
reasons.push(t('parameters.invoke.collectionNumberLTMin', { value, minimum }));
|
||||
}
|
||||
if (exclusiveMaximum !== undefined && num >= exclusiveMaximum) {
|
||||
reasons.push(t('parameters.invoke.collectionNumberGTExclusiveMax', { value, exclusiveMaximum }));
|
||||
}
|
||||
if (exclusiveMinimum !== undefined && num <= exclusiveMinimum) {
|
||||
reasons.push(t('parameters.invoke.collectionNumberLTExclusiveMin', { value, exclusiveMinimum }));
|
||||
}
|
||||
if (multipleOf !== undefined && num % multipleOf !== 0) {
|
||||
reasons.push(t('parameters.invoke.collectionNumberNotMultipleOf', { value, multipleOf }));
|
||||
}
|
||||
}
|
||||
|
||||
return reasons;
|
||||
};
|
||||
@@ -1,29 +0,0 @@
|
||||
import { z } from 'zod';
|
||||
|
||||
export const zFloatRangeStartStepCountGenerator = z.object({
|
||||
type: z.literal('float-range-generator-start-step-count').default('float-range-generator-start-step-count'),
|
||||
start: z.number().default(0),
|
||||
step: z.number().default(1),
|
||||
count: z.number().int().default(10),
|
||||
});
|
||||
export type FloatRangeStartStepCountGenerator = z.infer<typeof zFloatRangeStartStepCountGenerator>;
|
||||
export const floatRangeStartStepCountGenerator = (generator: FloatRangeStartStepCountGenerator): number[] => {
|
||||
const { start, step, count } = generator;
|
||||
return Array.from({ length: count }, (_, i) => start + i * step);
|
||||
};
|
||||
export const getDefaultFloatRangeStartStepCountGenerator = (): FloatRangeStartStepCountGenerator =>
|
||||
zFloatRangeStartStepCountGenerator.parse({});
|
||||
|
||||
export const zIntegerRangeStartStepCountGenerator = z.object({
|
||||
type: z.literal('integer-range-generator-start-step-count').default('integer-range-generator-start-step-count'),
|
||||
start: z.number().int().default(0),
|
||||
step: z.number().int().default(1),
|
||||
count: z.number().int().default(10),
|
||||
});
|
||||
export type IntegerRangeStartStepCountGenerator = z.infer<typeof zIntegerRangeStartStepCountGenerator>;
|
||||
export const integerRangeStartStepCountGenerator = (generator: IntegerRangeStartStepCountGenerator): number[] => {
|
||||
const { start, step, count } = generator;
|
||||
return Array.from({ length: count }, (_, i) => start + i * step);
|
||||
};
|
||||
export const getDefaultIntegerRangeStartStepCountGenerator = (): IntegerRangeStartStepCountGenerator =>
|
||||
zIntegerRangeStartStepCountGenerator.parse({});
|
||||
@@ -91,15 +91,3 @@ const zInvocationNodeEdgeExtra = z.object({
|
||||
type InvocationNodeEdgeExtra = z.infer<typeof zInvocationNodeEdgeExtra>;
|
||||
export type InvocationNodeEdge = Edge<InvocationNodeEdgeExtra>;
|
||||
// #endregion
|
||||
|
||||
export const isBatchNode = (node: InvocationNode) => {
|
||||
switch (node.data.type) {
|
||||
case 'image_batch':
|
||||
case 'string_batch':
|
||||
case 'integer_batch':
|
||||
case 'float_batch':
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { NodesState } from 'features/nodes/store/types';
|
||||
import { isFloatFieldCollectionInputInstance, isIntegerFieldCollectionInputInstance } from 'features/nodes/types/field';
|
||||
import { resolveNumberFieldCollectionValue } from 'features/nodes/types/fieldValidators';
|
||||
import { isBatchNode, isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { negate, omit, reduce } from 'lodash-es';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { omit, reduce } from 'lodash-es';
|
||||
import type { AnyInvocation, Graph } from 'services/api/types';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
@@ -16,7 +14,7 @@ export const buildNodesGraph = (nodesState: NodesState): Graph => {
|
||||
const { nodes, edges } = nodesState;
|
||||
|
||||
// Exclude all batch nodes - we will handle these in the batch setup in a diff function
|
||||
const filteredNodes = nodes.filter(isInvocationNode).filter(negate(isBatchNode));
|
||||
const filteredNodes = nodes.filter(isInvocationNode).filter((node) => node.data.type !== 'image_batch');
|
||||
|
||||
// Reduce the node editor nodes into invocation graph nodes
|
||||
const parsedNodes = filteredNodes.reduce<NonNullable<Graph['nodes']>>((nodesAccumulator, node) => {
|
||||
@@ -27,11 +25,7 @@ export const buildNodesGraph = (nodesState: NodesState): Graph => {
|
||||
const transformedInputs = reduce(
|
||||
inputs,
|
||||
(inputsAccumulator, input, name) => {
|
||||
if (isFloatFieldCollectionInputInstance(input) || isIntegerFieldCollectionInputInstance(input)) {
|
||||
inputsAccumulator[name] = resolveNumberFieldCollectionValue(input);
|
||||
} else {
|
||||
inputsAccumulator[name] = input.value;
|
||||
}
|
||||
inputsAccumulator[name] = input.value;
|
||||
|
||||
return inputsAccumulator;
|
||||
},
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user