mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-16 16:07:54 -05:00
Compare commits
137 Commits
ryan/conse
...
ryan/flex-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
231099f913 | ||
|
|
66bc225bd3 | ||
|
|
7535d2e188 | ||
|
|
3dff87aeee | ||
|
|
b14bf1e0f4 | ||
|
|
4fdc6eec9d | ||
|
|
180a67d11b | ||
|
|
ec816d3c04 | ||
|
|
7dcc2dafbc | ||
|
|
81da5210f0 | ||
|
|
eb976a2ab0 | ||
|
|
724028d974 | ||
|
|
43c98fd99e | ||
|
|
526d64a5e2 | ||
|
|
58c6c6db53 | ||
|
|
8a41e09de3 | ||
|
|
c24eae1968 | ||
|
|
a6b207a0d9 | ||
|
|
eea5ecdd69 | ||
|
|
50de54dcfd | ||
|
|
04b893f982 | ||
|
|
4c655eeb48 | ||
|
|
298abab883 | ||
|
|
bd477ded2e | ||
|
|
0b64d21980 | ||
|
|
91d5f8537d | ||
|
|
e498e1f07c | ||
|
|
73a3f195dc | ||
|
|
8cc790a030 | ||
|
|
57265c8869 | ||
|
|
66d08eaa1c | ||
|
|
d69e90ca5e | ||
|
|
f345fde512 | ||
|
|
508c702289 | ||
|
|
8fbd2f9a97 | ||
|
|
bfb26af36a | ||
|
|
4400bc69f2 | ||
|
|
10f2c0dc9a | ||
|
|
5b0326fc49 | ||
|
|
2f9a0a250d | ||
|
|
5d03328dc6 | ||
|
|
1fb32aec28 | ||
|
|
2bbcd42036 | ||
|
|
2f40f7bafd | ||
|
|
65dd01bf3a | ||
|
|
81fc525f8a | ||
|
|
d2dd5ee408 | ||
|
|
b4b1daeb26 | ||
|
|
90c4c10e14 | ||
|
|
30e33d30d5 | ||
|
|
3df3be6c34 | ||
|
|
4e917bf2b2 | ||
|
|
26e6e28a13 | ||
|
|
f9cee42a06 | ||
|
|
1b8da023b8 | ||
|
|
05f1026812 | ||
|
|
ca1bd254ea | ||
|
|
29645326b9 | ||
|
|
c23a2abc82 | ||
|
|
803ec8e904 | ||
|
|
0abc0be931 | ||
|
|
edff16124f | ||
|
|
2e4110a29a | ||
|
|
7ee51f3e14 | ||
|
|
8ae75dbc35 | ||
|
|
9265716b07 | ||
|
|
27b9c07711 | ||
|
|
9dcbe3cc8f | ||
|
|
30165f66c3 | ||
|
|
deb70edc75 | ||
|
|
d82d990b23 | ||
|
|
2c64b60d32 | ||
|
|
4e8c6d931d | ||
|
|
9049e6e0f3 | ||
|
|
3cb5f8536b | ||
|
|
38e50cc7aa | ||
|
|
5bff6123b9 | ||
|
|
d63ff560d6 | ||
|
|
acceac8304 | ||
|
|
96671d12bd | ||
|
|
584601d03f | ||
|
|
b1c4ec0888 | ||
|
|
db5f016826 | ||
|
|
c1fd28472d | ||
|
|
0c5958675a | ||
|
|
912e07f2c8 | ||
|
|
f853b24868 | ||
|
|
4f900b22dc | ||
|
|
5823532941 | ||
|
|
bfe6d98cba | ||
|
|
c26b3cd54f | ||
|
|
c012d832d2 | ||
|
|
9d11d2aabd | ||
|
|
a5f1587ce7 | ||
|
|
0b26bb1ca3 | ||
|
|
0f1e632117 | ||
|
|
b212332b3e | ||
|
|
90a91ff438 | ||
|
|
b52b271dc4 | ||
|
|
e077fe8046 | ||
|
|
368957b208 | ||
|
|
27277e1fd6 | ||
|
|
236c0d89e7 | ||
|
|
b807170701 | ||
|
|
c5d2de3169 | ||
|
|
f7511bfd94 | ||
|
|
0abb5ea114 | ||
|
|
ce57c4ed2e | ||
|
|
0cf51cefe8 | ||
|
|
e5e848d239 | ||
|
|
da589b3f1f | ||
|
|
36a3869af0 | ||
|
|
c76d08d1fd | ||
|
|
04087c38ce | ||
|
|
b2bb359d47 | ||
|
|
b57aa06d9e | ||
|
|
f856246c36 | ||
|
|
195df2ebe6 | ||
|
|
7b5cef6bd7 | ||
|
|
69e7ffaaf5 | ||
|
|
993401ad6c | ||
|
|
8d570dcffc | ||
|
|
3f70e947fd | ||
|
|
157290bef4 | ||
|
|
b7389da89b | ||
|
|
254b89b1f5 | ||
|
|
2b122d7882 | ||
|
|
ded9213eb4 | ||
|
|
9d51eb49cd | ||
|
|
0a6e22bc9e | ||
|
|
b301785dc8 | ||
|
|
edcdff4f78 | ||
|
|
66e04ea7ab | ||
|
|
497bc916cc | ||
|
|
ebe1873712 | ||
|
|
59926c320c | ||
|
|
2d3e2f1907 |
@@ -28,11 +28,12 @@ It is possible to fine-tune the settings for best performance or if you still ge
|
||||
|
||||
## Details and fine-tuning
|
||||
|
||||
Low-VRAM mode involves 3 features, each of which can be configured or fine-tuned:
|
||||
Low-VRAM mode involves 4 features, each of which can be configured or fine-tuned:
|
||||
|
||||
- Partial model loading
|
||||
- Dynamic RAM and VRAM cache sizes
|
||||
- Working memory
|
||||
- Partial model loading (`enable_partial_loading`)
|
||||
- Dynamic RAM and VRAM cache sizes (`max_cache_ram_gb`, `max_cache_vram_gb`)
|
||||
- Working memory (`device_working_mem_gb`)
|
||||
- Keeping a RAM weight copy (`keep_ram_copy_of_weights`)
|
||||
|
||||
Read on to learn about these features and understand how to fine-tune them for your system and use-cases.
|
||||
|
||||
@@ -67,12 +68,20 @@ As of v5.6.0, the caches are dynamically sized. The `ram` and `vram` settings ar
|
||||
But, if your GPU has enough VRAM to hold models fully, you might get a perf boost by manually setting the cache sizes in `invokeai.yaml`:
|
||||
|
||||
```yaml
|
||||
# Set the RAM cache size to as large as possible, leaving a few GB free for the rest of your system and Invoke.
|
||||
# For example, if your system has 32GB RAM, 28GB is a good value.
|
||||
# The default max cache RAM size is logged on InvokeAI startup. It is determined based on your system RAM / VRAM.
|
||||
# You can override the default value by setting `max_cache_ram_gb`.
|
||||
# Increasing `max_cache_ram_gb` will increase the amount of RAM used to cache inactive models, resulting in faster model
|
||||
# reloads for the cached models.
|
||||
# As an example, if your system has 32GB of RAM and no other heavy processes, setting the `max_cache_ram_gb` to 28GB
|
||||
# might be a good value to achieve aggressive model caching.
|
||||
max_cache_ram_gb: 28
|
||||
# Set the VRAM cache size to be as large as possible while leaving enough room for the working memory of the tasks you will be doing.
|
||||
# For example, on a 24GB GPU that will be running unquantized FLUX without any auxiliary models,
|
||||
# 18GB is a good value.
|
||||
# The default max cache VRAM size is adjusted dynamically based on the amount of available VRAM (taking into
|
||||
# consideration the VRAM used by other processes).
|
||||
# You can override the default value by setting `max_cache_vram_gb`. Note that this value takes precedence over the
|
||||
# `device_working_mem_gb`.
|
||||
# It is recommended to set the VRAM cache size to be as large as possible while leaving enough room for the working
|
||||
# memory of the tasks you will be doing. For example, on a 24GB GPU that will be running unquantized FLUX without any
|
||||
# auxiliary models, 18GB might be a good value.
|
||||
max_cache_vram_gb: 18
|
||||
```
|
||||
|
||||
@@ -109,6 +118,15 @@ device_working_mem_gb: 4
|
||||
|
||||
Once decoding completes, the model manager "reclaims" the extra VRAM allocated as working memory for future model loading operations.
|
||||
|
||||
### Keeping a RAM weight copy
|
||||
|
||||
Invoke has the option of keeping a RAM copy of all model weights, even when they are loaded onto the GPU. This optimization is _on_ by default, and enables faster model switching and LoRA patching. Disabling this feature will reduce the average RAM load while running Invoke (peak RAM likely won't change), at the cost of slower model switching and LoRA patching. If you have limited RAM, you can disable this optimization:
|
||||
|
||||
```yaml
|
||||
# Set to false to reduce the average RAM usage at the cost of slower model switching and LoRA patching.
|
||||
keep_ram_copy_of_weights: false
|
||||
```
|
||||
|
||||
### Disabling Nvidia sysmem fallback (Windows only)
|
||||
|
||||
On Windows, Nvidia GPUs are able to use system RAM when their VRAM fills up via **sysmem fallback**. While it sounds like a good idea on the surface, in practice it causes massive slowdowns during generation.
|
||||
@@ -127,3 +145,19 @@ It is strongly suggested to disable this feature:
|
||||
If the sysmem fallback feature sounds familiar, that's because Invoke's partial model loading strategy is conceptually very similar - use VRAM when there's room, else fall back to RAM.
|
||||
|
||||
Unfortunately, the Nvidia implementation is not optimized for applications like Invoke and does more harm than good.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Windows page file
|
||||
|
||||
Invoke has high virtual memory (a.k.a. 'committed memory') requirements. This can cause issues on Windows if the page file size limits are hit. (See this issue for the technical details on why this happens: https://github.com/invoke-ai/InvokeAI/issues/7563).
|
||||
|
||||
If you run out of page file space, InvokeAI may crash. Often, these crashes will happen with one of the following errors:
|
||||
|
||||
- InvokeAI exits with Windows error code `3221225477`
|
||||
- InvokeAI crashes without an error, but `eventvwr.msc` reveals an error with code `0xc0000005` (the hex equivalent of `3221225477`)
|
||||
|
||||
If you are running out of page file space, try the following solutions:
|
||||
|
||||
- Make sure that you have sufficient disk space for the page file to grow. Watch your disk usage as Invoke runs. If it climbs near 100% leading up to the crash, then this is very likely the source of the issue. Clear out some disk space to resolve the issue.
|
||||
- Make sure that your page file is set to "System managed size" (this is the default) rather than a custom size. Under the "System managed size" policy, the page file will grow dynamically as needed.
|
||||
|
||||
@@ -25,6 +25,7 @@ async def parse_dynamicprompts(
|
||||
prompt: str = Body(description="The prompt to parse with dynamicprompts"),
|
||||
max_prompts: int = Body(ge=1, le=10000, default=1000, description="The max number of prompts to generate"),
|
||||
combinatorial: bool = Body(default=True, description="Whether to use the combinatorial generator"),
|
||||
seed: int | None = Body(None, description="The seed to use for random generation. Only used if not combinatorial"),
|
||||
) -> DynamicPromptsResponse:
|
||||
"""Creates a batch process"""
|
||||
max_prompts = min(max_prompts, 10000)
|
||||
@@ -35,7 +36,7 @@ async def parse_dynamicprompts(
|
||||
generator = CombinatorialPromptGenerator()
|
||||
prompts = generator.generate(prompt, max_prompts=max_prompts)
|
||||
else:
|
||||
generator = RandomPromptGenerator()
|
||||
generator = RandomPromptGenerator(seed=seed)
|
||||
prompts = generator.generate(prompt, num_images=max_prompts)
|
||||
except ParseException as e:
|
||||
prompts = [prompt]
|
||||
|
||||
235
invokeai/app/invocations/batch.py
Normal file
235
invokeai/app/invocations/batch.py
Normal file
@@ -0,0 +1,235 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import (
|
||||
ImageField,
|
||||
Input,
|
||||
InputField,
|
||||
OutputField,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import (
|
||||
FloatOutput,
|
||||
ImageOutput,
|
||||
IntegerOutput,
|
||||
StringOutput,
|
||||
)
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
BATCH_GROUP_IDS = Literal[
|
||||
"None",
|
||||
"Group 1",
|
||||
"Group 2",
|
||||
"Group 3",
|
||||
"Group 4",
|
||||
"Group 5",
|
||||
]
|
||||
|
||||
|
||||
class NotExecutableNodeError(Exception):
|
||||
def __init__(self, message: str = "This class should never be executed or instantiated directly."):
|
||||
super().__init__(message)
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class BaseBatchInvocation(BaseInvocation):
|
||||
batch_group_id: BATCH_GROUP_IDS = InputField(
|
||||
default="None",
|
||||
description="The ID of this batch node's group. If provided, all batch nodes in with the same ID will be 'zipped' before execution, and all nodes' collections must be of the same size.",
|
||||
input=Input.Direct,
|
||||
title="Batch Group",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
raise NotExecutableNodeError()
|
||||
|
||||
|
||||
@invocation(
|
||||
"image_batch",
|
||||
title="Image Batch",
|
||||
tags=["primitives", "image", "batch", "special"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
class ImageBatchInvocation(BaseBatchInvocation):
|
||||
"""Create a batched generation, where the workflow is executed once for each image in the batch."""
|
||||
|
||||
images: list[ImageField] = InputField(
|
||||
default=[], min_length=1, description="The images to batch over", input=Input.Direct
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
raise NotExecutableNodeError()
|
||||
|
||||
|
||||
@invocation(
|
||||
"string_batch",
|
||||
title="String Batch",
|
||||
tags=["primitives", "string", "batch", "special"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
class StringBatchInvocation(BaseBatchInvocation):
|
||||
"""Create a batched generation, where the workflow is executed once for each string in the batch."""
|
||||
|
||||
strings: list[str] = InputField(
|
||||
default=[], min_length=1, description="The strings to batch over", input=Input.Direct
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> StringOutput:
|
||||
raise NotExecutableNodeError()
|
||||
|
||||
|
||||
@invocation_output("string_generator_output")
|
||||
class StringGeneratorOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of strings"""
|
||||
|
||||
strings: list[str] = OutputField(description="The generated strings")
|
||||
|
||||
|
||||
class StringGeneratorField(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
@invocation(
|
||||
"string_generator",
|
||||
title="String Generator",
|
||||
tags=["primitives", "string", "number", "batch", "special"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
class StringGenerator(BaseInvocation):
|
||||
"""Generated a range of strings for use in a batched generation"""
|
||||
|
||||
generator: StringGeneratorField = InputField(
|
||||
description="The string generator.",
|
||||
input=Input.Direct,
|
||||
title="Generator Type",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
raise NotExecutableNodeError()
|
||||
|
||||
def invoke(self, context: InvocationContext) -> StringGeneratorOutput:
|
||||
raise NotExecutableNodeError()
|
||||
|
||||
|
||||
@invocation(
|
||||
"integer_batch",
|
||||
title="Integer Batch",
|
||||
tags=["primitives", "integer", "number", "batch", "special"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
class IntegerBatchInvocation(BaseBatchInvocation):
|
||||
"""Create a batched generation, where the workflow is executed once for each integer in the batch."""
|
||||
|
||||
integers: list[int] = InputField(
|
||||
default=[],
|
||||
min_length=1,
|
||||
description="The integers to batch over",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||
raise NotExecutableNodeError()
|
||||
|
||||
|
||||
@invocation_output("integer_generator_output")
|
||||
class IntegerGeneratorOutput(BaseInvocationOutput):
|
||||
integers: list[int] = OutputField(description="The generated integers")
|
||||
|
||||
|
||||
class IntegerGeneratorField(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
@invocation(
|
||||
"integer_generator",
|
||||
title="Integer Generator",
|
||||
tags=["primitives", "int", "number", "batch", "special"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
class IntegerGenerator(BaseInvocation):
|
||||
"""Generated a range of integers for use in a batched generation"""
|
||||
|
||||
generator: IntegerGeneratorField = InputField(
|
||||
description="The integer generator.",
|
||||
input=Input.Direct,
|
||||
title="Generator Type",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
raise NotExecutableNodeError()
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntegerGeneratorOutput:
|
||||
raise NotExecutableNodeError()
|
||||
|
||||
|
||||
@invocation(
|
||||
"float_batch",
|
||||
title="Float Batch",
|
||||
tags=["primitives", "float", "number", "batch", "special"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
class FloatBatchInvocation(BaseBatchInvocation):
|
||||
"""Create a batched generation, where the workflow is executed once for each float in the batch."""
|
||||
|
||||
floats: list[float] = InputField(
|
||||
default=[],
|
||||
min_length=1,
|
||||
description="The floats to batch over",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatOutput:
|
||||
raise NotExecutableNodeError()
|
||||
|
||||
|
||||
@invocation_output("float_generator_output")
|
||||
class FloatGeneratorOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of floats"""
|
||||
|
||||
floats: list[float] = OutputField(description="The generated floats")
|
||||
|
||||
|
||||
class FloatGeneratorField(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
@invocation(
|
||||
"float_generator",
|
||||
title="Float Generator",
|
||||
tags=["primitives", "float", "number", "batch", "special"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
class FloatGenerator(BaseInvocation):
|
||||
"""Generated a range of floats for use in a batched generation"""
|
||||
|
||||
generator: FloatGeneratorField = InputField(
|
||||
description="The float generator.",
|
||||
input=Input.Direct,
|
||||
title="Generator Type",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
raise NotExecutableNodeError()
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatGeneratorOutput:
|
||||
raise NotExecutableNodeError()
|
||||
@@ -40,6 +40,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
@@ -85,6 +86,7 @@ def get_scheduler(
|
||||
scheduler_info: ModelIdentifierField,
|
||||
scheduler_name: str,
|
||||
seed: int,
|
||||
unet_config: AnyModelConfig,
|
||||
) -> Scheduler:
|
||||
"""Load a scheduler and apply some scheduler-specific overrides."""
|
||||
# TODO(ryand): Silently falling back to ddim seems like a bad idea. Look into why this was added and remove if
|
||||
@@ -103,6 +105,9 @@ def get_scheduler(
|
||||
"_backup": scheduler_config,
|
||||
}
|
||||
|
||||
if hasattr(unet_config, "prediction_type"):
|
||||
scheduler_config["prediction_type"] = unet_config.prediction_type
|
||||
|
||||
# make dpmpp_sde reproducable(seed can be passed only in initializer)
|
||||
if scheduler_class is DPMSolverSDEScheduler:
|
||||
scheduler_config["noise_sampler_seed"] = seed
|
||||
@@ -829,6 +834,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
||||
_, _, latent_height, latent_width = latents.shape
|
||||
|
||||
# get the unet's config so that we can pass the base to sd_step_callback()
|
||||
unet_config = context.models.get_config(self.unet.unet.key)
|
||||
|
||||
conditioning_data = self.get_conditioning_data(
|
||||
context=context,
|
||||
positive_conditioning_field=self.positive_conditioning,
|
||||
@@ -848,6 +856,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
seed=seed,
|
||||
unet_config=unet_config,
|
||||
)
|
||||
|
||||
timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
|
||||
@@ -859,9 +868,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
denoising_end=self.denoising_end,
|
||||
)
|
||||
|
||||
# get the unet's config so that we can pass the base to sd_step_callback()
|
||||
unet_config = context.models.get_config(self.unet.unet.key)
|
||||
|
||||
### preview
|
||||
def step_callback(state: PipelineIntermediateState) -> None:
|
||||
context.util.sd_step_callback(state, unet_config.base)
|
||||
@@ -1030,6 +1036,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
seed=seed,
|
||||
unet_config=unet_config,
|
||||
)
|
||||
|
||||
pipeline = self.create_pipeline(unet, scheduler)
|
||||
|
||||
@@ -10,6 +10,10 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.t5_model_identifier import (
|
||||
preprocess_t5_encoder_model_identifier,
|
||||
preprocess_t5_tokenizer_model_identifier,
|
||||
)
|
||||
from invokeai.backend.flux.util import max_seq_lengths
|
||||
from invokeai.backend.model_manager.config import (
|
||||
CheckpointConfigBase,
|
||||
@@ -74,8 +78,8 @@ class FluxModelLoaderInvocation(BaseInvocation):
|
||||
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
|
||||
tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
tokenizer2 = preprocess_t5_tokenizer_model_identifier(self.t5_encoder_model)
|
||||
t5_encoder = preprocess_t5_encoder_model_identifier(self.t5_encoder_model)
|
||||
|
||||
transformer_config = context.models.get_config(transformer)
|
||||
assert isinstance(transformer_config, CheckpointConfigBase)
|
||||
|
||||
@@ -2,7 +2,7 @@ from contextlib import ExitStack
|
||||
from typing import Iterator, Literal, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer, T5TokenizerFast
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
@@ -76,7 +76,7 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
context.models.load(self.t5_encoder.tokenizer) as t5_tokenizer,
|
||||
):
|
||||
assert isinstance(t5_text_encoder, T5EncoderModel)
|
||||
assert isinstance(t5_tokenizer, T5Tokenizer)
|
||||
assert isinstance(t5_tokenizer, (T5Tokenizer, T5TokenizerFast))
|
||||
|
||||
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ from invokeai.app.invocations.fields import (
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.misc import SEED_MAX
|
||||
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
||||
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
||||
|
||||
@@ -161,12 +162,12 @@ class ImagePasteInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
crop: bool = InputField(default=False, description="Crop to base image dimensions")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
base_image = context.images.get_pil(self.base_image.image_name)
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
base_image = context.images.get_pil(self.base_image.image_name, mode="RGBA")
|
||||
image = context.images.get_pil(self.image.image_name, mode="RGBA")
|
||||
mask = None
|
||||
if self.mask is not None:
|
||||
mask = context.images.get_pil(self.mask.image_name)
|
||||
mask = ImageOps.invert(mask.convert("L"))
|
||||
mask = context.images.get_pil(self.mask.image_name, mode="L")
|
||||
mask = ImageOps.invert(mask)
|
||||
# TODO: probably shouldn't invert mask here... should user be required to do it?
|
||||
|
||||
min_x = min(0, self.x)
|
||||
@@ -176,7 +177,11 @@ class ImagePasteInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
new_image = Image.new(mode="RGBA", size=(max_x - min_x, max_y - min_y), color=(0, 0, 0, 0))
|
||||
new_image.paste(base_image, (abs(min_x), abs(min_y)))
|
||||
new_image.paste(image, (max(0, self.x), max(0, self.y)), mask=mask)
|
||||
|
||||
# Create a temporary image to paste the image with transparency
|
||||
temp_image = Image.new("RGBA", new_image.size)
|
||||
temp_image.paste(image, (max(0, self.x), max(0, self.y)), mask=mask)
|
||||
new_image = Image.alpha_composite(new_image, temp_image)
|
||||
|
||||
if self.crop:
|
||||
base_w, base_h = base_image.size
|
||||
@@ -301,14 +306,44 @@ class ImageBlurInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
blur_type: Literal["gaussian", "box"] = InputField(default="gaussian", description="The type of blur")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
image = context.images.get_pil(self.image.image_name, mode="RGBA")
|
||||
|
||||
# Split the image into RGBA channels
|
||||
r, g, b, a = image.split()
|
||||
|
||||
# Premultiply RGB channels by alpha
|
||||
premultiplied_image = ImageChops.multiply(image, a.convert("RGBA"))
|
||||
premultiplied_image.putalpha(a)
|
||||
|
||||
# Apply the blur
|
||||
blur = (
|
||||
ImageFilter.GaussianBlur(self.radius) if self.blur_type == "gaussian" else ImageFilter.BoxBlur(self.radius)
|
||||
)
|
||||
blur_image = image.filter(blur)
|
||||
blurred_image = premultiplied_image.filter(blur)
|
||||
|
||||
image_dto = context.images.save(image=blur_image)
|
||||
# Split the blurred image into RGBA channels
|
||||
r, g, b, a_orig = blurred_image.split()
|
||||
|
||||
# Convert to float using NumPy. float 32/64 division are much faster than float 16
|
||||
r = numpy.array(r, dtype=numpy.float32)
|
||||
g = numpy.array(g, dtype=numpy.float32)
|
||||
b = numpy.array(b, dtype=numpy.float32)
|
||||
a = numpy.array(a_orig, dtype=numpy.float32) / 255.0 # Normalize alpha to [0, 1]
|
||||
|
||||
# Unpremultiply RGB channels by alpha
|
||||
r /= a + 1e-6 # Add a small epsilon to avoid division by zero
|
||||
g /= a + 1e-6
|
||||
b /= a + 1e-6
|
||||
|
||||
# Convert back to PIL images
|
||||
r = Image.fromarray(numpy.uint8(numpy.clip(r, 0, 255)))
|
||||
g = Image.fromarray(numpy.uint8(numpy.clip(g, 0, 255)))
|
||||
b = Image.fromarray(numpy.uint8(numpy.clip(b, 0, 255)))
|
||||
|
||||
# Merge back into a single image
|
||||
result_image = Image.merge("RGBA", (r, g, b, a_orig))
|
||||
|
||||
image_dto = context.images.save(image=result_image)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
@@ -1055,3 +1090,67 @@ class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
image_dto = context.images.save(image=generated_image)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
|
||||
@invocation(
|
||||
"img_noise",
|
||||
title="Add Image Noise",
|
||||
tags=["image", "noise"],
|
||||
category="image",
|
||||
version="1.0.1",
|
||||
)
|
||||
class ImageNoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Add noise to an image"""
|
||||
|
||||
image: ImageField = InputField(description="The image to add noise to")
|
||||
seed: int = InputField(
|
||||
default=0,
|
||||
ge=0,
|
||||
le=SEED_MAX,
|
||||
description=FieldDescriptions.seed,
|
||||
)
|
||||
noise_type: Literal["gaussian", "salt_and_pepper"] = InputField(
|
||||
default="gaussian",
|
||||
description="The type of noise to add",
|
||||
)
|
||||
amount: float = InputField(default=0.1, ge=0, le=1, description="The amount of noise to add")
|
||||
noise_color: bool = InputField(default=True, description="Whether to add colored noise")
|
||||
size: int = InputField(default=1, ge=1, description="The size of the noise points")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, mode="RGBA")
|
||||
|
||||
# Save out the alpha channel
|
||||
alpha = image.getchannel("A")
|
||||
|
||||
# Set the seed for numpy random
|
||||
rs = numpy.random.RandomState(numpy.random.MT19937(numpy.random.SeedSequence(self.seed)))
|
||||
|
||||
if self.noise_type == "gaussian":
|
||||
if self.noise_color:
|
||||
noise = rs.normal(0, 1, (image.height // self.size, image.width // self.size, 3)) * 255
|
||||
else:
|
||||
noise = rs.normal(0, 1, (image.height // self.size, image.width // self.size)) * 255
|
||||
noise = numpy.stack([noise] * 3, axis=-1)
|
||||
elif self.noise_type == "salt_and_pepper":
|
||||
if self.noise_color:
|
||||
noise = rs.choice(
|
||||
[0, 255], (image.height // self.size, image.width // self.size, 3), p=[1 - self.amount, self.amount]
|
||||
)
|
||||
else:
|
||||
noise = rs.choice(
|
||||
[0, 255], (image.height // self.size, image.width // self.size), p=[1 - self.amount, self.amount]
|
||||
)
|
||||
noise = numpy.stack([noise] * 3, axis=-1)
|
||||
|
||||
noise = Image.fromarray(noise.astype(numpy.uint8), mode="RGB").resize(
|
||||
(image.width, image.height), Image.Resampling.NEAREST
|
||||
)
|
||||
noisy_image = Image.blend(image.convert("RGB"), noise, self.amount).convert("RGBA")
|
||||
|
||||
# Paste back the alpha channel
|
||||
noisy_image.putalpha(alpha)
|
||||
|
||||
image_dto = context.images.save(image=noisy_image)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
@@ -7,7 +7,6 @@ import torch
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@@ -539,23 +538,3 @@ class BoundingBoxInvocation(BaseInvocation):
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
@invocation(
|
||||
"image_batch",
|
||||
title="Image Batch",
|
||||
tags=["primitives", "image", "batch", "internal"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
class ImageBatchInvocation(BaseInvocation):
|
||||
"""Create a batched generation, where the workflow is executed once for each image in the batch."""
|
||||
|
||||
images: list[ImageField] = InputField(min_length=1, description="The images to batch over", input=Input.Direct)
|
||||
|
||||
def __init__(self):
|
||||
raise NotImplementedError("This class should never be executed or instantiated directly.")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
raise NotImplementedError("This class should never be executed or instantiated directly.")
|
||||
|
||||
@@ -10,6 +10,10 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.t5_model_identifier import (
|
||||
preprocess_t5_encoder_model_identifier,
|
||||
preprocess_t5_tokenizer_model_identifier,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import SubModelType
|
||||
|
||||
|
||||
@@ -88,16 +92,8 @@ class Sd3ModelLoaderInvocation(BaseInvocation):
|
||||
if self.clip_g_model
|
||||
else self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
)
|
||||
tokenizer_t5 = (
|
||||
self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
|
||||
if self.t5_encoder_model
|
||||
else self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
|
||||
)
|
||||
t5_encoder = (
|
||||
self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
|
||||
if self.t5_encoder_model
|
||||
else self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
|
||||
)
|
||||
tokenizer_t5 = preprocess_t5_tokenizer_model_identifier(self.t5_encoder_model or self.model)
|
||||
t5_encoder = preprocess_t5_encoder_model_identifier(self.t5_encoder_model or self.model)
|
||||
|
||||
return Sd3ModelLoaderOutput(
|
||||
transformer=TransformerField(transformer=transformer, loras=[]),
|
||||
|
||||
@@ -218,6 +218,7 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
seed=seed,
|
||||
unet_config=unet_config,
|
||||
)
|
||||
pipeline = self.create_pipeline(unet=unet, scheduler=scheduler)
|
||||
|
||||
|
||||
@@ -87,6 +87,7 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
|
||||
device_working_mem_gb: The amount of working memory to keep available on the compute device (in GB). Has no effect if running on CPU. If you are experiencing OOM errors, try increasing this value.
|
||||
enable_partial_loading: Enable partial loading of models. This enables models to run with reduced VRAM requirements (at the cost of slower speed) by streaming the model from RAM to VRAM as its used. In some edge cases, partial loading can cause models to run more slowly if they were previously being fully loaded into VRAM.
|
||||
keep_ram_copy_of_weights: Whether to keep a full RAM copy of a model's weights when the model is loaded in VRAM. Keeping a RAM copy increases average RAM usage, but speeds up model switching and LoRA patching (assuming there is sufficient RAM). Set this to False if RAM pressure is consistently high.
|
||||
ram: DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_ram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.
|
||||
vram: DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_vram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.
|
||||
lazy_offload: DEPRECATED: This setting is no longer used. Lazy-offloading is enabled by default. This config setting will be removed once the new model cache behavior is stable.
|
||||
@@ -162,6 +163,7 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
log_memory_usage: bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.")
|
||||
device_working_mem_gb: float = Field(default=3, description="The amount of working memory to keep available on the compute device (in GB). Has no effect if running on CPU. If you are experiencing OOM errors, try increasing this value.")
|
||||
enable_partial_loading: bool = Field(default=False, description="Enable partial loading of models. This enables models to run with reduced VRAM requirements (at the cost of slower speed) by streaming the model from RAM to VRAM as its used. In some edge cases, partial loading can cause models to run more slowly if they were previously being fully loaded into VRAM.")
|
||||
keep_ram_copy_of_weights: bool = Field(default=True, description="Whether to keep a full RAM copy of a model's weights when the model is loaded in VRAM. Keeping a RAM copy increases average RAM usage, but speeds up model switching and LoRA patching (assuming there is sufficient RAM). Set this to False if RAM pressure is consistently high.")
|
||||
# Deprecated CACHE configs
|
||||
ram: Optional[float] = Field(default=None, gt=0, description="DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_ram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.")
|
||||
vram: Optional[float] = Field(default=None, ge=0, description="DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_vram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.")
|
||||
|
||||
@@ -84,6 +84,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
ram_cache = ModelCache(
|
||||
execution_device_working_mem_gb=app_config.device_working_mem_gb,
|
||||
enable_partial_loading=app_config.enable_partial_loading,
|
||||
keep_ram_copy_of_weights=app_config.keep_ram_copy_of_weights,
|
||||
max_ram_cache_size_gb=app_config.max_cache_ram_gb,
|
||||
max_vram_cache_size_gb=app_config.max_cache_vram_gb,
|
||||
execution_device=execution_device or TorchDevice.choose_torch_device(),
|
||||
|
||||
@@ -108,8 +108,16 @@ class Batch(BaseModel):
|
||||
return v
|
||||
for batch_data_list in v:
|
||||
for datum in batch_data_list:
|
||||
if not datum.items:
|
||||
continue
|
||||
|
||||
# Special handling for numbers - they can be mixed
|
||||
# TODO(psyche): Update BatchDatum to have a `type` field to specify the type of the items, then we can have strict float and int fields
|
||||
if all(isinstance(item, (int, float)) for item in datum.items):
|
||||
continue
|
||||
|
||||
# Get the type of the first item in the list
|
||||
first_item_type = type(datum.items[0]) if datum.items else None
|
||||
first_item_type = type(datum.items[0])
|
||||
for item in datum.items:
|
||||
if type(item) is not first_item_type:
|
||||
raise BatchItemsTypeError("All items in a batch must have the same type")
|
||||
|
||||
26
invokeai/app/util/t5_model_identifier.py
Normal file
26
invokeai/app/util/t5_model_identifier.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.backend.model_manager.config import BaseModelType, SubModelType
|
||||
|
||||
|
||||
def preprocess_t5_encoder_model_identifier(model_identifier: ModelIdentifierField) -> ModelIdentifierField:
|
||||
"""A helper function to normalize a T5 encoder model identifier so that T5 models associated with FLUX
|
||||
or SD3 models can be used interchangeably.
|
||||
"""
|
||||
if model_identifier.base == BaseModelType.Any:
|
||||
return model_identifier.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
elif model_identifier.base == BaseModelType.StableDiffusion3:
|
||||
return model_identifier.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
|
||||
else:
|
||||
raise ValueError(f"Unsupported model base: {model_identifier.base}")
|
||||
|
||||
|
||||
def preprocess_t5_tokenizer_model_identifier(model_identifier: ModelIdentifierField) -> ModelIdentifierField:
|
||||
"""A helper function to normalize a T5 tokenizer model identifier so that T5 models associated with FLUX
|
||||
or SD3 models can be used interchangeably.
|
||||
"""
|
||||
if model_identifier.base == BaseModelType.Any:
|
||||
return model_identifier.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||
elif model_identifier.base == BaseModelType.StableDiffusion3:
|
||||
return model_identifier.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
|
||||
else:
|
||||
raise ValueError(f"Unsupported model base: {model_identifier.base}")
|
||||
@@ -1,13 +1,19 @@
|
||||
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||
|
||||
from torch import Tensor, nn
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
class HFEncoder(nn.Module):
|
||||
def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool, max_length: int):
|
||||
def __init__(
|
||||
self,
|
||||
encoder: PreTrainedModel,
|
||||
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
|
||||
is_clip: bool,
|
||||
max_length: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.max_length = max_length
|
||||
self.is_clip = is_clip
|
||||
|
||||
@@ -9,12 +9,17 @@ class CachedModelOnlyFullLoad:
|
||||
MPS memory, etc.
|
||||
"""
|
||||
|
||||
def __init__(self, model: torch.nn.Module | Any, compute_device: torch.device, total_bytes: int):
|
||||
def __init__(
|
||||
self, model: torch.nn.Module | Any, compute_device: torch.device, total_bytes: int, keep_ram_copy: bool = False
|
||||
):
|
||||
"""Initialize a CachedModelOnlyFullLoad.
|
||||
Args:
|
||||
model (torch.nn.Module | Any): The model to wrap. Should be on the CPU.
|
||||
compute_device (torch.device): The compute device to move the model to.
|
||||
total_bytes (int): The total size (in bytes) of all the weights in the model.
|
||||
keep_ram_copy (bool): Whether to keep a read-only copy of the model's state dict in RAM. Keeping a RAM copy
|
||||
increases RAM usage, but speeds up model offload from VRAM and LoRA patching (assuming there is
|
||||
sufficient RAM).
|
||||
"""
|
||||
# model is often a torch.nn.Module, but could be any model type. Throughout this class, we handle both cases.
|
||||
self._model = model
|
||||
@@ -23,7 +28,7 @@ class CachedModelOnlyFullLoad:
|
||||
|
||||
# A CPU read-only copy of the model's state dict.
|
||||
self._cpu_state_dict: dict[str, torch.Tensor] | None = None
|
||||
if isinstance(model, torch.nn.Module):
|
||||
if isinstance(model, torch.nn.Module) and keep_ram_copy:
|
||||
self._cpu_state_dict = model.state_dict()
|
||||
|
||||
self._total_bytes = total_bytes
|
||||
|
||||
@@ -14,33 +14,38 @@ class CachedModelWithPartialLoad:
|
||||
MPS memory, etc.
|
||||
"""
|
||||
|
||||
def __init__(self, model: torch.nn.Module, compute_device: torch.device):
|
||||
def __init__(self, model: torch.nn.Module, compute_device: torch.device, keep_ram_copy: bool = False):
|
||||
self._model = model
|
||||
self._compute_device = compute_device
|
||||
|
||||
# A CPU read-only copy of the model's state dict.
|
||||
self._cpu_state_dict: dict[str, torch.Tensor] = model.state_dict()
|
||||
model_state_dict = model.state_dict()
|
||||
# A CPU read-only copy of the model's state dict. Used for faster model unloads from VRAM, and to speed up LoRA
|
||||
# patching. Set to `None` if keep_ram_copy is False.
|
||||
self._cpu_state_dict: dict[str, torch.Tensor] | None = model_state_dict if keep_ram_copy else None
|
||||
|
||||
# A dictionary of the size of each tensor in the state dict.
|
||||
# HACK(ryand): We use this dictionary any time we are doing byte tracking calculations. We do this for
|
||||
# consistency in case the application code has modified the model's size (e.g. by casting to a different
|
||||
# precision). Of course, this means that we are making model cache load/unload decisions based on model size
|
||||
# data that may not be fully accurate.
|
||||
self._state_dict_bytes = {k: calc_tensor_size(v) for k, v in self._cpu_state_dict.items()}
|
||||
self._state_dict_bytes = {k: calc_tensor_size(v) for k, v in model_state_dict.items()}
|
||||
|
||||
self._total_bytes = sum(self._state_dict_bytes.values())
|
||||
self._cur_vram_bytes: int | None = None
|
||||
|
||||
self._modules_that_support_autocast = self._find_modules_that_support_autocast()
|
||||
self._keys_in_modules_that_do_not_support_autocast = self._find_keys_in_modules_that_do_not_support_autocast()
|
||||
self._keys_in_modules_that_do_not_support_autocast = self._find_keys_in_modules_that_do_not_support_autocast(
|
||||
model_state_dict
|
||||
)
|
||||
self._state_dict_keys_by_module_prefix = self._group_state_dict_keys_by_module_prefix(model_state_dict)
|
||||
|
||||
def _find_modules_that_support_autocast(self) -> dict[str, torch.nn.Module]:
|
||||
"""Find all modules that support autocasting."""
|
||||
return {n: m for n, m in self._model.named_modules() if isinstance(m, CustomModuleMixin)} # type: ignore
|
||||
|
||||
def _find_keys_in_modules_that_do_not_support_autocast(self) -> set[str]:
|
||||
def _find_keys_in_modules_that_do_not_support_autocast(self, state_dict: dict[str, torch.Tensor]) -> set[str]:
|
||||
keys_in_modules_that_do_not_support_autocast: set[str] = set()
|
||||
for key in self._cpu_state_dict.keys():
|
||||
for key in state_dict.keys():
|
||||
for module_name in self._modules_that_support_autocast.keys():
|
||||
if key.startswith(module_name):
|
||||
break
|
||||
@@ -48,6 +53,47 @@ class CachedModelWithPartialLoad:
|
||||
keys_in_modules_that_do_not_support_autocast.add(key)
|
||||
return keys_in_modules_that_do_not_support_autocast
|
||||
|
||||
def _group_state_dict_keys_by_module_prefix(self, state_dict: dict[str, torch.Tensor]) -> dict[str, list[str]]:
|
||||
"""A helper function that groups state dict keys by module prefix.
|
||||
|
||||
Example:
|
||||
```
|
||||
state_dict = {
|
||||
"weight": ...,
|
||||
"module.submodule.weight": ...,
|
||||
"module.submodule.bias": ...,
|
||||
"module.other_submodule.weight": ...,
|
||||
"module.other_submodule.bias": ...,
|
||||
}
|
||||
|
||||
output = group_state_dict_keys_by_module_prefix(state_dict)
|
||||
|
||||
# The output will be:
|
||||
output = {
|
||||
"": [
|
||||
"weight",
|
||||
],
|
||||
"module.submodule": [
|
||||
"module.submodule.weight",
|
||||
"module.submodule.bias",
|
||||
],
|
||||
"module.other_submodule": [
|
||||
"module.other_submodule.weight",
|
||||
"module.other_submodule.bias",
|
||||
],
|
||||
}
|
||||
```
|
||||
"""
|
||||
state_dict_keys_by_module_prefix: dict[str, list[str]] = {}
|
||||
for key in state_dict.keys():
|
||||
split = key.rsplit(".", 1)
|
||||
# `split` will have length 1 if the root module has parameters.
|
||||
module_name = split[0] if len(split) > 1 else ""
|
||||
if module_name not in state_dict_keys_by_module_prefix:
|
||||
state_dict_keys_by_module_prefix[module_name] = []
|
||||
state_dict_keys_by_module_prefix[module_name].append(key)
|
||||
return state_dict_keys_by_module_prefix
|
||||
|
||||
def _move_non_persistent_buffers_to_device(self, device: torch.device):
|
||||
"""Move the non-persistent buffers to the target device. These buffers are not included in the state dict,
|
||||
so we need to move them manually.
|
||||
@@ -98,6 +144,82 @@ class CachedModelWithPartialLoad:
|
||||
"""Unload all weights from VRAM."""
|
||||
return self.partial_unload_from_vram(self.total_bytes())
|
||||
|
||||
def _load_state_dict_with_device_conversion(
|
||||
self, state_dict: dict[str, torch.Tensor], keys_to_convert: set[str], target_device: torch.device
|
||||
):
|
||||
if self._cpu_state_dict is not None:
|
||||
# Run the fast version.
|
||||
self._load_state_dict_with_fast_device_conversion(
|
||||
state_dict=state_dict,
|
||||
keys_to_convert=keys_to_convert,
|
||||
target_device=target_device,
|
||||
cpu_state_dict=self._cpu_state_dict,
|
||||
)
|
||||
else:
|
||||
# Run the low-virtual-memory version.
|
||||
self._load_state_dict_with_jit_device_conversion(
|
||||
state_dict=state_dict,
|
||||
keys_to_convert=keys_to_convert,
|
||||
target_device=target_device,
|
||||
)
|
||||
|
||||
def _load_state_dict_with_jit_device_conversion(
|
||||
self,
|
||||
state_dict: dict[str, torch.Tensor],
|
||||
keys_to_convert: set[str],
|
||||
target_device: torch.device,
|
||||
):
|
||||
"""A custom state dict loading implementation with good peak memory properties.
|
||||
|
||||
This implementation has the important property that it copies parameters to the target device one module at a time
|
||||
rather than applying all of the device conversions and then calling load_state_dict(). This is done to minimize the
|
||||
peak virtual memory usage. Specifically, we want to avoid a case where we hold references to all of the CPU weights
|
||||
and CUDA weights simultaneously, because Windows will reserve virtual memory for both.
|
||||
"""
|
||||
for module_name, module in self._model.named_modules():
|
||||
module_keys = self._state_dict_keys_by_module_prefix.get(module_name, [])
|
||||
# Calculate the length of the module name prefix.
|
||||
prefix_len = len(module_name)
|
||||
if prefix_len > 0:
|
||||
prefix_len += 1
|
||||
|
||||
module_state_dict = {}
|
||||
for key in module_keys:
|
||||
if key in keys_to_convert:
|
||||
# It is important that we overwrite `state_dict[key]` to avoid keeping two copies of the same
|
||||
# parameter.
|
||||
state_dict[key] = state_dict[key].to(target_device)
|
||||
# Note that we keep parameters that have not been moved to a new device in case the module implements
|
||||
# weird custom state dict loading logic that requires all parameters to be present.
|
||||
module_state_dict[key[prefix_len:]] = state_dict[key]
|
||||
|
||||
if len(module_state_dict) > 0:
|
||||
# We set strict=False, because if `module` has both parameters and child modules, then we are loading a
|
||||
# state dict that only contains the parameters of `module` (not its children).
|
||||
# We assume that it is rare for non-leaf modules to have parameters. Calling load_state_dict() on non-leaf
|
||||
# modules will recurse through all of the children, so is a bit wasteful.
|
||||
incompatible_keys = module.load_state_dict(module_state_dict, strict=False, assign=True)
|
||||
# Missing keys are ok, unexpected keys are not.
|
||||
assert len(incompatible_keys.unexpected_keys) == 0
|
||||
|
||||
def _load_state_dict_with_fast_device_conversion(
|
||||
self,
|
||||
state_dict: dict[str, torch.Tensor],
|
||||
keys_to_convert: set[str],
|
||||
target_device: torch.device,
|
||||
cpu_state_dict: dict[str, torch.Tensor],
|
||||
):
|
||||
"""Convert parameters to the target device and load them into the model. Leverages the `cpu_state_dict` to speed
|
||||
up transfers of weights to the CPU.
|
||||
"""
|
||||
for key in keys_to_convert:
|
||||
if target_device.type == "cpu":
|
||||
state_dict[key] = cpu_state_dict[key]
|
||||
else:
|
||||
state_dict[key] = state_dict[key].to(target_device)
|
||||
|
||||
self._model.load_state_dict(state_dict, assign=True)
|
||||
|
||||
@torch.no_grad()
|
||||
def partial_load_to_vram(self, vram_bytes_to_load: int) -> int:
|
||||
"""Load more weights into VRAM without exceeding vram_bytes_to_load.
|
||||
@@ -112,26 +234,33 @@ class CachedModelWithPartialLoad:
|
||||
|
||||
cur_state_dict = self._model.state_dict()
|
||||
|
||||
# Identify the keys that will be loaded into VRAM.
|
||||
keys_to_load: set[str] = set()
|
||||
|
||||
# First, process the keys that *must* be loaded into VRAM.
|
||||
for key in self._keys_in_modules_that_do_not_support_autocast:
|
||||
param = cur_state_dict[key]
|
||||
if param.device.type == self._compute_device.type:
|
||||
continue
|
||||
|
||||
keys_to_load.add(key)
|
||||
param_size = self._state_dict_bytes[key]
|
||||
cur_state_dict[key] = param.to(self._compute_device, copy=True)
|
||||
vram_bytes_loaded += param_size
|
||||
|
||||
if vram_bytes_loaded > vram_bytes_to_load:
|
||||
logger = InvokeAILogger.get_logger()
|
||||
logger.warning(
|
||||
f"Loaded {vram_bytes_loaded / 2**20} MB into VRAM, but only {vram_bytes_to_load / 2**20} MB were "
|
||||
f"Loading {vram_bytes_loaded / 2**20} MB into VRAM, but only {vram_bytes_to_load / 2**20} MB were "
|
||||
"requested. This is the minimum set of weights in VRAM required to run the model."
|
||||
)
|
||||
|
||||
# Next, process the keys that can optionally be loaded into VRAM.
|
||||
fully_loaded = True
|
||||
for key, param in cur_state_dict.items():
|
||||
# Skip the keys that have already been processed above.
|
||||
if key in keys_to_load:
|
||||
continue
|
||||
|
||||
if param.device.type == self._compute_device.type:
|
||||
continue
|
||||
|
||||
@@ -142,14 +271,14 @@ class CachedModelWithPartialLoad:
|
||||
fully_loaded = False
|
||||
continue
|
||||
|
||||
cur_state_dict[key] = param.to(self._compute_device, copy=True)
|
||||
keys_to_load.add(key)
|
||||
vram_bytes_loaded += param_size
|
||||
|
||||
if vram_bytes_loaded > 0:
|
||||
if len(keys_to_load) > 0:
|
||||
# We load the entire state dict, not just the parameters that changed, in case there are modules that
|
||||
# override _load_from_state_dict() and do some funky stuff that requires the entire state dict.
|
||||
# Alternatively, in the future, grouping parameters by module could probably solve this problem.
|
||||
self._model.load_state_dict(cur_state_dict, assign=True)
|
||||
self._load_state_dict_with_device_conversion(cur_state_dict, keys_to_load, self._compute_device)
|
||||
|
||||
if self._cur_vram_bytes is not None:
|
||||
self._cur_vram_bytes += vram_bytes_loaded
|
||||
@@ -180,6 +309,10 @@ class CachedModelWithPartialLoad:
|
||||
|
||||
offload_device = "cpu"
|
||||
cur_state_dict = self._model.state_dict()
|
||||
|
||||
# Identify the keys that will be offloaded to CPU.
|
||||
keys_to_offload: set[str] = set()
|
||||
|
||||
for key, param in cur_state_dict.items():
|
||||
if vram_bytes_freed >= vram_bytes_to_free:
|
||||
break
|
||||
@@ -191,11 +324,11 @@ class CachedModelWithPartialLoad:
|
||||
required_weights_in_vram += self._state_dict_bytes[key]
|
||||
continue
|
||||
|
||||
cur_state_dict[key] = self._cpu_state_dict[key]
|
||||
keys_to_offload.add(key)
|
||||
vram_bytes_freed += self._state_dict_bytes[key]
|
||||
|
||||
if vram_bytes_freed > 0:
|
||||
self._model.load_state_dict(cur_state_dict, assign=True)
|
||||
if len(keys_to_offload) > 0:
|
||||
self._load_state_dict_with_device_conversion(cur_state_dict, keys_to_offload, torch.device("cpu"))
|
||||
|
||||
if self._cur_vram_bytes is not None:
|
||||
self._cur_vram_bytes -= vram_bytes_freed
|
||||
|
||||
@@ -78,6 +78,7 @@ class ModelCache:
|
||||
self,
|
||||
execution_device_working_mem_gb: float,
|
||||
enable_partial_loading: bool,
|
||||
keep_ram_copy_of_weights: bool,
|
||||
max_ram_cache_size_gb: float | None = None,
|
||||
max_vram_cache_size_gb: float | None = None,
|
||||
execution_device: torch.device | str = "cuda",
|
||||
@@ -105,6 +106,7 @@ class ModelCache:
|
||||
:param logger: InvokeAILogger to use (otherwise creates one)
|
||||
"""
|
||||
self._enable_partial_loading = enable_partial_loading
|
||||
self._keep_ram_copy_of_weights = keep_ram_copy_of_weights
|
||||
self._execution_device_working_mem_gb = execution_device_working_mem_gb
|
||||
self._execution_device: torch.device = torch.device(execution_device)
|
||||
self._storage_device: torch.device = torch.device(storage_device)
|
||||
@@ -121,6 +123,8 @@ class ModelCache:
|
||||
self._cached_models: Dict[str, CacheRecord] = {}
|
||||
self._cache_stack: List[str] = []
|
||||
|
||||
self._ram_cache_size_bytes = self._calc_ram_available_to_model_cache()
|
||||
|
||||
@property
|
||||
def stats(self) -> Optional[CacheStats]:
|
||||
"""Return collected CacheStats object."""
|
||||
@@ -154,9 +158,13 @@ class ModelCache:
|
||||
|
||||
# Wrap model.
|
||||
if isinstance(model, torch.nn.Module) and running_with_cuda and self._enable_partial_loading:
|
||||
wrapped_model = CachedModelWithPartialLoad(model, self._execution_device)
|
||||
wrapped_model = CachedModelWithPartialLoad(
|
||||
model, self._execution_device, keep_ram_copy=self._keep_ram_copy_of_weights
|
||||
)
|
||||
else:
|
||||
wrapped_model = CachedModelOnlyFullLoad(model, self._execution_device, size)
|
||||
wrapped_model = CachedModelOnlyFullLoad(
|
||||
model, self._execution_device, size, keep_ram_copy=self._keep_ram_copy_of_weights
|
||||
)
|
||||
|
||||
cache_record = CacheRecord(key=key, cached_model=wrapped_model)
|
||||
self._cached_models[key] = cache_record
|
||||
@@ -382,41 +390,89 @@ class ModelCache:
|
||||
# Alternative definition of VRAM in use:
|
||||
# return sum(ce.cached_model.cur_vram_bytes() for ce in self._cached_models.values())
|
||||
|
||||
def _get_ram_available(self) -> int:
|
||||
"""Get the amount of RAM available for the cache to use, while keeping memory pressure under control."""
|
||||
def _calc_ram_available_to_model_cache(self) -> int:
|
||||
"""Calculate the amount of RAM available for the cache to use."""
|
||||
# If self._max_ram_cache_size_gb is set, then it overrides the default logic.
|
||||
if self._max_ram_cache_size_gb is not None:
|
||||
ram_total_available_to_cache = int(self._max_ram_cache_size_gb * GB)
|
||||
return ram_total_available_to_cache - self._get_ram_in_use()
|
||||
self._logger.info(f"Using user-defined RAM cache size: {self._max_ram_cache_size_gb} GB.")
|
||||
return int(self._max_ram_cache_size_gb * GB)
|
||||
|
||||
virtual_memory = psutil.virtual_memory()
|
||||
ram_total = virtual_memory.total
|
||||
ram_available = virtual_memory.available
|
||||
ram_used = ram_total - ram_available
|
||||
# Heuristics for dynamically calculating the RAM cache size, **in order of increasing priority**:
|
||||
# 1. As an initial default, use 50% of the total RAM for InvokeAI.
|
||||
# - Assume a 2GB baseline for InvokeAI's non-model RAM usage, and use the rest of the RAM for the model cache.
|
||||
# 2. On a system with a lot of RAM (e.g. 64GB+), users probably don't want InvokeAI to eat up too much RAM.
|
||||
# There are diminishing returns to storing more and more models. So, we apply an upper bound.
|
||||
# - On systems without a CUDA device, the upper bound is 32GB.
|
||||
# - On systems with a CUDA device, the upper bound is 2x the amount of VRAM.
|
||||
# 3. On systems with a CUDA device, the minimum should be the VRAM size (less the working memory).
|
||||
# - Setting lower than this would mean that we sometimes kick models out of the cache when there is room for
|
||||
# all models in VRAM.
|
||||
# - Consider an extreme case of a system with 8GB RAM / 24GB VRAM. I haven't tested this, but I think
|
||||
# you'd still want the RAM cache size to be ~24GB (less the working memory). (Though you'd probably want to
|
||||
# set `keep_ram_copy_of_weights: false` in this case.)
|
||||
# 4. Absolute minimum of 4GB.
|
||||
|
||||
# The total size of all the models in the cache will often be larger than the amount of RAM reported by psutil
|
||||
# (due to lazy-loading and OS RAM caching behaviour). We could just rely on the psutil values, but it feels
|
||||
# like a bad idea to over-fill the model cache. So, for now, we'll try to keep the total size of models in the
|
||||
# cache under the total amount of system RAM.
|
||||
cache_ram_used = self._get_ram_in_use()
|
||||
ram_used = max(cache_ram_used, ram_used)
|
||||
# NOTE(ryand): We explored dynamically adjusting the RAM cache size based on memory pressure (using psutil), but
|
||||
# decided against it for now, for the following reasons:
|
||||
# - It was surprisingly difficult to get memory metrics with consistent definitions across OSes. (If you go
|
||||
# down this path again, don't underestimate the amount of complexity here and be sure to test rigorously on all
|
||||
# OSes.)
|
||||
# - Making the RAM cache size dynamic opens the door for performance regressions that are hard to diagnose and
|
||||
# hard for users to understand. It is better for users to see that their RAM is maxed out, and then override
|
||||
# the default value if desired.
|
||||
|
||||
# Aim to keep 10% of RAM free.
|
||||
ram_available_based_on_memory_usage = int(ram_total * 0.9) - ram_used
|
||||
# Lookup the total VRAM size for the CUDA execution device.
|
||||
total_cuda_vram_bytes: int | None = None
|
||||
if self._execution_device.type == "cuda":
|
||||
_, total_cuda_vram_bytes = torch.cuda.mem_get_info(self._execution_device)
|
||||
|
||||
# If we are running out of RAM, then there's an increased likelihood that we will run into this issue:
|
||||
# https://github.com/invoke-ai/InvokeAI/issues/7513
|
||||
# To keep things running smoothly, there's a minimum RAM cache size that we always allow (even if this means
|
||||
# using swap).
|
||||
min_ram_cache_size_bytes = 4 * GB
|
||||
ram_available_based_on_min_cache_size = min_ram_cache_size_bytes - cache_ram_used
|
||||
# Apply heuristic 1.
|
||||
# ------------------
|
||||
heuristics_applied = [1]
|
||||
total_system_ram_bytes = psutil.virtual_memory().total
|
||||
# Assumed baseline RAM used by InvokeAI for non-model stuff.
|
||||
baseline_ram_used_by_invokeai = 2 * GB
|
||||
ram_available_to_model_cache = int(total_system_ram_bytes * 0.5 - baseline_ram_used_by_invokeai)
|
||||
|
||||
return max(ram_available_based_on_memory_usage, ram_available_based_on_min_cache_size)
|
||||
# Apply heuristic 2.
|
||||
# ------------------
|
||||
max_ram_cache_size_bytes = 32 * GB
|
||||
if total_cuda_vram_bytes is not None:
|
||||
max_ram_cache_size_bytes = 2 * total_cuda_vram_bytes
|
||||
if ram_available_to_model_cache > max_ram_cache_size_bytes:
|
||||
heuristics_applied.append(2)
|
||||
ram_available_to_model_cache = max_ram_cache_size_bytes
|
||||
|
||||
# Apply heuristic 3.
|
||||
# ------------------
|
||||
if total_cuda_vram_bytes is not None:
|
||||
if self._max_vram_cache_size_gb is not None:
|
||||
min_ram_cache_size_bytes = int(self._max_vram_cache_size_gb * GB)
|
||||
else:
|
||||
min_ram_cache_size_bytes = total_cuda_vram_bytes - int(self._execution_device_working_mem_gb * GB)
|
||||
if ram_available_to_model_cache < min_ram_cache_size_bytes:
|
||||
heuristics_applied.append(3)
|
||||
ram_available_to_model_cache = min_ram_cache_size_bytes
|
||||
|
||||
# Apply heuristic 4.
|
||||
# ------------------
|
||||
if ram_available_to_model_cache < 4 * GB:
|
||||
heuristics_applied.append(4)
|
||||
ram_available_to_model_cache = 4 * GB
|
||||
|
||||
self._logger.info(
|
||||
f"Calculated model RAM cache size: {ram_available_to_model_cache / MB:.2f} MB. Heuristics applied: {heuristics_applied}."
|
||||
)
|
||||
return ram_available_to_model_cache
|
||||
|
||||
def _get_ram_in_use(self) -> int:
|
||||
"""Get the amount of RAM currently in use."""
|
||||
return sum(ce.cached_model.total_bytes() for ce in self._cached_models.values())
|
||||
|
||||
def _get_ram_available(self) -> int:
|
||||
"""Get the amount of RAM available for the cache to use."""
|
||||
return self._ram_cache_size_bytes - self._get_ram_in_use()
|
||||
|
||||
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
|
||||
if self._log_memory_usage:
|
||||
return MemorySnapshot.capture()
|
||||
|
||||
@@ -80,19 +80,19 @@ class FluxVAELoader(ModelLoader):
|
||||
raise ValueError("Only VAECheckpointConfig models are currently supported here.")
|
||||
model_path = Path(config.path)
|
||||
|
||||
with SilenceWarnings():
|
||||
with accelerate.init_empty_weights():
|
||||
model = AutoEncoder(ae_params[config.config_path])
|
||||
sd = load_file(model_path)
|
||||
model.load_state_dict(sd, assign=True)
|
||||
# VAE is broken in float16, which mps defaults to
|
||||
if self._torch_dtype == torch.float16:
|
||||
try:
|
||||
vae_dtype = torch.tensor([1.0], dtype=torch.bfloat16, device=self._torch_device).dtype
|
||||
except TypeError:
|
||||
vae_dtype = torch.float32
|
||||
else:
|
||||
vae_dtype = self._torch_dtype
|
||||
model.to(vae_dtype)
|
||||
sd = load_file(model_path)
|
||||
model.load_state_dict(sd, assign=True)
|
||||
# VAE is broken in float16, which mps defaults to
|
||||
if self._torch_dtype == torch.float16:
|
||||
try:
|
||||
vae_dtype = torch.tensor([1.0], dtype=torch.bfloat16, device=self._torch_device).dtype
|
||||
except TypeError:
|
||||
vae_dtype = torch.float32
|
||||
else:
|
||||
vae_dtype = self._torch_dtype
|
||||
model.to(vae_dtype)
|
||||
|
||||
return model
|
||||
|
||||
@@ -183,7 +183,9 @@ class T5EncoderCheckpointModel(ModelLoader):
|
||||
case SubModelType.Tokenizer2 | SubModelType.Tokenizer3:
|
||||
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
|
||||
case SubModelType.TextEncoder2 | SubModelType.TextEncoder3:
|
||||
return T5EncoderModel.from_pretrained(Path(config.path) / "text_encoder_2", torch_dtype="auto")
|
||||
return T5EncoderModel.from_pretrained(
|
||||
Path(config.path) / "text_encoder_2", torch_dtype="auto", low_cpu_mem_usage=True
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
|
||||
@@ -217,17 +219,20 @@ class FluxCheckpointModel(ModelLoader):
|
||||
assert isinstance(config, MainCheckpointConfig)
|
||||
model_path = Path(config.path)
|
||||
|
||||
with SilenceWarnings():
|
||||
model = Flux(params[config.config_path])
|
||||
sd = load_file(model_path)
|
||||
if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd:
|
||||
sd = convert_bundle_to_flux_transformer_checkpoint(sd)
|
||||
new_sd_size = sum([ten.nelement() * torch.bfloat16.itemsize for ten in sd.values()])
|
||||
self._ram_cache.make_room(new_sd_size)
|
||||
for k in sd.keys():
|
||||
# We need to cast to bfloat16 due to it being the only currently supported dtype for inference
|
||||
sd[k] = sd[k].to(torch.bfloat16)
|
||||
model.load_state_dict(sd, assign=True)
|
||||
sd = load_file(model_path)
|
||||
if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd:
|
||||
sd = convert_bundle_to_flux_transformer_checkpoint(sd)
|
||||
new_sd_size = sum([ten.nelement() * torch.bfloat16.itemsize for ten in sd.values()])
|
||||
self._ram_cache.make_room(new_sd_size)
|
||||
for k in sd.keys():
|
||||
# We need to cast to bfloat16 due to it being the only currently supported dtype for inference
|
||||
sd[k] = sd[k].to(torch.bfloat16)
|
||||
|
||||
flux_params = infer_flux_params_from_state_dict(sd)
|
||||
with accelerate.init_empty_weights():
|
||||
model = Flux(flux_params)
|
||||
|
||||
model.load_state_dict(sd, assign=True)
|
||||
return model
|
||||
|
||||
|
||||
@@ -258,11 +263,11 @@ class FluxGGUFCheckpointModel(ModelLoader):
|
||||
assert isinstance(config, MainGGUFCheckpointConfig)
|
||||
model_path = Path(config.path)
|
||||
|
||||
with SilenceWarnings():
|
||||
with accelerate.init_empty_weights():
|
||||
model = Flux(params[config.config_path])
|
||||
|
||||
# HACK(ryand): We shouldn't be hard-coding the compute_dtype here.
|
||||
sd = gguf_sd_loader(model_path, compute_dtype=torch.bfloat16)
|
||||
# HACK(ryand): We shouldn't be hard-coding the compute_dtype here.
|
||||
sd = gguf_sd_loader(model_path, compute_dtype=torch.bfloat16)
|
||||
|
||||
# HACK(ryand): There are some broken GGUF models in circulation that have the wrong shape for img_in.weight.
|
||||
# We override the shape here to fix the issue.
|
||||
|
||||
@@ -76,6 +76,7 @@
|
||||
"konva": "^9.3.15",
|
||||
"lodash-es": "^4.17.21",
|
||||
"lru-cache": "^11.0.1",
|
||||
"mtwist": "^1.0.2",
|
||||
"nanoid": "^5.0.7",
|
||||
"nanostores": "^0.11.3",
|
||||
"new-github-issue-url": "^1.0.0",
|
||||
|
||||
7
invokeai/frontend/web/pnpm-lock.yaml
generated
7
invokeai/frontend/web/pnpm-lock.yaml
generated
@@ -77,6 +77,9 @@ dependencies:
|
||||
lru-cache:
|
||||
specifier: ^11.0.1
|
||||
version: 11.0.1
|
||||
mtwist:
|
||||
specifier: ^1.0.2
|
||||
version: 1.0.2
|
||||
nanoid:
|
||||
specifier: ^5.0.7
|
||||
version: 5.0.7
|
||||
@@ -7016,6 +7019,10 @@ packages:
|
||||
/ms@2.1.3:
|
||||
resolution: {integrity: sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==}
|
||||
|
||||
/mtwist@1.0.2:
|
||||
resolution: {integrity: sha512-eRsSga5jkLg7nNERPOV8vDNxgSwuEcj5upQfJcT0gXfJwXo3pMc7xOga0fu8rXHyrxzl7GFVWWDuaPQgpKDvgw==}
|
||||
dev: false
|
||||
|
||||
/muggle-string@0.3.1:
|
||||
resolution: {integrity: sha512-ckmWDJjphvd/FvZawgygcUeQCxzvohjFO5RxTjj4eq8kw359gFF3E1brjfI+viLMxss5JrHTDRHZvu2/tuy0Qg==}
|
||||
dev: true
|
||||
|
||||
@@ -177,7 +177,17 @@
|
||||
"none": "None",
|
||||
"new": "New",
|
||||
"generating": "Generating",
|
||||
"warnings": "Warnings"
|
||||
"warnings": "Warnings",
|
||||
"start": "Start",
|
||||
"count": "Count",
|
||||
"step": "Step",
|
||||
"end": "End",
|
||||
"min": "Min",
|
||||
"max": "Max",
|
||||
"values": "Values",
|
||||
"resetToDefaults": "Reset to Defaults",
|
||||
"seed": "Seed",
|
||||
"combinatorial": "Combinatorial"
|
||||
},
|
||||
"hrf": {
|
||||
"hrf": "High Resolution Fix",
|
||||
@@ -850,6 +860,19 @@
|
||||
"defaultVAE": "Default VAE"
|
||||
},
|
||||
"nodes": {
|
||||
"arithmeticSequence": "Arithmetic Sequence",
|
||||
"linearDistribution": "Linear Distribution",
|
||||
"uniformRandomDistribution": "Uniform Random Distribution",
|
||||
"parseString": "Parse String",
|
||||
"splitOn": "Split On",
|
||||
"noBatchGroup": "no group",
|
||||
"generatorNRandomValues_one": "{{count}} random value",
|
||||
"generatorNRandomValues_other": "{{count}} random values",
|
||||
"generatorNoValues": "empty",
|
||||
"generatorLoading": "loading",
|
||||
"generatorLoadFromFile": "Load from File",
|
||||
"dynamicPromptsRandom": "Dynamic Prompts (Random)",
|
||||
"dynamicPromptsCombinatorial": "Dynamic Prompts (Combinatorial)",
|
||||
"addNode": "Add Node",
|
||||
"addNodeToolTip": "Add Node (Shift+A, Space)",
|
||||
"addLinearView": "Add to Linear View",
|
||||
@@ -989,7 +1012,11 @@
|
||||
"imageAccessError": "Unable to find image {{image_name}}, resetting to default",
|
||||
"boardAccessError": "Unable to find board {{board_id}}, resetting to default",
|
||||
"modelAccessError": "Unable to find model {{key}}, resetting to default",
|
||||
"saveToGallery": "Save To Gallery"
|
||||
"saveToGallery": "Save To Gallery",
|
||||
"addItem": "Add Item",
|
||||
"generateValues": "Generate Values",
|
||||
"floatRangeGenerator": "Float Range Generator",
|
||||
"integerRangeGenerator": "Integer Range Generator"
|
||||
},
|
||||
"parameters": {
|
||||
"aspect": "Aspect",
|
||||
@@ -1024,11 +1051,22 @@
|
||||
"addingImagesTo": "Adding images to",
|
||||
"invoke": "Invoke",
|
||||
"missingFieldTemplate": "Missing field template",
|
||||
"missingInputForField": "{{nodeLabel}} -> {{fieldLabel}}: missing input",
|
||||
"missingInputForField": "missing input",
|
||||
"missingNodeTemplate": "Missing node template",
|
||||
"collectionEmpty": "{{nodeLabel}} -> {{fieldLabel}} empty collection",
|
||||
"collectionTooFewItems": "{{nodeLabel}} -> {{fieldLabel}}: too few items, minimum {{minItems}}",
|
||||
"collectionTooManyItems": "{{nodeLabel}} -> {{fieldLabel}}: too many items, maximum {{maxItems}}",
|
||||
"emptyBatches": "empty batches",
|
||||
"batchNodeNotConnected": "Batch node not connected: {{label}}",
|
||||
"batchNodeEmptyCollection": "Some batch nodes have empty collections",
|
||||
"invalidBatchConfigurationCannotCalculate": "Invalid batch configuration; cannot calculate",
|
||||
"collectionTooFewItems": "too few items, minimum {{minItems}}",
|
||||
"collectionTooManyItems": "too many items, maximum {{maxItems}}",
|
||||
"collectionStringTooLong": "too long, max {{maxLength}}",
|
||||
"collectionStringTooShort": "too short, min {{minLength}}",
|
||||
"collectionNumberGTMax": "{{value}} > {{maximum}} (inc max)",
|
||||
"collectionNumberLTMin": "{{value}} < {{minimum}} (inc min)",
|
||||
"collectionNumberGTExclusiveMax": "{{value}} >= {{exclusiveMaximum}} (exc max)",
|
||||
"collectionNumberLTExclusiveMin": "{{value}} <= {{exclusiveMinimum}} (exc min)",
|
||||
"collectionNumberNotMultipleOf": "{{value}} not multiple of {{multipleOf}}",
|
||||
"batchNodeCollectionSizeMismatch": "Collection size mismatch on Batch {{batchGroupId}}",
|
||||
"noModelSelected": "No model selected",
|
||||
"noT5EncoderModelSelected": "No T5 Encoder model selected for FLUX generation",
|
||||
"noFLUXVAEModelSelected": "No VAE model selected for FLUX generation",
|
||||
@@ -1100,7 +1138,8 @@
|
||||
"perPromptLabel": "Seed per Image",
|
||||
"perPromptDesc": "Use a different seed for each image"
|
||||
},
|
||||
"loading": "Generating Dynamic Prompts..."
|
||||
"loading": "Generating Dynamic Prompts...",
|
||||
"promptsToGenerate": "Prompts to Generate"
|
||||
},
|
||||
"sdxl": {
|
||||
"cfgScale": "CFG Scale",
|
||||
@@ -1932,6 +1971,24 @@
|
||||
"description": "Generates an edge map from the selected layer using the PiDiNet edge detection model.",
|
||||
"scribble": "Scribble",
|
||||
"quantize_edges": "Quantize Edges"
|
||||
},
|
||||
"img_blur": {
|
||||
"label": "Blur Image",
|
||||
"description": "Blurs the selected layer.",
|
||||
"blur_type": "Blur Type",
|
||||
"blur_radius": "Radius",
|
||||
"gaussian_type": "Gaussian",
|
||||
"box_type": "Box"
|
||||
},
|
||||
"img_noise": {
|
||||
"label": "Noise Image",
|
||||
"description": "Adds noise to the selected layer.",
|
||||
"noise_type": "Noise Type",
|
||||
"noise_amount": "Amount",
|
||||
"gaussian_type": "Gaussian",
|
||||
"salt_and_pepper_type": "Salt and Pepper",
|
||||
"noise_color": "Colored Noise",
|
||||
"size": "Noise Size"
|
||||
}
|
||||
},
|
||||
"transform": {
|
||||
|
||||
@@ -1,16 +1,14 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { enqueueRequested } from 'app/store/actions';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { isImageFieldCollectionInputInstance } from 'features/nodes/types/field';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { isBatchNode, isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { buildNodesGraph } from 'features/nodes/util/graph/buildNodesGraph';
|
||||
import { buildWorkflowWithValidation } from 'features/nodes/util/workflow/buildWorkflow';
|
||||
import { resolveBatchValue } from 'features/queue/store/readiness';
|
||||
import { groupBy } from 'lodash-es';
|
||||
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
|
||||
import type { Batch, BatchConfig } from 'services/api/types';
|
||||
|
||||
const log = logger('workflows');
|
||||
|
||||
export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
predicate: (action): action is ReturnType<typeof enqueueRequested> =>
|
||||
@@ -33,28 +31,54 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
|
||||
|
||||
const data: Batch['data'] = [];
|
||||
|
||||
// Skip edges from batch nodes - these should not be in the graph, they exist only in the UI
|
||||
const imageBatchNodes = nodes.nodes.filter(isInvocationNode).filter((node) => node.data.type === 'image_batch');
|
||||
for (const node of imageBatchNodes) {
|
||||
const images = node.data.inputs['images'];
|
||||
if (!isImageFieldCollectionInputInstance(images)) {
|
||||
log.warn({ nodeId: node.id }, 'Image batch images field is not an image collection');
|
||||
break;
|
||||
}
|
||||
const edgesFromImageBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'image');
|
||||
const batchDataCollectionItem: NonNullable<Batch['data']>[number] = [];
|
||||
for (const edge of edgesFromImageBatch) {
|
||||
if (!edge.targetHandle) {
|
||||
break;
|
||||
const invocationNodes = nodes.nodes.filter(isInvocationNode);
|
||||
const batchNodes = invocationNodes.filter(isBatchNode);
|
||||
|
||||
// Handle zipping batch nodes. First group the batch nodes by their batch_group_id
|
||||
const groupedBatchNodes = groupBy(batchNodes, (node) => node.data.inputs['batch_group_id']?.value);
|
||||
|
||||
// Then, we will create a batch data collection item for each group
|
||||
for (const [batchGroupId, batchNodes] of Object.entries(groupedBatchNodes)) {
|
||||
const zippedBatchDataCollectionItems: NonNullable<Batch['data']>[number] = [];
|
||||
|
||||
for (const node of batchNodes) {
|
||||
const value = resolveBatchValue(node, invocationNodes, nodes.edges);
|
||||
const sourceHandle = node.data.type === 'image_batch' ? 'image' : 'value';
|
||||
const edgesFromBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === sourceHandle);
|
||||
if (batchGroupId !== 'None') {
|
||||
// If this batch node has a batch_group_id, we will zip the data collection items
|
||||
for (const edge of edgesFromBatch) {
|
||||
if (!edge.targetHandle) {
|
||||
break;
|
||||
}
|
||||
zippedBatchDataCollectionItems.push({
|
||||
node_path: edge.target,
|
||||
field_name: edge.targetHandle,
|
||||
items: value,
|
||||
});
|
||||
}
|
||||
} else {
|
||||
// Otherwise add the data collection items to root of the batch so they are not zipped
|
||||
const productBatchDataCollectionItems: NonNullable<Batch['data']>[number] = [];
|
||||
for (const edge of edgesFromBatch) {
|
||||
if (!edge.targetHandle) {
|
||||
break;
|
||||
}
|
||||
productBatchDataCollectionItems.push({
|
||||
node_path: edge.target,
|
||||
field_name: edge.targetHandle,
|
||||
items: value,
|
||||
});
|
||||
}
|
||||
if (productBatchDataCollectionItems.length > 0) {
|
||||
data.push(productBatchDataCollectionItems);
|
||||
}
|
||||
}
|
||||
batchDataCollectionItem.push({
|
||||
node_path: edge.target,
|
||||
field_name: edge.targetHandle,
|
||||
items: images.value,
|
||||
});
|
||||
}
|
||||
if (batchDataCollectionItem.length > 0) {
|
||||
data.push(batchDataCollectionItem);
|
||||
|
||||
// Finally, if this batch data collection item has any items, add it to the data array
|
||||
if (batchGroupId !== 'None' && zippedBatchDataCollectionItems.length > 0) {
|
||||
data.push(zippedBatchDataCollectionItems);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
|
||||
import { Combobox, CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import type { BlurFilterConfig, BlurTypes } from 'features/controlLayers/store/filters';
|
||||
import { IMAGE_FILTERS, isBlurTypes } from 'features/controlLayers/store/filters';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { FilterComponentProps } from './types';
|
||||
|
||||
type Props = FilterComponentProps<BlurFilterConfig>;
|
||||
const DEFAULTS = IMAGE_FILTERS.img_blur.buildDefaults();
|
||||
|
||||
export const FilterBlur = memo(({ onChange, config }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const handleBlurTypeChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
if (!isBlurTypes(v?.value)) {
|
||||
return;
|
||||
}
|
||||
onChange({ ...config, blur_type: v.value });
|
||||
},
|
||||
[config, onChange]
|
||||
);
|
||||
|
||||
const handleRadiusChange = useCallback(
|
||||
(v: number) => {
|
||||
onChange({ ...config, radius: v });
|
||||
},
|
||||
[config, onChange]
|
||||
);
|
||||
|
||||
const options: { label: string; value: BlurTypes }[] = useMemo(
|
||||
() => [
|
||||
{ label: t('controlLayers.filter.img_blur.gaussian_type'), value: 'gaussian' },
|
||||
{ label: t('controlLayers.filter.img_blur.box_type'), value: 'box' },
|
||||
],
|
||||
[t]
|
||||
);
|
||||
|
||||
const value = useMemo(() => options.filter((o) => o.value === config.blur_type)[0], [options, config.blur_type]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<FormControl>
|
||||
<FormLabel m={0}>{t('controlLayers.filter.img_blur.blur_type')}</FormLabel>
|
||||
<Combobox value={value} options={options} onChange={handleBlurTypeChange} isSearchable={false} />
|
||||
</FormControl>
|
||||
<FormControl>
|
||||
<FormLabel m={0}>{t('controlLayers.filter.img_blur.blur_radius')}</FormLabel>
|
||||
<CompositeSlider
|
||||
value={config.radius}
|
||||
defaultValue={DEFAULTS.radius}
|
||||
onChange={handleRadiusChange}
|
||||
min={1}
|
||||
max={64}
|
||||
step={0.1}
|
||||
marks
|
||||
/>
|
||||
<CompositeNumberInput
|
||||
value={config.radius}
|
||||
defaultValue={DEFAULTS.radius}
|
||||
onChange={handleRadiusChange}
|
||||
min={1}
|
||||
max={4096}
|
||||
step={0.1}
|
||||
/>
|
||||
</FormControl>
|
||||
</>
|
||||
);
|
||||
});
|
||||
|
||||
FilterBlur.displayName = 'FilterBlur';
|
||||
@@ -0,0 +1,111 @@
|
||||
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
|
||||
import { Combobox, CompositeNumberInput, CompositeSlider, FormControl, FormLabel, Switch } from '@invoke-ai/ui-library';
|
||||
import type { NoiseFilterConfig, NoiseTypes } from 'features/controlLayers/store/filters';
|
||||
import { IMAGE_FILTERS, isNoiseTypes } from 'features/controlLayers/store/filters';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { FilterComponentProps } from './types';
|
||||
|
||||
type Props = FilterComponentProps<NoiseFilterConfig>;
|
||||
const DEFAULTS = IMAGE_FILTERS.img_noise.buildDefaults();
|
||||
|
||||
export const FilterNoise = memo(({ onChange, config }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const handleNoiseTypeChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
if (!isNoiseTypes(v?.value)) {
|
||||
return;
|
||||
}
|
||||
onChange({ ...config, noise_type: v.value });
|
||||
},
|
||||
[config, onChange]
|
||||
);
|
||||
|
||||
const handleAmountChange = useCallback(
|
||||
(v: number) => {
|
||||
onChange({ ...config, amount: v });
|
||||
},
|
||||
[config, onChange]
|
||||
);
|
||||
|
||||
const handleColorChange = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => {
|
||||
onChange({ ...config, noise_color: e.target.checked });
|
||||
},
|
||||
[config, onChange]
|
||||
);
|
||||
|
||||
const handleSizeChange = useCallback(
|
||||
(v: number) => {
|
||||
onChange({ ...config, size: v });
|
||||
},
|
||||
[config, onChange]
|
||||
);
|
||||
|
||||
const options: { label: string; value: NoiseTypes }[] = useMemo(
|
||||
() => [
|
||||
{ label: t('controlLayers.filter.img_noise.gaussian_type'), value: 'gaussian' },
|
||||
{ label: t('controlLayers.filter.img_noise.salt_and_pepper_type'), value: 'salt_and_pepper' },
|
||||
],
|
||||
[t]
|
||||
);
|
||||
|
||||
const value = useMemo(() => options.filter((o) => o.value === config.noise_type)[0], [options, config.noise_type]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<FormControl>
|
||||
<FormLabel m={0}>{t('controlLayers.filter.img_noise.noise_type')}</FormLabel>
|
||||
<Combobox value={value} options={options} onChange={handleNoiseTypeChange} isSearchable={false} />
|
||||
</FormControl>
|
||||
<FormControl>
|
||||
<FormLabel m={0}>{t('controlLayers.filter.img_noise.noise_amount')}</FormLabel>
|
||||
<CompositeSlider
|
||||
value={config.amount}
|
||||
defaultValue={DEFAULTS.amount}
|
||||
onChange={handleAmountChange}
|
||||
min={0}
|
||||
max={1}
|
||||
step={0.01}
|
||||
marks
|
||||
/>
|
||||
<CompositeNumberInput
|
||||
value={config.amount}
|
||||
defaultValue={DEFAULTS.amount}
|
||||
onChange={handleAmountChange}
|
||||
min={0}
|
||||
max={1}
|
||||
step={0.01}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormControl>
|
||||
<FormLabel m={0}>{t('controlLayers.filter.img_noise.size')}</FormLabel>
|
||||
<CompositeSlider
|
||||
value={config.size}
|
||||
defaultValue={DEFAULTS.size}
|
||||
onChange={handleSizeChange}
|
||||
min={1}
|
||||
max={16}
|
||||
step={1}
|
||||
marks
|
||||
/>
|
||||
<CompositeNumberInput
|
||||
value={config.size}
|
||||
defaultValue={DEFAULTS.size}
|
||||
onChange={handleSizeChange}
|
||||
min={1}
|
||||
max={256}
|
||||
step={1}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormControl w="max-content">
|
||||
<FormLabel m={0}>{t('controlLayers.filter.img_noise.noise_color')}</FormLabel>
|
||||
<Switch defaultChecked={DEFAULTS.noise_color} isChecked={config.noise_color} onChange={handleColorChange} />
|
||||
</FormControl>
|
||||
</>
|
||||
);
|
||||
});
|
||||
|
||||
FilterNoise.displayName = 'Filternoise';
|
||||
@@ -1,4 +1,5 @@
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import { FilterBlur } from 'features/controlLayers/components/Filters/FilterBlur';
|
||||
import { FilterCannyEdgeDetection } from 'features/controlLayers/components/Filters/FilterCannyEdgeDetection';
|
||||
import { FilterColorMap } from 'features/controlLayers/components/Filters/FilterColorMap';
|
||||
import { FilterContentShuffle } from 'features/controlLayers/components/Filters/FilterContentShuffle';
|
||||
@@ -8,6 +9,7 @@ import { FilterHEDEdgeDetection } from 'features/controlLayers/components/Filter
|
||||
import { FilterLineartEdgeDetection } from 'features/controlLayers/components/Filters/FilterLineartEdgeDetection';
|
||||
import { FilterMediaPipeFaceDetection } from 'features/controlLayers/components/Filters/FilterMediaPipeFaceDetection';
|
||||
import { FilterMLSDDetection } from 'features/controlLayers/components/Filters/FilterMLSDDetection';
|
||||
import { FilterNoise } from 'features/controlLayers/components/Filters/FilterNoise';
|
||||
import { FilterPiDiNetEdgeDetection } from 'features/controlLayers/components/Filters/FilterPiDiNetEdgeDetection';
|
||||
import { FilterSpandrel } from 'features/controlLayers/components/Filters/FilterSpandrel';
|
||||
import type { FilterConfig } from 'features/controlLayers/store/filters';
|
||||
@@ -19,6 +21,10 @@ type Props = { filterConfig: FilterConfig; onChange: (filterConfig: FilterConfig
|
||||
export const FilterSettings = memo(({ filterConfig, onChange }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
if (filterConfig.type === 'img_blur') {
|
||||
return <FilterBlur config={filterConfig} onChange={onChange} />;
|
||||
}
|
||||
|
||||
if (filterConfig.type === 'canny_edge_detection') {
|
||||
return <FilterCannyEdgeDetection config={filterConfig} onChange={onChange} />;
|
||||
}
|
||||
@@ -59,6 +65,10 @@ export const FilterSettings = memo(({ filterConfig, onChange }: Props) => {
|
||||
return <FilterPiDiNetEdgeDetection config={filterConfig} onChange={onChange} />;
|
||||
}
|
||||
|
||||
if (filterConfig.type === 'img_noise') {
|
||||
return <FilterNoise config={filterConfig} onChange={onChange} />;
|
||||
}
|
||||
|
||||
if (filterConfig.type === 'spandrel_filter') {
|
||||
return <FilterSpandrel config={filterConfig} onChange={onChange} />;
|
||||
}
|
||||
|
||||
@@ -297,10 +297,9 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
|
||||
const imageState = imageDTOToImageObject(filterResult.value);
|
||||
this.$imageState.set(imageState);
|
||||
|
||||
// Destroy any existing masked image and create a new one
|
||||
if (this.imageModule) {
|
||||
this.imageModule.destroy();
|
||||
}
|
||||
// Stash the existing image module - we will destroy it after the new image is rendered to prevent a flash
|
||||
// of an empty layer
|
||||
const oldImageModule = this.imageModule;
|
||||
|
||||
this.imageModule = new CanvasObjectImage(imageState, this);
|
||||
|
||||
@@ -309,6 +308,16 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
|
||||
|
||||
this.konva.group.add(this.imageModule.konva.group);
|
||||
|
||||
// The filtered image have some transparency, so we need to hide the objects of the parent entity to prevent the
|
||||
// two images from blending. We will show the objects again in the teardown method, which is always called after
|
||||
// the filter finishes (applied or canceled).
|
||||
this.parent.renderer.hideObjects();
|
||||
|
||||
if (oldImageModule) {
|
||||
// Destroy the old image module now that the new one is rendered
|
||||
oldImageModule.destroy();
|
||||
}
|
||||
|
||||
// The porcessing is complete, set can set the last processed hash and isProcessing to false
|
||||
this.$lastProcessedHash.set(hash);
|
||||
|
||||
@@ -424,6 +433,8 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
|
||||
|
||||
teardown = () => {
|
||||
this.unsubscribe();
|
||||
// Re-enable the objects of the parent entity
|
||||
this.parent.renderer.showObjects();
|
||||
this.konva.group.remove();
|
||||
// The reset must be done _after_ unsubscribing from listeners, in case the listeners would otherwise react to
|
||||
// the reset. For example, if auto-processing is enabled and we reset the state, it may trigger processing.
|
||||
|
||||
@@ -185,6 +185,14 @@ export class CanvasEntityObjectRenderer extends CanvasModuleBase {
|
||||
return didRender;
|
||||
};
|
||||
|
||||
hideObjects = () => {
|
||||
this.konva.objectGroup.hide();
|
||||
};
|
||||
|
||||
showObjects = () => {
|
||||
this.konva.objectGroup.show();
|
||||
};
|
||||
|
||||
adoptObjectRenderer = (renderer: AnyObjectRenderer) => {
|
||||
this.renderers.set(renderer.id, renderer);
|
||||
renderer.konva.group.moveTo(this.konva.objectGroup);
|
||||
|
||||
@@ -10,6 +10,7 @@ import {
|
||||
getKonvaNodeDebugAttrs,
|
||||
getPrefixedId,
|
||||
offsetCoord,
|
||||
roundRect,
|
||||
} from 'features/controlLayers/konva/util';
|
||||
import { selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
|
||||
import type { Coordinate, Rect, RectWithRotation } from 'features/controlLayers/store/types';
|
||||
@@ -773,7 +774,7 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
|
||||
const rect = this.getRelativeRect();
|
||||
const rasterizeResult = await withResultAsync(() =>
|
||||
this.parent.renderer.rasterize({
|
||||
rect,
|
||||
rect: roundRect(rect),
|
||||
replaceObjects: true,
|
||||
ignoreCache: true,
|
||||
attrs: { opacity: 1, filters: [] },
|
||||
|
||||
@@ -740,3 +740,12 @@ export const getColorAtCoordinate = (stage: Konva.Stage, coord: Coordinate): Rgb
|
||||
|
||||
return { r, g, b };
|
||||
};
|
||||
|
||||
export const roundRect = (rect: Rect): Rect => {
|
||||
return {
|
||||
x: Math.round(rect.x),
|
||||
y: Math.round(rect.y),
|
||||
width: Math.round(rect.width),
|
||||
height: Math.round(rect.height),
|
||||
};
|
||||
};
|
||||
|
||||
@@ -95,6 +95,28 @@ const zSpandrelFilterConfig = z.object({
|
||||
});
|
||||
export type SpandrelFilterConfig = z.infer<typeof zSpandrelFilterConfig>;
|
||||
|
||||
const zBlurTypes = z.enum(['gaussian', 'box']);
|
||||
export type BlurTypes = z.infer<typeof zBlurTypes>;
|
||||
export const isBlurTypes = (v: unknown): v is BlurTypes => zBlurTypes.safeParse(v).success;
|
||||
const zBlurFilterConfig = z.object({
|
||||
type: z.literal('img_blur'),
|
||||
blur_type: zBlurTypes,
|
||||
radius: z.number().gte(0),
|
||||
});
|
||||
export type BlurFilterConfig = z.infer<typeof zBlurFilterConfig>;
|
||||
|
||||
const zNoiseTypes = z.enum(['gaussian', 'salt_and_pepper']);
|
||||
export type NoiseTypes = z.infer<typeof zNoiseTypes>;
|
||||
export const isNoiseTypes = (v: unknown): v is NoiseTypes => zNoiseTypes.safeParse(v).success;
|
||||
const zNoiseFilterConfig = z.object({
|
||||
type: z.literal('img_noise'),
|
||||
noise_type: zNoiseTypes,
|
||||
amount: z.number().gte(0).lte(1),
|
||||
noise_color: z.boolean(),
|
||||
size: z.number().int().gte(1),
|
||||
});
|
||||
export type NoiseFilterConfig = z.infer<typeof zNoiseFilterConfig>;
|
||||
|
||||
const zFilterConfig = z.discriminatedUnion('type', [
|
||||
zCannyEdgeDetectionFilterConfig,
|
||||
zColorMapFilterConfig,
|
||||
@@ -109,6 +131,8 @@ const zFilterConfig = z.discriminatedUnion('type', [
|
||||
zPiDiNetEdgeDetectionFilterConfig,
|
||||
zDWOpenposeDetectionFilterConfig,
|
||||
zSpandrelFilterConfig,
|
||||
zBlurFilterConfig,
|
||||
zNoiseFilterConfig,
|
||||
]);
|
||||
export type FilterConfig = z.infer<typeof zFilterConfig>;
|
||||
|
||||
@@ -126,6 +150,8 @@ const zFilterType = z.enum([
|
||||
'pidi_edge_detection',
|
||||
'dw_openpose_detection',
|
||||
'spandrel_filter',
|
||||
'img_blur',
|
||||
'img_noise',
|
||||
]);
|
||||
export type FilterType = z.infer<typeof zFilterType>;
|
||||
export const isFilterType = (v: unknown): v is FilterType => zFilterType.safeParse(v).success;
|
||||
@@ -429,6 +455,62 @@ export const IMAGE_FILTERS: { [key in FilterConfig['type']]: ImageFilterData<key
|
||||
return true;
|
||||
},
|
||||
},
|
||||
img_blur: {
|
||||
type: 'img_blur',
|
||||
buildDefaults: () => ({
|
||||
type: 'img_blur',
|
||||
blur_type: 'gaussian',
|
||||
radius: 8,
|
||||
}),
|
||||
buildGraph: ({ image_name }, { blur_type, radius }) => {
|
||||
const graph = new Graph(getPrefixedId('img_blur'));
|
||||
const node = graph.addNode({
|
||||
id: getPrefixedId('img_blur'),
|
||||
type: 'img_blur',
|
||||
image: { image_name },
|
||||
blur_type: blur_type,
|
||||
radius: radius,
|
||||
});
|
||||
return {
|
||||
graph,
|
||||
outputNodeId: node.id,
|
||||
};
|
||||
},
|
||||
},
|
||||
img_noise: {
|
||||
type: 'img_noise',
|
||||
buildDefaults: () => ({
|
||||
type: 'img_noise',
|
||||
noise_type: 'gaussian',
|
||||
amount: 0.3,
|
||||
noise_color: true,
|
||||
size: 1,
|
||||
}),
|
||||
buildGraph: ({ image_name }, { noise_type, amount, noise_color, size }) => {
|
||||
const graph = new Graph(getPrefixedId('img_noise'));
|
||||
const node = graph.addNode({
|
||||
id: getPrefixedId('img_noise'),
|
||||
type: 'img_noise',
|
||||
image: { image_name },
|
||||
noise_type: noise_type,
|
||||
amount: amount,
|
||||
noise_color: noise_color,
|
||||
size: size,
|
||||
});
|
||||
const rand = graph.addNode({
|
||||
id: getPrefixedId('rand_int'),
|
||||
use_cache: false,
|
||||
type: 'rand_int',
|
||||
low: 0,
|
||||
high: 2147483647,
|
||||
});
|
||||
graph.addEdge(rand, 'value', node, 'seed');
|
||||
return {
|
||||
graph,
|
||||
outputNodeId: node.id,
|
||||
};
|
||||
},
|
||||
},
|
||||
} as const;
|
||||
|
||||
/**
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import type {
|
||||
BlurFilterConfig,
|
||||
CannyEdgeDetectionFilterConfig,
|
||||
ColorMapFilterConfig,
|
||||
ContentShuffleFilterConfig,
|
||||
@@ -12,6 +13,7 @@ import type {
|
||||
LineartEdgeDetectionFilterConfig,
|
||||
MediaPipeFaceDetectionFilterConfig,
|
||||
MLSDDetectionFilterConfig,
|
||||
NoiseFilterConfig,
|
||||
NormalMapFilterConfig,
|
||||
PiDiNetEdgeDetectionFilterConfig,
|
||||
} from 'features/controlLayers/store/filters';
|
||||
@@ -54,6 +56,7 @@ describe('Control Adapter Types', () => {
|
||||
});
|
||||
test('Processor Configs', () => {
|
||||
// Types derived from OpenAPI
|
||||
type _BlurFilterConfig = Required<Pick<Invocation<'img_blur'>, 'type' | 'radius' | 'blur_type'>>;
|
||||
type _CannyEdgeDetectionFilterConfig = Required<
|
||||
Pick<Invocation<'canny_edge_detection'>, 'type' | 'low_threshold' | 'high_threshold'>
|
||||
>;
|
||||
@@ -71,6 +74,9 @@ describe('Control Adapter Types', () => {
|
||||
type _MLSDDetectionFilterConfig = Required<
|
||||
Pick<Invocation<'mlsd_detection'>, 'type' | 'score_threshold' | 'distance_threshold'>
|
||||
>;
|
||||
type _NoiseFilterConfig = Required<
|
||||
Pick<Invocation<'img_noise'>, 'type' | 'noise_type' | 'amount' | 'noise_color' | 'size'>
|
||||
>;
|
||||
type _NormalMapFilterConfig = Required<Pick<Invocation<'normal_map'>, 'type'>>;
|
||||
type _DWOpenposeDetectionFilterConfig = Required<
|
||||
Pick<Invocation<'dw_openpose_detection'>, 'type' | 'draw_body' | 'draw_face' | 'draw_hands'>
|
||||
@@ -81,6 +87,7 @@ describe('Control Adapter Types', () => {
|
||||
|
||||
// The processor configs are manually modeled zod schemas. This test ensures that the inferred types are correct.
|
||||
// The types prefixed with `_` are types generated from OpenAPI, while the types without the prefix are manually modeled.
|
||||
assert<Equals<_BlurFilterConfig, BlurFilterConfig>>();
|
||||
assert<Equals<_CannyEdgeDetectionFilterConfig, CannyEdgeDetectionFilterConfig>>();
|
||||
assert<Equals<_ColorMapFilterConfig, ColorMapFilterConfig>>();
|
||||
assert<Equals<_ContentShuffleFilterConfig, ContentShuffleFilterConfig>>();
|
||||
@@ -90,6 +97,7 @@ describe('Control Adapter Types', () => {
|
||||
assert<Equals<_LineartEdgeDetectionFilterConfig, LineartEdgeDetectionFilterConfig>>();
|
||||
assert<Equals<_MediaPipeFaceDetectionFilterConfig, MediaPipeFaceDetectionFilterConfig>>();
|
||||
assert<Equals<_MLSDDetectionFilterConfig, MLSDDetectionFilterConfig>>();
|
||||
assert<Equals<_NoiseFilterConfig, NoiseFilterConfig>>();
|
||||
assert<Equals<_NormalMapFilterConfig, NormalMapFilterConfig>>();
|
||||
assert<Equals<_DWOpenposeDetectionFilterConfig, DWOpenposeDetectionFilterConfig>>();
|
||||
assert<Equals<_PiDiNetEdgeDetectionFilterConfig, PiDiNetEdgeDetectionFilterConfig>>();
|
||||
|
||||
@@ -11,7 +11,7 @@ import type { DndTargetState } from 'features/dnd/types';
|
||||
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
|
||||
import { selectMaxImageUploadCount } from 'features/system/store/configSlice';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
|
||||
import { memo, useCallback, useEffect, useRef, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { uploadImages } from 'services/api/endpoints/images';
|
||||
import { useBoardName } from 'services/api/hooks/useBoardName';
|
||||
@@ -72,11 +72,10 @@ export const FullscreenDropzone = memo(() => {
|
||||
const maxImageUploadCount = useAppSelector(selectMaxImageUploadCount);
|
||||
const [dndState, setDndState] = useState<DndTargetState>('idle');
|
||||
|
||||
const uploadFilesSchema = useMemo(() => getFilesSchema(maxImageUploadCount), [maxImageUploadCount]);
|
||||
|
||||
const validateAndUploadFiles = useCallback(
|
||||
(files: File[]) => {
|
||||
const { getState } = getStore();
|
||||
const uploadFilesSchema = getFilesSchema(maxImageUploadCount);
|
||||
const parseResult = uploadFilesSchema.safeParse(files);
|
||||
|
||||
if (!parseResult.success) {
|
||||
@@ -105,7 +104,18 @@ export const FullscreenDropzone = memo(() => {
|
||||
|
||||
uploadImages(uploadArgs);
|
||||
},
|
||||
[maxImageUploadCount, t, uploadFilesSchema]
|
||||
[maxImageUploadCount, t]
|
||||
);
|
||||
|
||||
const onPaste = useCallback(
|
||||
(e: ClipboardEvent) => {
|
||||
if (!e.clipboardData?.files) {
|
||||
return;
|
||||
}
|
||||
const files = Array.from(e.clipboardData.files);
|
||||
validateAndUploadFiles(files);
|
||||
},
|
||||
[validateAndUploadFiles]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
@@ -144,24 +154,12 @@ export const FullscreenDropzone = memo(() => {
|
||||
}, [validateAndUploadFiles]);
|
||||
|
||||
useEffect(() => {
|
||||
const controller = new AbortController();
|
||||
|
||||
document.addEventListener(
|
||||
'paste',
|
||||
(e) => {
|
||||
if (!e.clipboardData?.files) {
|
||||
return;
|
||||
}
|
||||
const files = Array.from(e.clipboardData.files);
|
||||
validateAndUploadFiles(files);
|
||||
},
|
||||
{ signal: controller.signal }
|
||||
);
|
||||
window.addEventListener('paste', onPaste);
|
||||
|
||||
return () => {
|
||||
controller.abort();
|
||||
window.removeEventListener('paste', onPaste);
|
||||
};
|
||||
}, [validateAndUploadFiles]);
|
||||
}, [onPaste]);
|
||||
|
||||
return (
|
||||
<Box ref={ref} data-dnd-state={dndState} sx={sx}>
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppDispatch, RootState } from 'app/store/store';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import type {
|
||||
@@ -9,7 +10,6 @@ import { selectComparisonImages } from 'features/gallery/components/ImageViewer/
|
||||
import type { BoardId } from 'features/gallery/store/types';
|
||||
import {
|
||||
addImagesToBoard,
|
||||
addImagesToNodeImageFieldCollectionAction,
|
||||
createNewCanvasEntityFromImage,
|
||||
removeImagesFromBoard,
|
||||
replaceCanvasEntityObjectsWithImage,
|
||||
@@ -19,10 +19,14 @@ import {
|
||||
setRegionalGuidanceReferenceImage,
|
||||
setUpscaleInitialImage,
|
||||
} from 'features/imageActions/actions';
|
||||
import type { FieldIdentifier } from 'features/nodes/types/field';
|
||||
import { fieldImageCollectionValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { selectFieldInputInstance, selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { type FieldIdentifier, isImageFieldCollectionInputInstance } from 'features/nodes/types/field';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import type { JsonObject } from 'type-fest';
|
||||
|
||||
const log = logger('dnd');
|
||||
|
||||
type RecordUnknown = Record<string | symbol, unknown>;
|
||||
|
||||
type DndData<
|
||||
@@ -268,15 +272,27 @@ export const addImagesToNodeImageFieldCollectionDndTarget: DndTarget<
|
||||
}
|
||||
|
||||
const { fieldIdentifier } = targetData.payload;
|
||||
const imageDTOs: ImageDTO[] = [];
|
||||
|
||||
if (singleImageDndSource.typeGuard(sourceData)) {
|
||||
imageDTOs.push(sourceData.payload.imageDTO);
|
||||
} else {
|
||||
imageDTOs.push(...sourceData.payload.imageDTOs);
|
||||
const fieldInputInstance = selectFieldInputInstance(
|
||||
selectNodesSlice(getState()),
|
||||
fieldIdentifier.nodeId,
|
||||
fieldIdentifier.fieldName
|
||||
);
|
||||
|
||||
if (!isImageFieldCollectionInputInstance(fieldInputInstance)) {
|
||||
log.warn({ fieldIdentifier }, 'Attempted to add images to a non-image field collection');
|
||||
return;
|
||||
}
|
||||
|
||||
addImagesToNodeImageFieldCollectionAction({ fieldIdentifier, imageDTOs, dispatch, getState });
|
||||
const newValue = fieldInputInstance.value ? [...fieldInputInstance.value] : [];
|
||||
|
||||
if (singleImageDndSource.typeGuard(sourceData)) {
|
||||
newValue.push({ image_name: sourceData.payload.imageDTO.image_name });
|
||||
} else {
|
||||
newValue.push(...sourceData.payload.imageDTOs.map(({ image_name }) => ({ image_name })));
|
||||
}
|
||||
|
||||
dispatch(fieldImageCollectionValueChanged({ ...fieldIdentifier, value: newValue }));
|
||||
},
|
||||
};
|
||||
//#endregion
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppDispatch, RootState } from 'app/store/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { selectDefaultIPAdapter } from 'features/controlLayers/hooks/addLayerHooks';
|
||||
@@ -31,19 +30,15 @@ import { imageDTOToImageObject, imageDTOToImageWithDims, initialControlNet } fro
|
||||
import { calculateNewSize } from 'features/controlLayers/util/getScaledBoundingBoxDimensions';
|
||||
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
|
||||
import type { BoardId } from 'features/gallery/store/types';
|
||||
import { fieldImageCollectionValueChanged, fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { selectFieldInputInstance, selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { type FieldIdentifier, isImageFieldCollectionInputInstance } from 'features/nodes/types/field';
|
||||
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { FieldIdentifier } from 'features/nodes/types/field';
|
||||
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
|
||||
import { getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||
import { uniqBy } from 'lodash-es';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import type { Equals } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const log = logger('system');
|
||||
|
||||
export const setGlobalReferenceImage = (arg: {
|
||||
imageDTO: ImageDTO;
|
||||
entityIdentifier: CanvasEntityIdentifier<'reference_image'>;
|
||||
@@ -77,54 +72,6 @@ export const setNodeImageFieldImage = (arg: {
|
||||
dispatch(fieldImageValueChanged({ ...fieldIdentifier, value: imageDTO }));
|
||||
};
|
||||
|
||||
export const addImagesToNodeImageFieldCollectionAction = (arg: {
|
||||
imageDTOs: ImageDTO[];
|
||||
fieldIdentifier: FieldIdentifier;
|
||||
dispatch: AppDispatch;
|
||||
getState: () => RootState;
|
||||
}) => {
|
||||
const { imageDTOs, fieldIdentifier, dispatch, getState } = arg;
|
||||
const fieldInputInstance = selectFieldInputInstance(
|
||||
selectNodesSlice(getState()),
|
||||
fieldIdentifier.nodeId,
|
||||
fieldIdentifier.fieldName
|
||||
);
|
||||
|
||||
if (!isImageFieldCollectionInputInstance(fieldInputInstance)) {
|
||||
log.warn({ fieldIdentifier }, 'Attempted to add images to a non-image field collection');
|
||||
return;
|
||||
}
|
||||
|
||||
const images = fieldInputInstance.value ? [...fieldInputInstance.value] : [];
|
||||
images.push(...imageDTOs.map(({ image_name }) => ({ image_name })));
|
||||
const uniqueImages = uniqBy(images, 'image_name');
|
||||
dispatch(fieldImageCollectionValueChanged({ ...fieldIdentifier, value: uniqueImages }));
|
||||
};
|
||||
|
||||
export const removeImageFromNodeImageFieldCollectionAction = (arg: {
|
||||
imageName: string;
|
||||
fieldIdentifier: FieldIdentifier;
|
||||
dispatch: AppDispatch;
|
||||
getState: () => RootState;
|
||||
}) => {
|
||||
const { imageName, fieldIdentifier, dispatch, getState } = arg;
|
||||
const fieldInputInstance = selectFieldInputInstance(
|
||||
selectNodesSlice(getState()),
|
||||
fieldIdentifier.nodeId,
|
||||
fieldIdentifier.fieldName
|
||||
);
|
||||
|
||||
if (!isImageFieldCollectionInputInstance(fieldInputInstance)) {
|
||||
log.warn({ fieldIdentifier }, 'Attempted to remove image from a non-image field collection');
|
||||
return;
|
||||
}
|
||||
|
||||
const images = fieldInputInstance.value ? [...fieldInputInstance.value] : [];
|
||||
const imagesWithoutTheImageToRemove = images.filter((image) => image.image_name !== imageName);
|
||||
const uniqueImages = uniqBy(imagesWithoutTheImageToRemove, 'image_name');
|
||||
dispatch(fieldImageCollectionValueChanged({ ...fieldIdentifier, value: uniqueImages }));
|
||||
};
|
||||
|
||||
export const setComparisonImage = (arg: { imageDTO: ImageDTO; dispatch: AppDispatch }) => {
|
||||
const { imageDTO, dispatch } = arg;
|
||||
dispatch(imageToCompareChanged(imageDTO));
|
||||
|
||||
@@ -36,8 +36,10 @@ const FieldHandle = (props: FieldHandleProps) => {
|
||||
borderWidth: !isSingle(type) ? 4 : 0,
|
||||
borderStyle: 'solid',
|
||||
borderColor: color,
|
||||
borderRadius: isModelType ? 4 : '100%',
|
||||
borderRadius: isModelType || type.batch ? 4 : '100%',
|
||||
zIndex: 1,
|
||||
transform: type.batch ? 'rotate(45deg) translateX(-0.3rem) translateY(-0.3rem)' : 'none',
|
||||
transformOrigin: 'center',
|
||||
};
|
||||
|
||||
if (handleType === 'target') {
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
import { FloatGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FloatGeneratorFieldComponent';
|
||||
import { ImageFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldCollectionInputComponent';
|
||||
import { IntegerGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerGeneratorFieldComponent';
|
||||
import ModelIdentifierFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent';
|
||||
import { NumberFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/NumberFieldCollectionInputComponent';
|
||||
import { StringFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/StringFieldCollectionInputComponent';
|
||||
import { StringGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/StringGeneratorFieldComponent';
|
||||
import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance';
|
||||
import { useFieldInputTemplate } from 'features/nodes/hooks/useFieldInputTemplate';
|
||||
import {
|
||||
@@ -21,8 +26,12 @@ import {
|
||||
isControlNetModelFieldInputTemplate,
|
||||
isEnumFieldInputInstance,
|
||||
isEnumFieldInputTemplate,
|
||||
isFloatFieldCollectionInputInstance,
|
||||
isFloatFieldCollectionInputTemplate,
|
||||
isFloatFieldInputInstance,
|
||||
isFloatFieldInputTemplate,
|
||||
isFloatGeneratorFieldInputInstance,
|
||||
isFloatGeneratorFieldInputTemplate,
|
||||
isFluxMainModelFieldInputInstance,
|
||||
isFluxMainModelFieldInputTemplate,
|
||||
isFluxVAEModelFieldInputInstance,
|
||||
@@ -31,8 +40,12 @@ import {
|
||||
isImageFieldCollectionInputTemplate,
|
||||
isImageFieldInputInstance,
|
||||
isImageFieldInputTemplate,
|
||||
isIntegerFieldCollectionInputInstance,
|
||||
isIntegerFieldCollectionInputTemplate,
|
||||
isIntegerFieldInputInstance,
|
||||
isIntegerFieldInputTemplate,
|
||||
isIntegerGeneratorFieldInputInstance,
|
||||
isIntegerGeneratorFieldInputTemplate,
|
||||
isIPAdapterModelFieldInputInstance,
|
||||
isIPAdapterModelFieldInputTemplate,
|
||||
isLoRAModelFieldInputInstance,
|
||||
@@ -51,8 +64,12 @@ import {
|
||||
isSDXLRefinerModelFieldInputTemplate,
|
||||
isSpandrelImageToImageModelFieldInputInstance,
|
||||
isSpandrelImageToImageModelFieldInputTemplate,
|
||||
isStringFieldCollectionInputInstance,
|
||||
isStringFieldCollectionInputTemplate,
|
||||
isStringFieldInputInstance,
|
||||
isStringFieldInputTemplate,
|
||||
isStringGeneratorFieldInputInstance,
|
||||
isStringGeneratorFieldInputTemplate,
|
||||
isT2IAdapterModelFieldInputInstance,
|
||||
isT2IAdapterModelFieldInputTemplate,
|
||||
isT5EncoderModelFieldInputInstance,
|
||||
@@ -97,6 +114,10 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
const fieldInstance = useFieldInputInstance(nodeId, fieldName);
|
||||
const fieldTemplate = useFieldInputTemplate(nodeId, fieldName);
|
||||
|
||||
if (isStringFieldCollectionInputInstance(fieldInstance) && isStringFieldCollectionInputTemplate(fieldTemplate)) {
|
||||
return <StringFieldCollectionInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isStringFieldInputInstance(fieldInstance) && isStringFieldInputTemplate(fieldTemplate)) {
|
||||
return <StringFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
@@ -105,13 +126,22 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
return <BooleanFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (
|
||||
(isIntegerFieldInputInstance(fieldInstance) && isIntegerFieldInputTemplate(fieldTemplate)) ||
|
||||
(isFloatFieldInputInstance(fieldInstance) && isFloatFieldInputTemplate(fieldTemplate))
|
||||
) {
|
||||
if (isIntegerFieldInputInstance(fieldInstance) && isIntegerFieldInputTemplate(fieldTemplate)) {
|
||||
return <NumberFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isFloatFieldInputInstance(fieldInstance) && isFloatFieldInputTemplate(fieldTemplate)) {
|
||||
return <NumberFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isIntegerFieldCollectionInputInstance(fieldInstance) && isIntegerFieldCollectionInputTemplate(fieldTemplate)) {
|
||||
return <NumberFieldCollectionInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isFloatFieldCollectionInputInstance(fieldInstance) && isFloatFieldCollectionInputTemplate(fieldTemplate)) {
|
||||
return <NumberFieldCollectionInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isEnumFieldInputInstance(fieldInstance) && isEnumFieldInputTemplate(fieldTemplate)) {
|
||||
return <EnumFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
@@ -216,6 +246,18 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
return <SchedulerFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isFloatGeneratorFieldInputInstance(fieldInstance) && isFloatGeneratorFieldInputTemplate(fieldTemplate)) {
|
||||
return <FloatGeneratorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isIntegerGeneratorFieldInputInstance(fieldInstance) && isIntegerGeneratorFieldInputTemplate(fieldTemplate)) {
|
||||
return <IntegerGeneratorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isStringGeneratorFieldInputInstance(fieldInstance) && isStringGeneratorFieldInputTemplate(fieldTemplate)) {
|
||||
return <StringGeneratorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (fieldTemplate) {
|
||||
// Fallback for when there is no component for the type
|
||||
return null;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import { selectInvocationNode, selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
@@ -18,7 +18,7 @@ export const InvocationInputFieldCheck = memo(({ nodeId, fieldName, children }:
|
||||
const templates = useStore($templates);
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(selectNodesSlice, (nodesSlice) => {
|
||||
createMemoizedSelector(selectNodesSlice, (nodesSlice) => {
|
||||
const node = selectInvocationNode(nodesSlice, nodeId);
|
||||
const instance = node.data.inputs[fieldName];
|
||||
const template = templates[node.data.type];
|
||||
|
||||
@@ -26,7 +26,7 @@ const EnumFieldInputComponent = (props: FieldComponentProps<EnumFieldInputInstan
|
||||
);
|
||||
|
||||
return (
|
||||
<Select className="nowheel nodrag" onChange={handleValueChanged} value={field.value}>
|
||||
<Select className="nowheel nodrag" onChange={handleValueChanged} value={field.value} size="sm">
|
||||
{fieldTemplate.options.map((option) => (
|
||||
<option key={option} value={option}>
|
||||
{fieldTemplate.ui_choice_labels ? fieldTemplate.ui_choice_labels[option] : option}
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
import { CompositeNumberInput, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import type { FloatGeneratorArithmeticSequence } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
type FloatGeneratorArithmeticSequenceSettingsProps = {
|
||||
state: FloatGeneratorArithmeticSequence;
|
||||
onChange: (state: FloatGeneratorArithmeticSequence) => void;
|
||||
};
|
||||
export const FloatGeneratorArithmeticSequenceSettings = memo(
|
||||
({ state, onChange }: FloatGeneratorArithmeticSequenceSettingsProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const onChangeStart = useCallback(
|
||||
(start: number) => {
|
||||
onChange({ ...state, start });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
const onChangeStep = useCallback(
|
||||
(step: number) => {
|
||||
onChange({ ...state, step });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
const onChangeCount = useCallback(
|
||||
(count: number) => {
|
||||
onChange({ ...state, count });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex gap={2} alignItems="flex-end">
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.start')}</FormLabel>
|
||||
<CompositeNumberInput
|
||||
value={state.start}
|
||||
onChange={onChangeStart}
|
||||
min={-Infinity}
|
||||
max={Infinity}
|
||||
step={0.01}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.step')}</FormLabel>
|
||||
<CompositeNumberInput value={state.step} onChange={onChangeStep} min={-Infinity} max={Infinity} step={0.01} />
|
||||
</FormControl>
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.count')}</FormLabel>
|
||||
<CompositeNumberInput value={state.count} onChange={onChangeCount} min={1} max={Infinity} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
);
|
||||
FloatGeneratorArithmeticSequenceSettings.displayName = 'FloatGeneratorArithmeticSequenceSettings';
|
||||
@@ -0,0 +1,120 @@
|
||||
import { Flex, Select, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { getOverlayScrollbarsParams, overlayScrollbarsStyles } from 'common/components/OverlayScrollbars/constants';
|
||||
import { FloatGeneratorArithmeticSequenceSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FloatGeneratorArithmeticSequenceSettings';
|
||||
import { FloatGeneratorLinearDistributionSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FloatGeneratorLinearDistributionSettings';
|
||||
import { FloatGeneratorParseStringSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FloatGeneratorParseStringSettings';
|
||||
import { FloatGeneratorUniformRandomDistributionSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FloatGeneratorUniformRandomDistributionSettings';
|
||||
import type { FieldComponentProps } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/types';
|
||||
import { fieldFloatGeneratorValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { FloatGeneratorFieldInputInstance, FloatGeneratorFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import {
|
||||
FloatGeneratorArithmeticSequenceType,
|
||||
FloatGeneratorLinearDistributionType,
|
||||
FloatGeneratorParseStringType,
|
||||
FloatGeneratorUniformRandomDistributionType,
|
||||
getFloatGeneratorDefaults,
|
||||
resolveFloatGeneratorField,
|
||||
} from 'features/nodes/types/field';
|
||||
import { isNil, round } from 'lodash-es';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useDebounce } from 'use-debounce';
|
||||
|
||||
const overlayscrollbarsOptions = getOverlayScrollbarsParams().options;
|
||||
|
||||
export const FloatGeneratorFieldInputComponent = memo(
|
||||
(props: FieldComponentProps<FloatGeneratorFieldInputInstance, FloatGeneratorFieldInputTemplate>) => {
|
||||
const { nodeId, field } = props;
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const onChange = useCallback(
|
||||
(value: FloatGeneratorFieldInputInstance['value']) => {
|
||||
dispatch(
|
||||
fieldFloatGeneratorValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
const onChangeGeneratorType = useCallback(
|
||||
(e: ChangeEvent<HTMLSelectElement>) => {
|
||||
const value = getFloatGeneratorDefaults(e.target.value as FloatGeneratorFieldInputInstance['value']['type']);
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldFloatGeneratorValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
const [debouncedField] = useDebounce(field, 300);
|
||||
const resolvedValuesAsString = useMemo(() => {
|
||||
if (
|
||||
debouncedField.value.type === FloatGeneratorUniformRandomDistributionType &&
|
||||
isNil(debouncedField.value.seed)
|
||||
) {
|
||||
const { count } = debouncedField.value;
|
||||
return `<${t('nodes.generatorNRandomValues', { count })}>`;
|
||||
}
|
||||
const resolvedValues = resolveFloatGeneratorField(debouncedField);
|
||||
if (resolvedValues.length === 0) {
|
||||
return `<${t('nodes.generatorNoValues')}>`;
|
||||
} else {
|
||||
return resolvedValues.map((val) => round(val, 2)).join(', ');
|
||||
}
|
||||
}, [debouncedField, t]);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" gap={2}>
|
||||
<Select className="nowheel nodrag" onChange={onChangeGeneratorType} value={field.value.type} size="sm">
|
||||
<option value={FloatGeneratorArithmeticSequenceType}>{t('nodes.arithmeticSequence')}</option>
|
||||
<option value={FloatGeneratorLinearDistributionType}>{t('nodes.linearDistribution')}</option>
|
||||
<option value={FloatGeneratorUniformRandomDistributionType}>{t('nodes.uniformRandomDistribution')}</option>
|
||||
<option value={FloatGeneratorParseStringType}>{t('nodes.parseString')}</option>
|
||||
</Select>
|
||||
{field.value.type === FloatGeneratorArithmeticSequenceType && (
|
||||
<FloatGeneratorArithmeticSequenceSettings state={field.value} onChange={onChange} />
|
||||
)}
|
||||
{field.value.type === FloatGeneratorLinearDistributionType && (
|
||||
<FloatGeneratorLinearDistributionSettings state={field.value} onChange={onChange} />
|
||||
)}
|
||||
{field.value.type === FloatGeneratorUniformRandomDistributionType && (
|
||||
<FloatGeneratorUniformRandomDistributionSettings state={field.value} onChange={onChange} />
|
||||
)}
|
||||
{field.value.type === FloatGeneratorParseStringType && (
|
||||
<FloatGeneratorParseStringSettings state={field.value} onChange={onChange} />
|
||||
)}
|
||||
<Flex w="full" h="full" p={2} borderWidth={1} borderRadius="base" maxH={128}>
|
||||
<Flex w="full" h="auto">
|
||||
<OverlayScrollbarsComponent
|
||||
className="nodrag nowheel"
|
||||
defer
|
||||
style={overlayScrollbarsStyles}
|
||||
options={overlayscrollbarsOptions}
|
||||
>
|
||||
<Text className="nodrag nowheel" fontFamily="monospace" userSelect="text" cursor="text">
|
||||
{resolvedValuesAsString}
|
||||
</Text>
|
||||
</OverlayScrollbarsComponent>
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
FloatGeneratorFieldInputComponent.displayName = 'FloatGeneratorFieldInputComponent';
|
||||
@@ -0,0 +1,57 @@
|
||||
import { CompositeNumberInput, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import type { FloatGeneratorLinearDistribution } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
type FloatGeneratorLinearDistributionSettingsProps = {
|
||||
state: FloatGeneratorLinearDistribution;
|
||||
onChange: (state: FloatGeneratorLinearDistribution) => void;
|
||||
};
|
||||
export const FloatGeneratorLinearDistributionSettings = memo(
|
||||
({ state, onChange }: FloatGeneratorLinearDistributionSettingsProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const onChangeStart = useCallback(
|
||||
(start: number) => {
|
||||
onChange({ ...state, start });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
const onChangeEnd = useCallback(
|
||||
(end: number) => {
|
||||
onChange({ ...state, end });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
const onChangeCount = useCallback(
|
||||
(count: number) => {
|
||||
onChange({ ...state, count });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex gap={2} alignItems="flex-end">
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.start')}</FormLabel>
|
||||
<CompositeNumberInput
|
||||
value={state.start}
|
||||
onChange={onChangeStart}
|
||||
min={-Infinity}
|
||||
max={Infinity}
|
||||
step={0.01}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.end')}</FormLabel>
|
||||
<CompositeNumberInput value={state.end} onChange={onChangeEnd} min={-Infinity} max={Infinity} step={0.01} />
|
||||
</FormControl>
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.count')}</FormLabel>
|
||||
<CompositeNumberInput value={state.count} onChange={onChangeCount} min={1} max={Infinity} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
);
|
||||
FloatGeneratorLinearDistributionSettings.displayName = 'FloatGeneratorLinearDistributionSettings';
|
||||
@@ -0,0 +1,39 @@
|
||||
import { Flex, FormControl, FormLabel, Input } from '@invoke-ai/ui-library';
|
||||
import { GeneratorTextareaWithFileUpload } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/GeneratorTextareaWithFileUpload';
|
||||
import type { FloatGeneratorParseString } from 'features/nodes/types/field';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
type FloatGeneratorParseStringSettingsProps = {
|
||||
state: FloatGeneratorParseString;
|
||||
onChange: (state: FloatGeneratorParseString) => void;
|
||||
};
|
||||
export const FloatGeneratorParseStringSettings = memo(({ state, onChange }: FloatGeneratorParseStringSettingsProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const onChangeSplitOn = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => {
|
||||
onChange({ ...state, splitOn: e.target.value });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
|
||||
const onChangeInput = useCallback(
|
||||
(input: string) => {
|
||||
onChange({ ...state, input });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex gap={2} flexDir="column">
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('nodes.splitOn')}</FormLabel>
|
||||
<Input value={state.splitOn} onChange={onChangeSplitOn} />
|
||||
</FormControl>
|
||||
<GeneratorTextareaWithFileUpload value={state.input} onChange={onChangeInput} />
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
FloatGeneratorParseStringSettings.displayName = 'FloatGeneratorParseStringSettings';
|
||||
@@ -0,0 +1,78 @@
|
||||
import { Checkbox, CompositeNumberInput, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import type { FloatGeneratorUniformRandomDistribution } from 'features/nodes/types/field';
|
||||
import { isNil } from 'lodash-es';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
type FloatGeneratorUniformRandomDistributionSettingsProps = {
|
||||
state: FloatGeneratorUniformRandomDistribution;
|
||||
onChange: (state: FloatGeneratorUniformRandomDistribution) => void;
|
||||
};
|
||||
export const FloatGeneratorUniformRandomDistributionSettings = memo(
|
||||
({ state, onChange }: FloatGeneratorUniformRandomDistributionSettingsProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const onChangeMin = useCallback(
|
||||
(min: number) => {
|
||||
onChange({ ...state, min });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
const onChangeMax = useCallback(
|
||||
(max: number) => {
|
||||
onChange({ ...state, max });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
const onChangeCount = useCallback(
|
||||
(count: number) => {
|
||||
onChange({ ...state, count });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
const onToggleSeed = useCallback(() => {
|
||||
onChange({ ...state, seed: isNil(state.seed) ? 0 : null });
|
||||
}, [onChange, state]);
|
||||
const onChangeSeed = useCallback(
|
||||
(seed?: number | null) => {
|
||||
onChange({ ...state, seed });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex gap={2} flexDir="column">
|
||||
<Flex gap={2} alignItems="flex-end">
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.min')}</FormLabel>
|
||||
<CompositeNumberInput value={state.min} onChange={onChangeMin} min={-Infinity} max={Infinity} step={0.01} />
|
||||
</FormControl>
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.max')}</FormLabel>
|
||||
<CompositeNumberInput value={state.max} onChange={onChangeMax} min={-Infinity} max={Infinity} step={0.01} />
|
||||
</FormControl>
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.count')}</FormLabel>
|
||||
<CompositeNumberInput value={state.count} onChange={onChangeCount} min={1} max={Infinity} />
|
||||
</FormControl>
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel alignItems="center" justifyContent="space-between" m={0} display="flex" w="full">
|
||||
{t('common.seed')}
|
||||
<Checkbox onChange={onToggleSeed} isChecked={!isNil(state.seed)} />
|
||||
</FormLabel>
|
||||
<CompositeNumberInput
|
||||
isDisabled={isNil(state.seed)}
|
||||
// This cast is save only because we disable the element when seed is not a number - the `...` is
|
||||
// rendered in the input field in this case
|
||||
value={state.seed ?? ('...' as unknown as number)}
|
||||
onChange={onChangeSeed}
|
||||
min={-Infinity}
|
||||
max={Infinity}
|
||||
/>
|
||||
</FormControl>
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
);
|
||||
FloatGeneratorUniformRandomDistributionSettings.displayName = 'FloatGeneratorUniformRandomDistributionSettings';
|
||||
@@ -0,0 +1,85 @@
|
||||
import { FormControl, FormLabel, IconButton, Spacer, Textarea } from '@invoke-ai/ui-library';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { isString } from 'lodash-es';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useDropzone } from 'react-dropzone';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiUploadFill } from 'react-icons/pi';
|
||||
|
||||
const MAX_SIZE = 1024 * 128; // 128KB, we don't want to load huge files into node values...
|
||||
|
||||
type Props = {
|
||||
value: string;
|
||||
onChange: (value: string) => void;
|
||||
};
|
||||
|
||||
export const GeneratorTextareaWithFileUpload = memo(({ value, onChange }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const onDropAccepted = useCallback(
|
||||
(files: File[]) => {
|
||||
const file = files[0];
|
||||
if (!file) {
|
||||
return;
|
||||
}
|
||||
const reader = new FileReader();
|
||||
reader.onload = () => {
|
||||
const result = reader.result;
|
||||
if (!isString(result)) {
|
||||
return;
|
||||
}
|
||||
onChange(result);
|
||||
};
|
||||
reader.onerror = () => {
|
||||
toast({
|
||||
title: 'Failed to load file',
|
||||
status: 'error',
|
||||
});
|
||||
};
|
||||
reader.readAsText(file);
|
||||
},
|
||||
[onChange]
|
||||
);
|
||||
|
||||
const { getInputProps, getRootProps } = useDropzone({
|
||||
accept: { 'text/csv': ['.csv'], 'text/plain': ['.txt'] },
|
||||
maxSize: MAX_SIZE,
|
||||
onDropAccepted,
|
||||
noDrag: true,
|
||||
multiple: false,
|
||||
});
|
||||
|
||||
const onChangeInput = useCallback(
|
||||
(e: ChangeEvent<HTMLTextAreaElement>) => {
|
||||
onChange(e.target.value);
|
||||
},
|
||||
[onChange]
|
||||
);
|
||||
|
||||
return (
|
||||
<FormControl orientation="vertical" position="relative" alignItems="stretch">
|
||||
<FormLabel m={0} display="flex" alignItems="center">
|
||||
{t('common.input')}
|
||||
<Spacer />
|
||||
<IconButton
|
||||
tooltip={t('nodes.generatorLoadFromFile')}
|
||||
aria-label={t('nodes.generatorLoadFromFile')}
|
||||
icon={<PiUploadFill />}
|
||||
variant="link"
|
||||
{...getRootProps()}
|
||||
/>
|
||||
<input {...getInputProps()} />
|
||||
</FormLabel>
|
||||
<Textarea
|
||||
className="nowheel nodrag nopan"
|
||||
value={value}
|
||||
onChange={onChangeInput}
|
||||
p={2}
|
||||
resize="none"
|
||||
rows={5}
|
||||
fontSize="sm"
|
||||
/>
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
GeneratorTextareaWithFileUpload.displayName = 'GeneratorTextareaWithFileUpload';
|
||||
@@ -10,9 +10,9 @@ import { addImagesToNodeImageFieldCollectionDndTarget } from 'features/dnd/dnd';
|
||||
import { DndDropTarget } from 'features/dnd/DndDropTarget';
|
||||
import { DndImage } from 'features/dnd/DndImage';
|
||||
import { DndImageIcon } from 'features/dnd/DndImageIcon';
|
||||
import { removeImageFromNodeImageFieldCollectionAction } from 'features/imageActions/actions';
|
||||
import { useFieldIsInvalid } from 'features/nodes/hooks/useFieldIsInvalid';
|
||||
import { fieldImageCollectionValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { ImageField } from 'features/nodes/types/common';
|
||||
import type { ImageFieldCollectionInputInstance, ImageFieldCollectionInputTemplate } from 'features/nodes/types/field';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
@@ -61,15 +61,12 @@ export const ImageFieldCollectionInputComponent = memo(
|
||||
);
|
||||
|
||||
const onRemoveImage = useCallback(
|
||||
(imageName: string) => {
|
||||
removeImageFromNodeImageFieldCollectionAction({
|
||||
imageName,
|
||||
fieldIdentifier: { nodeId, fieldName: field.name },
|
||||
dispatch: store.dispatch,
|
||||
getState: store.getState,
|
||||
});
|
||||
(index: number) => {
|
||||
const newValue = field.value ? [...field.value] : [];
|
||||
newValue.splice(index, 1);
|
||||
store.dispatch(fieldImageCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
|
||||
},
|
||||
[field.name, nodeId, store.dispatch, store.getState]
|
||||
[field.name, field.value, nodeId, store]
|
||||
);
|
||||
|
||||
return (
|
||||
@@ -90,7 +87,7 @@ export const ImageFieldCollectionInputComponent = memo(
|
||||
isError={isInvalid}
|
||||
onUpload={onUpload}
|
||||
fontSize={24}
|
||||
variant="outline"
|
||||
variant="ghost"
|
||||
/>
|
||||
)}
|
||||
{field.value && field.value.length > 0 && (
|
||||
@@ -102,9 +99,9 @@ export const ImageFieldCollectionInputComponent = memo(
|
||||
options={overlayscrollbarsOptions}
|
||||
>
|
||||
<Grid w="full" h="full" templateColumns="repeat(4, 1fr)" gap={1}>
|
||||
{field.value.map(({ image_name }) => (
|
||||
<GridItem key={image_name} position="relative" className="nodrag">
|
||||
<ImageGridItemContent imageName={image_name} onRemoveImage={onRemoveImage} />
|
||||
{field.value.map((value, index) => (
|
||||
<GridItem key={index} position="relative" className="nodrag">
|
||||
<ImageGridItemContent value={value} index={index} onRemoveImage={onRemoveImage} />
|
||||
</GridItem>
|
||||
))}
|
||||
</Grid>
|
||||
@@ -124,11 +121,11 @@ export const ImageFieldCollectionInputComponent = memo(
|
||||
ImageFieldCollectionInputComponent.displayName = 'ImageFieldCollectionInputComponent';
|
||||
|
||||
const ImageGridItemContent = memo(
|
||||
({ imageName, onRemoveImage }: { imageName: string; onRemoveImage: (imageName: string) => void }) => {
|
||||
const query = useGetImageDTOQuery(imageName);
|
||||
({ value, index, onRemoveImage }: { value: ImageField; index: number; onRemoveImage: (index: number) => void }) => {
|
||||
const query = useGetImageDTOQuery(value.image_name);
|
||||
const onClickRemove = useCallback(() => {
|
||||
onRemoveImage(imageName);
|
||||
}, [imageName, onRemoveImage]);
|
||||
onRemoveImage(index);
|
||||
}, [index, onRemoveImage]);
|
||||
|
||||
if (query.isLoading) {
|
||||
return <IAINoContentFallbackWithSpinner />;
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
import { CompositeNumberInput, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import type { IntegerGeneratorArithmeticSequence } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
type IntegerGeneratorArithmeticSequenceSettingsProps = {
|
||||
state: IntegerGeneratorArithmeticSequence;
|
||||
onChange: (state: IntegerGeneratorArithmeticSequence) => void;
|
||||
};
|
||||
export const IntegerGeneratorArithmeticSequenceSettings = memo(
|
||||
({ state, onChange }: IntegerGeneratorArithmeticSequenceSettingsProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const onChangeStart = useCallback(
|
||||
(start: number) => {
|
||||
onChange({ ...state, start });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
const onChangeStep = useCallback(
|
||||
(step: number) => {
|
||||
onChange({ ...state, step });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
const onChangeCount = useCallback(
|
||||
(count: number) => {
|
||||
onChange({ ...state, count });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex gap={2} alignItems="flex-end">
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.start')}</FormLabel>
|
||||
<CompositeNumberInput value={state.start} onChange={onChangeStart} min={-Infinity} max={Infinity} />
|
||||
</FormControl>
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.step')}</FormLabel>
|
||||
<CompositeNumberInput value={state.step} onChange={onChangeStep} min={-Infinity} max={Infinity} />
|
||||
</FormControl>
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.count')}</FormLabel>
|
||||
<CompositeNumberInput value={state.count} onChange={onChangeCount} min={1} max={Infinity} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
);
|
||||
IntegerGeneratorArithmeticSequenceSettings.displayName = 'IntegerGeneratorArithmeticSequenceSettings';
|
||||
@@ -0,0 +1,122 @@
|
||||
import { Flex, Select, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { getOverlayScrollbarsParams, overlayScrollbarsStyles } from 'common/components/OverlayScrollbars/constants';
|
||||
import { IntegerGeneratorArithmeticSequenceSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerGeneratorArithmeticSequenceSettings';
|
||||
import { IntegerGeneratorLinearDistributionSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerGeneratorLinearDistributionSettings';
|
||||
import { IntegerGeneratorParseStringSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerGeneratorParseStringSettings';
|
||||
import { IntegerGeneratorUniformRandomDistributionSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerGeneratorUniformRandomDistributionSettings';
|
||||
import type { FieldComponentProps } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/types';
|
||||
import { fieldIntegerGeneratorValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type {
|
||||
IntegerGeneratorFieldInputInstance,
|
||||
IntegerGeneratorFieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import {
|
||||
getIntegerGeneratorDefaults,
|
||||
IntegerGeneratorArithmeticSequenceType,
|
||||
IntegerGeneratorLinearDistributionType,
|
||||
IntegerGeneratorParseStringType,
|
||||
IntegerGeneratorUniformRandomDistributionType,
|
||||
resolveIntegerGeneratorField,
|
||||
} from 'features/nodes/types/field';
|
||||
import { isNil, round } from 'lodash-es';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useDebounce } from 'use-debounce';
|
||||
|
||||
const overlayscrollbarsOptions = getOverlayScrollbarsParams().options;
|
||||
|
||||
export const IntegerGeneratorFieldInputComponent = memo(
|
||||
(props: FieldComponentProps<IntegerGeneratorFieldInputInstance, IntegerGeneratorFieldInputTemplate>) => {
|
||||
const { nodeId, field } = props;
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const onChange = useCallback(
|
||||
(value: IntegerGeneratorFieldInputInstance['value']) => {
|
||||
dispatch(
|
||||
fieldIntegerGeneratorValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
const onChangeGeneratorType = useCallback(
|
||||
(e: ChangeEvent<HTMLSelectElement>) => {
|
||||
const value = getIntegerGeneratorDefaults(
|
||||
e.target.value as IntegerGeneratorFieldInputInstance['value']['type']
|
||||
);
|
||||
dispatch(
|
||||
fieldIntegerGeneratorValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
const [debouncedField] = useDebounce(field, 300);
|
||||
const resolvedValuesAsString = useMemo(() => {
|
||||
if (
|
||||
debouncedField.value.type === IntegerGeneratorUniformRandomDistributionType &&
|
||||
isNil(debouncedField.value.seed)
|
||||
) {
|
||||
const { count } = debouncedField.value;
|
||||
return `<${t('nodes.generatorNRandomValues', { count })}>`;
|
||||
}
|
||||
const resolvedValues = resolveIntegerGeneratorField(debouncedField);
|
||||
if (resolvedValues.length === 0) {
|
||||
return `<${t('nodes.generatorNoValues')}>`;
|
||||
} else {
|
||||
return resolvedValues.map((val) => round(val, 2)).join(', ');
|
||||
}
|
||||
}, [debouncedField, t]);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" gap={2}>
|
||||
<Select className="nowheel nodrag" onChange={onChangeGeneratorType} value={field.value.type} size="sm">
|
||||
<option value={IntegerGeneratorArithmeticSequenceType}>{t('nodes.arithmeticSequence')}</option>
|
||||
<option value={IntegerGeneratorLinearDistributionType}>{t('nodes.linearDistribution')}</option>
|
||||
<option value={IntegerGeneratorUniformRandomDistributionType}>{t('nodes.uniformRandomDistribution')}</option>
|
||||
<option value={IntegerGeneratorParseStringType}>{t('nodes.parseString')}</option>
|
||||
</Select>
|
||||
{field.value.type === IntegerGeneratorArithmeticSequenceType && (
|
||||
<IntegerGeneratorArithmeticSequenceSettings state={field.value} onChange={onChange} />
|
||||
)}
|
||||
{field.value.type === IntegerGeneratorLinearDistributionType && (
|
||||
<IntegerGeneratorLinearDistributionSettings state={field.value} onChange={onChange} />
|
||||
)}
|
||||
{field.value.type === IntegerGeneratorUniformRandomDistributionType && (
|
||||
<IntegerGeneratorUniformRandomDistributionSettings state={field.value} onChange={onChange} />
|
||||
)}
|
||||
{field.value.type === IntegerGeneratorParseStringType && (
|
||||
<IntegerGeneratorParseStringSettings state={field.value} onChange={onChange} />
|
||||
)}
|
||||
<Flex w="full" h="full" p={2} borderWidth={1} borderRadius="base" maxH={128}>
|
||||
<Flex w="full" h="auto">
|
||||
<OverlayScrollbarsComponent
|
||||
className="nodrag nowheel"
|
||||
defer
|
||||
style={overlayScrollbarsStyles}
|
||||
options={overlayscrollbarsOptions}
|
||||
>
|
||||
<Text className="nodrag nowheel" fontFamily="monospace" userSelect="text" cursor="text">
|
||||
{resolvedValuesAsString}
|
||||
</Text>
|
||||
</OverlayScrollbarsComponent>
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
IntegerGeneratorFieldInputComponent.displayName = 'IntegerGeneratorFieldInputComponent';
|
||||
@@ -0,0 +1,51 @@
|
||||
import { CompositeNumberInput, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import type { IntegerGeneratorLinearDistribution } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
type IntegerGeneratorLinearDistributionSettingsProps = {
|
||||
state: IntegerGeneratorLinearDistribution;
|
||||
onChange: (state: IntegerGeneratorLinearDistribution) => void;
|
||||
};
|
||||
export const IntegerGeneratorLinearDistributionSettings = memo(
|
||||
({ state, onChange }: IntegerGeneratorLinearDistributionSettingsProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const onChangeStart = useCallback(
|
||||
(start: number) => {
|
||||
onChange({ ...state, start });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
const onChangeEnd = useCallback(
|
||||
(end: number) => {
|
||||
onChange({ ...state, end });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
const onChangeCount = useCallback(
|
||||
(count: number) => {
|
||||
onChange({ ...state, count });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex gap={2} alignItems="flex-end">
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.start')}</FormLabel>
|
||||
<CompositeNumberInput value={state.start} onChange={onChangeStart} min={-Infinity} max={Infinity} />
|
||||
</FormControl>
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.end')}</FormLabel>
|
||||
<CompositeNumberInput value={state.end} onChange={onChangeEnd} min={-Infinity} max={Infinity} />
|
||||
</FormControl>
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.count')}</FormLabel>
|
||||
<CompositeNumberInput value={state.count} onChange={onChangeCount} min={1} max={Infinity} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
);
|
||||
IntegerGeneratorLinearDistributionSettings.displayName = 'IntegerGeneratorLinearDistributionSettings';
|
||||
@@ -0,0 +1,41 @@
|
||||
import { Flex, FormControl, FormLabel, Input } from '@invoke-ai/ui-library';
|
||||
import { GeneratorTextareaWithFileUpload } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/GeneratorTextareaWithFileUpload';
|
||||
import type { IntegerGeneratorParseString } from 'features/nodes/types/field';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
type IntegerGeneratorParseStringSettingsProps = {
|
||||
state: IntegerGeneratorParseString;
|
||||
onChange: (state: IntegerGeneratorParseString) => void;
|
||||
};
|
||||
export const IntegerGeneratorParseStringSettings = memo(
|
||||
({ state, onChange }: IntegerGeneratorParseStringSettingsProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const onChangeSplitOn = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => {
|
||||
onChange({ ...state, splitOn: e.target.value });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
|
||||
const onChangeInput = useCallback(
|
||||
(input: string) => {
|
||||
onChange({ ...state, input });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex gap={2} flexDir="column">
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('nodes.splitOn')}</FormLabel>
|
||||
<Input value={state.splitOn} onChange={onChangeSplitOn} />
|
||||
</FormControl>
|
||||
<GeneratorTextareaWithFileUpload value={state.input} onChange={onChangeInput} />
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
);
|
||||
IntegerGeneratorParseStringSettings.displayName = 'IntegerGeneratorParseStringSettings';
|
||||
@@ -0,0 +1,78 @@
|
||||
import { Checkbox, CompositeNumberInput, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import type { IntegerGeneratorUniformRandomDistribution } from 'features/nodes/types/field';
|
||||
import { isNil } from 'lodash-es';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
type IntegerGeneratorUniformRandomDistributionSettingsProps = {
|
||||
state: IntegerGeneratorUniformRandomDistribution;
|
||||
onChange: (state: IntegerGeneratorUniformRandomDistribution) => void;
|
||||
};
|
||||
export const IntegerGeneratorUniformRandomDistributionSettings = memo(
|
||||
({ state, onChange }: IntegerGeneratorUniformRandomDistributionSettingsProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const onChangeMin = useCallback(
|
||||
(min: number) => {
|
||||
onChange({ ...state, min });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
const onChangeMax = useCallback(
|
||||
(max: number) => {
|
||||
onChange({ ...state, max });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
const onChangeCount = useCallback(
|
||||
(count: number) => {
|
||||
onChange({ ...state, count });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
const onToggleSeed = useCallback(() => {
|
||||
onChange({ ...state, seed: isNil(state.seed) ? 0 : null });
|
||||
}, [onChange, state]);
|
||||
const onChangeSeed = useCallback(
|
||||
(seed?: number | null) => {
|
||||
onChange({ ...state, seed });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex gap={2} flexDir="column">
|
||||
<Flex gap={2} alignItems="flex-end">
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.min')}</FormLabel>
|
||||
<CompositeNumberInput value={state.min} onChange={onChangeMin} min={-Infinity} max={Infinity} />
|
||||
</FormControl>
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.max')}</FormLabel>
|
||||
<CompositeNumberInput value={state.max} onChange={onChangeMax} min={-Infinity} max={Infinity} />
|
||||
</FormControl>
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.count')}</FormLabel>
|
||||
<CompositeNumberInput value={state.count} onChange={onChangeCount} min={1} max={Infinity} />
|
||||
</FormControl>
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel alignItems="center" justifyContent="space-between" m={0} display="flex" w="full">
|
||||
{t('common.seed')}
|
||||
<Checkbox onChange={onToggleSeed} isChecked={!isNil(state.seed)} />
|
||||
</FormLabel>
|
||||
<CompositeNumberInput
|
||||
isDisabled={isNil(state.seed)}
|
||||
// This cast is save only because we disable the element when seed is not a number - the `...` is
|
||||
// rendered in the input field in this case
|
||||
value={state.seed ?? ('...' as unknown as number)}
|
||||
onChange={onChangeSeed}
|
||||
min={-Infinity}
|
||||
max={Infinity}
|
||||
/>
|
||||
</FormControl>
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
);
|
||||
IntegerGeneratorUniformRandomDistributionSettings.displayName = 'IntegerGeneratorUniformRandomDistributionSettings';
|
||||
@@ -0,0 +1,237 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import {
|
||||
Button,
|
||||
CompositeNumberInput,
|
||||
Divider,
|
||||
Flex,
|
||||
FormLabel,
|
||||
Grid,
|
||||
GridItem,
|
||||
IconButton,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { NUMPY_RAND_MAX } from 'app/constants';
|
||||
import { useAppStore } from 'app/store/nanostores/store';
|
||||
import { getOverlayScrollbarsParams, overlayScrollbarsStyles } from 'common/components/OverlayScrollbars/constants';
|
||||
import { useFieldIsInvalid } from 'features/nodes/hooks/useFieldIsInvalid';
|
||||
import { fieldNumberCollectionValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type {
|
||||
FloatFieldCollectionInputInstance,
|
||||
FloatFieldCollectionInputTemplate,
|
||||
IntegerFieldCollectionInputInstance,
|
||||
IntegerFieldCollectionInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { isNil } from 'lodash-es';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
const overlayscrollbarsOptions = getOverlayScrollbarsParams().options;
|
||||
|
||||
const sx = {
|
||||
borderWidth: 1,
|
||||
'&[data-error=true]': {
|
||||
borderColor: 'error.500',
|
||||
borderStyle: 'solid',
|
||||
},
|
||||
} satisfies SystemStyleObject;
|
||||
|
||||
export const NumberFieldCollectionInputComponent = memo(
|
||||
(
|
||||
props:
|
||||
| FieldComponentProps<IntegerFieldCollectionInputInstance, IntegerFieldCollectionInputTemplate>
|
||||
| FieldComponentProps<FloatFieldCollectionInputInstance, FloatFieldCollectionInputTemplate>
|
||||
) => {
|
||||
const { nodeId, field, fieldTemplate } = props;
|
||||
const store = useAppStore();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const isInvalid = useFieldIsInvalid(nodeId, field.name);
|
||||
const isIntegerField = useMemo(() => fieldTemplate.type.name === 'IntegerField', [fieldTemplate.type]);
|
||||
|
||||
const onRemoveNumber = useCallback(
|
||||
(index: number) => {
|
||||
const newValue = field.value ? [...field.value] : [];
|
||||
newValue.splice(index, 1);
|
||||
store.dispatch(fieldNumberCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
|
||||
},
|
||||
[field.name, field.value, nodeId, store]
|
||||
);
|
||||
|
||||
const onChangeNumber = useCallback(
|
||||
(index: number, value: number) => {
|
||||
const newValue = field.value ? [...field.value] : [];
|
||||
newValue[index] = value;
|
||||
store.dispatch(fieldNumberCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
|
||||
},
|
||||
[field.name, field.value, nodeId, store]
|
||||
);
|
||||
|
||||
const onAddNumber = useCallback(() => {
|
||||
const newValue = field.value ? [...field.value, 0] : [0];
|
||||
store.dispatch(fieldNumberCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
|
||||
}, [field.name, field.value, nodeId, store]);
|
||||
|
||||
const min = useMemo(() => {
|
||||
let min = -NUMPY_RAND_MAX;
|
||||
if (!isNil(fieldTemplate.minimum)) {
|
||||
min = fieldTemplate.minimum;
|
||||
}
|
||||
if (!isNil(fieldTemplate.exclusiveMinimum)) {
|
||||
min = fieldTemplate.exclusiveMinimum + 0.01;
|
||||
}
|
||||
return min;
|
||||
}, [fieldTemplate.exclusiveMinimum, fieldTemplate.minimum]);
|
||||
|
||||
const max = useMemo(() => {
|
||||
let max = NUMPY_RAND_MAX;
|
||||
if (!isNil(fieldTemplate.maximum)) {
|
||||
max = fieldTemplate.maximum;
|
||||
}
|
||||
if (!isNil(fieldTemplate.exclusiveMaximum)) {
|
||||
max = fieldTemplate.exclusiveMaximum - 0.01;
|
||||
}
|
||||
return max;
|
||||
}, [fieldTemplate.exclusiveMaximum, fieldTemplate.maximum]);
|
||||
|
||||
const step = useMemo(() => {
|
||||
if (isNil(fieldTemplate.multipleOf)) {
|
||||
return isIntegerField ? 1 : 0.1;
|
||||
}
|
||||
return fieldTemplate.multipleOf;
|
||||
}, [fieldTemplate.multipleOf, isIntegerField]);
|
||||
|
||||
const fineStep = useMemo(() => {
|
||||
if (isNil(fieldTemplate.multipleOf)) {
|
||||
return isIntegerField ? 1 : 0.01;
|
||||
}
|
||||
return fieldTemplate.multipleOf;
|
||||
}, [fieldTemplate.multipleOf, isIntegerField]);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
className="nodrag"
|
||||
position="relative"
|
||||
w="full"
|
||||
h="auto"
|
||||
maxH={64}
|
||||
alignItems="stretch"
|
||||
justifyContent="center"
|
||||
p={1}
|
||||
sx={sx}
|
||||
data-error={isInvalid}
|
||||
borderRadius="base"
|
||||
flexDir="column"
|
||||
gap={1}
|
||||
>
|
||||
<Button onClick={onAddNumber} variant="ghost">
|
||||
{t('nodes.addItem')}
|
||||
</Button>
|
||||
{field.value && field.value.length > 0 && (
|
||||
<>
|
||||
<Divider />
|
||||
<OverlayScrollbarsComponent
|
||||
className="nowheel"
|
||||
defer
|
||||
style={overlayScrollbarsStyles}
|
||||
options={overlayscrollbarsOptions}
|
||||
>
|
||||
<Grid gap={1} gridTemplateColumns="auto 1fr auto" alignItems="center">
|
||||
{field.value.map((value, index) => (
|
||||
<NumberListItemContent
|
||||
key={index}
|
||||
value={value}
|
||||
index={index}
|
||||
min={min}
|
||||
max={max}
|
||||
step={step}
|
||||
fineStep={fineStep}
|
||||
isIntegerField={isIntegerField}
|
||||
onRemoveNumber={onRemoveNumber}
|
||||
onChangeNumber={onChangeNumber}
|
||||
/>
|
||||
))}
|
||||
</Grid>
|
||||
</OverlayScrollbarsComponent>
|
||||
</>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
NumberFieldCollectionInputComponent.displayName = 'NumberFieldCollectionInputComponent';
|
||||
|
||||
type NumberListItemContentProps = {
|
||||
value: number;
|
||||
index: number;
|
||||
isIntegerField: boolean;
|
||||
min: number;
|
||||
max: number;
|
||||
step: number;
|
||||
fineStep: number;
|
||||
onRemoveNumber: (index: number) => void;
|
||||
onChangeNumber: (index: number, value: number) => void;
|
||||
};
|
||||
|
||||
const NumberListItemContent = memo(
|
||||
({
|
||||
value,
|
||||
index,
|
||||
isIntegerField,
|
||||
min,
|
||||
max,
|
||||
step,
|
||||
fineStep,
|
||||
onRemoveNumber,
|
||||
onChangeNumber,
|
||||
}: NumberListItemContentProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const onClickRemove = useCallback(() => {
|
||||
onRemoveNumber(index);
|
||||
}, [index, onRemoveNumber]);
|
||||
const onChange = useCallback(
|
||||
(v: number) => {
|
||||
onChangeNumber(index, isIntegerField ? Math.floor(Number(v)) : Number(v));
|
||||
},
|
||||
[index, isIntegerField, onChangeNumber]
|
||||
);
|
||||
|
||||
return (
|
||||
<>
|
||||
<GridItem>
|
||||
<FormLabel ps={1} m={0}>
|
||||
{index + 1}.
|
||||
</FormLabel>
|
||||
</GridItem>
|
||||
<GridItem>
|
||||
<CompositeNumberInput
|
||||
onChange={onChange}
|
||||
value={value}
|
||||
min={min}
|
||||
max={max}
|
||||
step={step}
|
||||
fineStep={fineStep}
|
||||
className="nodrag"
|
||||
flexGrow={1}
|
||||
/>
|
||||
</GridItem>
|
||||
<GridItem>
|
||||
<IconButton
|
||||
tabIndex={-1}
|
||||
size="sm"
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
onClick={onClickRemove}
|
||||
icon={<PiXBold />}
|
||||
aria-label={t('common.delete')}
|
||||
/>
|
||||
</GridItem>
|
||||
</>
|
||||
);
|
||||
}
|
||||
);
|
||||
NumberListItemContent.displayName = 'NumberListItemContent';
|
||||
@@ -0,0 +1,189 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Button, Divider, Flex, FormLabel, Grid, GridItem, IconButton, Input } from '@invoke-ai/ui-library';
|
||||
import { useAppStore } from 'app/store/nanostores/store';
|
||||
import { getOverlayScrollbarsParams, overlayScrollbarsStyles } from 'common/components/OverlayScrollbars/constants';
|
||||
import { useFieldIsInvalid } from 'features/nodes/hooks/useFieldIsInvalid';
|
||||
import { fieldStringCollectionValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type {
|
||||
StringFieldCollectionInputInstance,
|
||||
StringFieldCollectionInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
const overlayscrollbarsOptions = getOverlayScrollbarsParams().options;
|
||||
|
||||
const sx = {
|
||||
borderWidth: 1,
|
||||
'&[data-error=true]': {
|
||||
borderColor: 'error.500',
|
||||
borderStyle: 'solid',
|
||||
},
|
||||
} satisfies SystemStyleObject;
|
||||
|
||||
export const StringFieldCollectionInputComponent = memo(
|
||||
(props: FieldComponentProps<StringFieldCollectionInputInstance, StringFieldCollectionInputTemplate>) => {
|
||||
const { nodeId, field } = props;
|
||||
const { t } = useTranslation();
|
||||
const store = useAppStore();
|
||||
|
||||
const isInvalid = useFieldIsInvalid(nodeId, field.name);
|
||||
|
||||
const onRemoveString = useCallback(
|
||||
(index: number) => {
|
||||
const newValue = field.value ? [...field.value] : [];
|
||||
newValue.splice(index, 1);
|
||||
store.dispatch(fieldStringCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
|
||||
},
|
||||
[field.name, field.value, nodeId, store]
|
||||
);
|
||||
|
||||
const onChangeString = useCallback(
|
||||
(index: number, value: string) => {
|
||||
const newValue = field.value ? [...field.value] : [];
|
||||
newValue[index] = value;
|
||||
store.dispatch(fieldStringCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
|
||||
},
|
||||
[field.name, field.value, nodeId, store]
|
||||
);
|
||||
|
||||
const onAddString = useCallback(() => {
|
||||
const newValue = field.value ? [...field.value, ''] : [''];
|
||||
store.dispatch(fieldStringCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
|
||||
}, [field.name, field.value, nodeId, store]);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
className="nodrag"
|
||||
position="relative"
|
||||
w="full"
|
||||
h="auto"
|
||||
maxH={64}
|
||||
alignItems="stretch"
|
||||
justifyContent="center"
|
||||
p={1}
|
||||
sx={sx}
|
||||
data-error={isInvalid}
|
||||
borderRadius="base"
|
||||
flexDir="column"
|
||||
gap={1}
|
||||
>
|
||||
<Button onClick={onAddString} variant="ghost">
|
||||
{t('nodes.addItem')}
|
||||
</Button>
|
||||
{field.value && field.value.length > 0 && (
|
||||
<>
|
||||
<Divider />
|
||||
<OverlayScrollbarsComponent
|
||||
className="nowheel"
|
||||
defer
|
||||
style={overlayScrollbarsStyles}
|
||||
options={overlayscrollbarsOptions}
|
||||
>
|
||||
<Grid gap={1} gridTemplateColumns="auto 1fr auto" alignItems="center">
|
||||
{field.value.map((value, index) => (
|
||||
<ListItemContent
|
||||
key={index}
|
||||
value={value}
|
||||
index={index}
|
||||
onRemoveString={onRemoveString}
|
||||
onChangeString={onChangeString}
|
||||
/>
|
||||
))}
|
||||
</Grid>
|
||||
</OverlayScrollbarsComponent>
|
||||
</>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
StringFieldCollectionInputComponent.displayName = 'StringFieldCollectionInputComponent';
|
||||
|
||||
type StringListItemContentProps = {
|
||||
value: string;
|
||||
index: number;
|
||||
onRemoveString: (index: number) => void;
|
||||
onChangeString: (index: number, value: string) => void;
|
||||
};
|
||||
|
||||
const StringListItemContent = memo(({ value, index, onRemoveString, onChangeString }: StringListItemContentProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const onClickRemove = useCallback(() => {
|
||||
onRemoveString(index);
|
||||
}, [index, onRemoveString]);
|
||||
const onChange = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => {
|
||||
onChangeString(index, e.target.value);
|
||||
},
|
||||
[index, onChangeString]
|
||||
);
|
||||
return (
|
||||
<Flex alignItems="center" gap={1}>
|
||||
<Input size="xs" resize="none" value={value} onChange={onChange} />
|
||||
<IconButton
|
||||
size="sm"
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
onClick={onClickRemove}
|
||||
icon={<PiXBold />}
|
||||
aria-label={t('common.remove')}
|
||||
tooltip={t('common.remove')}
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
StringListItemContent.displayName = 'StringListItemContent';
|
||||
|
||||
type ListItemContentProps = {
|
||||
value: string;
|
||||
index: number;
|
||||
onRemoveString: (index: number) => void;
|
||||
onChangeString: (index: number, value: string) => void;
|
||||
};
|
||||
|
||||
const ListItemContent = memo(({ value, index, onRemoveString, onChangeString }: ListItemContentProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const onClickRemove = useCallback(() => {
|
||||
onRemoveString(index);
|
||||
}, [index, onRemoveString]);
|
||||
const onChange = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => {
|
||||
onChangeString(index, e.target.value);
|
||||
},
|
||||
[index, onChangeString]
|
||||
);
|
||||
|
||||
return (
|
||||
<>
|
||||
<GridItem>
|
||||
<FormLabel ps={1} m={0}>
|
||||
{index + 1}.
|
||||
</FormLabel>
|
||||
</GridItem>
|
||||
<GridItem>
|
||||
<Input size="sm" resize="none" value={value} onChange={onChange} />
|
||||
</GridItem>
|
||||
<GridItem>
|
||||
<IconButton
|
||||
tabIndex={-1}
|
||||
size="sm"
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
onClick={onClickRemove}
|
||||
icon={<PiXBold />}
|
||||
aria-label={t('common.delete')}
|
||||
/>
|
||||
</GridItem>
|
||||
</>
|
||||
);
|
||||
});
|
||||
ListItemContent.displayName = 'ListItemContent';
|
||||
@@ -0,0 +1,60 @@
|
||||
import { CompositeNumberInput, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { GeneratorTextareaWithFileUpload } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/GeneratorTextareaWithFileUpload';
|
||||
import type { StringGeneratorDynamicPromptsCombinatorial } from 'features/nodes/types/field';
|
||||
import { memo, useCallback, useEffect, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useDynamicPromptsQuery } from 'services/api/endpoints/utilities';
|
||||
import { useDebounce } from 'use-debounce';
|
||||
|
||||
type StringGeneratorDynamicPromptsCombinatorialSettingsProps = {
|
||||
state: StringGeneratorDynamicPromptsCombinatorial;
|
||||
onChange: (state: StringGeneratorDynamicPromptsCombinatorial) => void;
|
||||
};
|
||||
export const StringGeneratorDynamicPromptsCombinatorialSettings = memo(
|
||||
({ state, onChange }: StringGeneratorDynamicPromptsCombinatorialSettingsProps) => {
|
||||
const { t } = useTranslation();
|
||||
const loadingValues = useMemo(() => [`<${t('nodes.generatorLoading')}>`], [t]);
|
||||
|
||||
const onChangeInput = useCallback(
|
||||
(input: string) => {
|
||||
onChange({ ...state, input, values: loadingValues });
|
||||
},
|
||||
[onChange, state, loadingValues]
|
||||
);
|
||||
const onChangeMaxPrompts = useCallback(
|
||||
(v: number) => {
|
||||
onChange({ ...state, maxPrompts: v, values: loadingValues });
|
||||
},
|
||||
[onChange, state, loadingValues]
|
||||
);
|
||||
|
||||
const arg = useMemo(() => {
|
||||
const { input, maxPrompts } = state;
|
||||
return { prompt: input, max_prompts: maxPrompts, combinatorial: true };
|
||||
}, [state]);
|
||||
const [debouncedArg] = useDebounce(arg, 300);
|
||||
|
||||
const { data, isLoading } = useDynamicPromptsQuery(debouncedArg);
|
||||
|
||||
useEffect(() => {
|
||||
if (isLoading) {
|
||||
onChange({ ...state, values: loadingValues });
|
||||
} else if (data) {
|
||||
onChange({ ...state, values: data.prompts });
|
||||
} else {
|
||||
onChange({ ...state, values: [] });
|
||||
}
|
||||
}, [data, isLoading, loadingValues, onChange, state]);
|
||||
|
||||
return (
|
||||
<Flex gap={2} flexDir="column">
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('dynamicPrompts.maxPrompts')}</FormLabel>
|
||||
<CompositeNumberInput value={state.maxPrompts} onChange={onChangeMaxPrompts} min={1} max={1000} w="full" />
|
||||
</FormControl>
|
||||
<GeneratorTextareaWithFileUpload value={state.input} onChange={onChangeInput} />
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
);
|
||||
StringGeneratorDynamicPromptsCombinatorialSettings.displayName = 'StringGeneratorDynamicPromptsCombinatorialSettings';
|
||||
@@ -0,0 +1,87 @@
|
||||
import { Checkbox, CompositeNumberInput, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { GeneratorTextareaWithFileUpload } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/GeneratorTextareaWithFileUpload';
|
||||
import type { StringGeneratorDynamicPromptsRandom } from 'features/nodes/types/field';
|
||||
import { isNil, random } from 'lodash-es';
|
||||
import { memo, useCallback, useEffect, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useDynamicPromptsQuery } from 'services/api/endpoints/utilities';
|
||||
import { useDebounce } from 'use-debounce';
|
||||
|
||||
type StringGeneratorDynamicPromptsRandomSettingsProps = {
|
||||
state: StringGeneratorDynamicPromptsRandom;
|
||||
onChange: (state: StringGeneratorDynamicPromptsRandom) => void;
|
||||
};
|
||||
export const StringGeneratorDynamicPromptsRandomSettings = memo(
|
||||
({ state, onChange }: StringGeneratorDynamicPromptsRandomSettingsProps) => {
|
||||
const { t } = useTranslation();
|
||||
const loadingValues = useMemo(() => [`<${t('nodes.generatorLoading')}>`], [t]);
|
||||
|
||||
const onChangeInput = useCallback(
|
||||
(input: string) => {
|
||||
onChange({ ...state, input, values: loadingValues });
|
||||
},
|
||||
[onChange, state, loadingValues]
|
||||
);
|
||||
const onChangeCount = useCallback(
|
||||
(v: number) => {
|
||||
onChange({ ...state, count: v, values: loadingValues });
|
||||
},
|
||||
[onChange, state, loadingValues]
|
||||
);
|
||||
const onToggleSeed = useCallback(() => {
|
||||
onChange({ ...state, seed: isNil(state.seed) ? 0 : null, values: loadingValues });
|
||||
}, [onChange, state, loadingValues]);
|
||||
const onChangeSeed = useCallback(
|
||||
(seed?: number | null) => {
|
||||
onChange({ ...state, seed, values: loadingValues });
|
||||
},
|
||||
[onChange, state, loadingValues]
|
||||
);
|
||||
|
||||
const arg = useMemo(() => {
|
||||
const { input, count, seed } = state;
|
||||
return { prompt: input, max_prompts: count, combinatorial: false, seed: seed ?? random() };
|
||||
}, [state]);
|
||||
const [debouncedArg] = useDebounce(arg, 300);
|
||||
|
||||
const { data, isLoading } = useDynamicPromptsQuery(debouncedArg);
|
||||
|
||||
useEffect(() => {
|
||||
if (isLoading) {
|
||||
onChange({ ...state, values: loadingValues });
|
||||
} else if (data) {
|
||||
onChange({ ...state, values: data.prompts });
|
||||
} else {
|
||||
onChange({ ...state, values: [] });
|
||||
}
|
||||
}, [data, isLoading, loadingValues, onChange, state]);
|
||||
|
||||
return (
|
||||
<Flex gap={2} flexDir="column">
|
||||
<Flex gap={2}>
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel alignItems="center" justifyContent="space-between" display="flex" w="full" pe={0.5}>
|
||||
{t('common.seed')}
|
||||
<Checkbox onChange={onToggleSeed} isChecked={!isNil(state.seed)} />
|
||||
</FormLabel>
|
||||
<CompositeNumberInput
|
||||
isDisabled={isNil(state.seed)}
|
||||
// This cast is save only because we disable the element when seed is not a number - the `...` is
|
||||
// rendered in the input field in this case
|
||||
value={state.seed ?? ('...' as unknown as number)}
|
||||
onChange={onChangeSeed}
|
||||
min={-Infinity}
|
||||
max={Infinity}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.count')}</FormLabel>
|
||||
<CompositeNumberInput value={state.count} onChange={onChangeCount} min={1} max={1000} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<GeneratorTextareaWithFileUpload value={state.input} onChange={onChangeInput} />
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
);
|
||||
StringGeneratorDynamicPromptsRandomSettings.displayName = 'StringGeneratorDynamicPromptsRandomSettings';
|
||||
@@ -0,0 +1,111 @@
|
||||
import { Flex, Select, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { getOverlayScrollbarsParams, overlayScrollbarsStyles } from 'common/components/OverlayScrollbars/constants';
|
||||
import { StringGeneratorDynamicPromptsCombinatorialSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/StringGeneratorDynamicPromptsCombinatorialSettings';
|
||||
import { StringGeneratorDynamicPromptsRandomSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/StringGeneratorDynamicPromptsRandomSettings';
|
||||
import { StringGeneratorParseStringSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/StringGeneratorParseStringSettings';
|
||||
import type { FieldComponentProps } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/types';
|
||||
import { fieldStringGeneratorValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { StringGeneratorFieldInputInstance, StringGeneratorFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import {
|
||||
getStringGeneratorDefaults,
|
||||
resolveStringGeneratorField,
|
||||
StringGeneratorDynamicPromptsCombinatorialType,
|
||||
StringGeneratorDynamicPromptsRandomType,
|
||||
StringGeneratorParseStringType,
|
||||
} from 'features/nodes/types/field';
|
||||
import { isNil } from 'lodash-es';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useDebounce } from 'use-debounce';
|
||||
|
||||
const overlayscrollbarsOptions = getOverlayScrollbarsParams().options;
|
||||
|
||||
export const StringGeneratorFieldInputComponent = memo(
|
||||
(props: FieldComponentProps<StringGeneratorFieldInputInstance, StringGeneratorFieldInputTemplate>) => {
|
||||
const { nodeId, field } = props;
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const onChange = useCallback(
|
||||
(value: StringGeneratorFieldInputInstance['value']) => {
|
||||
dispatch(
|
||||
fieldStringGeneratorValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
const onChangeGeneratorType = useCallback(
|
||||
(e: ChangeEvent<HTMLSelectElement>) => {
|
||||
const value = getStringGeneratorDefaults(e.target.value as StringGeneratorFieldInputInstance['value']['type']);
|
||||
dispatch(
|
||||
fieldStringGeneratorValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
const [debouncedField] = useDebounce(field, 300);
|
||||
const resolvedValuesAsString = useMemo(() => {
|
||||
if (debouncedField.value.type === StringGeneratorDynamicPromptsRandomType && isNil(debouncedField.value.seed)) {
|
||||
const { count } = debouncedField.value;
|
||||
return `<${t('nodes.generatorNRandomValues', { count })}>`;
|
||||
}
|
||||
|
||||
const resolvedValues = resolveStringGeneratorField(debouncedField);
|
||||
if (resolvedValues.length === 0) {
|
||||
return `<${t('nodes.generatorNoValues')}>`;
|
||||
} else {
|
||||
return resolvedValues.join(', ');
|
||||
}
|
||||
}, [debouncedField, t]);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" gap={2}>
|
||||
<Select className="nowheel nodrag" onChange={onChangeGeneratorType} value={field.value.type} size="sm">
|
||||
<option value={StringGeneratorParseStringType}>{t('nodes.parseString')}</option>
|
||||
<option value={StringGeneratorDynamicPromptsRandomType}>{t('nodes.dynamicPromptsRandom')}</option>
|
||||
<option value={StringGeneratorDynamicPromptsCombinatorialType}>
|
||||
{t('nodes.dynamicPromptsCombinatorial')}
|
||||
</option>
|
||||
</Select>
|
||||
{field.value.type === StringGeneratorParseStringType && (
|
||||
<StringGeneratorParseStringSettings state={field.value} onChange={onChange} />
|
||||
)}
|
||||
{field.value.type === StringGeneratorDynamicPromptsRandomType && (
|
||||
<StringGeneratorDynamicPromptsRandomSettings state={field.value} onChange={onChange} />
|
||||
)}
|
||||
{field.value.type === StringGeneratorDynamicPromptsCombinatorialType && (
|
||||
<StringGeneratorDynamicPromptsCombinatorialSettings state={field.value} onChange={onChange} />
|
||||
)}
|
||||
<Flex w="full" h="full" p={2} borderWidth={1} borderRadius="base" maxH={128}>
|
||||
<Flex w="full" h="auto">
|
||||
<OverlayScrollbarsComponent
|
||||
className="nodrag nowheel"
|
||||
defer
|
||||
style={overlayScrollbarsStyles}
|
||||
options={overlayscrollbarsOptions}
|
||||
>
|
||||
<Text className="nodrag nowheel" fontFamily="monospace" userSelect="text" cursor="text">
|
||||
{resolvedValuesAsString}
|
||||
</Text>
|
||||
</OverlayScrollbarsComponent>
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
StringGeneratorFieldInputComponent.displayName = 'StringGeneratorFieldInputComponent';
|
||||
@@ -0,0 +1,41 @@
|
||||
import { Flex, FormControl, FormLabel, Input } from '@invoke-ai/ui-library';
|
||||
import { GeneratorTextareaWithFileUpload } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/GeneratorTextareaWithFileUpload';
|
||||
import type { StringGeneratorParseString } from 'features/nodes/types/field';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
type StringGeneratorParseStringSettingsProps = {
|
||||
state: StringGeneratorParseString;
|
||||
onChange: (state: StringGeneratorParseString) => void;
|
||||
};
|
||||
export const StringGeneratorParseStringSettings = memo(
|
||||
({ state, onChange }: StringGeneratorParseStringSettingsProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const onChangeSplitOn = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => {
|
||||
onChange({ ...state, splitOn: e.target.value });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
|
||||
const onChangeInput = useCallback(
|
||||
(input: string) => {
|
||||
onChange({ ...state, input });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex gap={2} flexDir="column">
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('nodes.splitOn')}</FormLabel>
|
||||
<Input value={state.splitOn} onChange={onChangeSplitOn} />
|
||||
</FormControl>
|
||||
<GeneratorTextareaWithFileUpload value={state.input} onChange={onChangeInput} />
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
);
|
||||
StringGeneratorParseStringSettings.displayName = 'StringGeneratorParseStringSettings';
|
||||
@@ -1,12 +1,14 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Box, Editable, EditableInput, EditablePreview, Flex, useEditableControls } from '@invoke-ai/ui-library';
|
||||
import type { SystemStyleObject, TextProps } from '@invoke-ai/ui-library';
|
||||
import { Box, Editable, EditableInput, Flex, Text, useEditableControls } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useBatchGroupColorToken } from 'features/nodes/hooks/useBatchGroupColorToken';
|
||||
import { useBatchGroupId } from 'features/nodes/hooks/useBatchGroupId';
|
||||
import { useNodeLabel } from 'features/nodes/hooks/useNodeLabel';
|
||||
import { useNodeTemplateTitle } from 'features/nodes/hooks/useNodeTemplateTitle';
|
||||
import { nodeLabelChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
|
||||
import type { MouseEvent } from 'react';
|
||||
import { memo, useCallback, useEffect, useState } from 'react';
|
||||
import { memo, useCallback, useEffect, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
type Props = {
|
||||
@@ -17,6 +19,8 @@ type Props = {
|
||||
const NodeTitle = ({ nodeId, title }: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const label = useNodeLabel(nodeId);
|
||||
const batchGroupId = useBatchGroupId(nodeId);
|
||||
const batchGroupColorToken = useBatchGroupColorToken(batchGroupId);
|
||||
const templateTitle = useNodeTemplateTitle(nodeId);
|
||||
const { t } = useTranslation();
|
||||
|
||||
@@ -29,6 +33,16 @@ const NodeTitle = ({ nodeId, title }: Props) => {
|
||||
[dispatch, nodeId, title, templateTitle, label, t]
|
||||
);
|
||||
|
||||
const localTitleWithBatchGroupId = useMemo(() => {
|
||||
if (!batchGroupId) {
|
||||
return localTitle;
|
||||
}
|
||||
if (batchGroupId === 'None') {
|
||||
return `${localTitle} (${t('nodes.noBatchGroup')})`;
|
||||
}
|
||||
return `${localTitle} (${batchGroupId})`;
|
||||
}, [batchGroupId, localTitle, t]);
|
||||
|
||||
const handleChange = useCallback((newTitle: string) => {
|
||||
setLocalTitle(newTitle);
|
||||
}, []);
|
||||
@@ -50,7 +64,16 @@ const NodeTitle = ({ nodeId, title }: Props) => {
|
||||
w="full"
|
||||
h="full"
|
||||
>
|
||||
<EditablePreview fontSize="sm" p={0} w="full" noOfLines={1} />
|
||||
<Preview
|
||||
fontSize="sm"
|
||||
p={0}
|
||||
w="full"
|
||||
noOfLines={1}
|
||||
color={batchGroupColorToken}
|
||||
fontWeight={batchGroupId ? 'semibold' : undefined}
|
||||
>
|
||||
{localTitleWithBatchGroupId}
|
||||
</Preview>
|
||||
<EditableInput className="nodrag" fontSize="sm" sx={editableInputStyles} />
|
||||
<EditableControls />
|
||||
</Editable>
|
||||
@@ -60,6 +83,16 @@ const NodeTitle = ({ nodeId, title }: Props) => {
|
||||
|
||||
export default memo(NodeTitle);
|
||||
|
||||
const Preview = (props: TextProps) => {
|
||||
const { isEditing } = useEditableControls();
|
||||
|
||||
if (isEditing) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return <Text {...props} />;
|
||||
};
|
||||
|
||||
function EditableControls() {
|
||||
const { isEditing, getEditButtonProps } = useEditableControls();
|
||||
const handleDoubleClick = useCallback(
|
||||
|
||||
@@ -5,7 +5,7 @@ import NodeOpacitySlider from './NodeOpacitySlider';
|
||||
import ViewportControls from './ViewportControls';
|
||||
|
||||
const BottomLeftPanel = () => (
|
||||
<Flex gap={2} position="absolute" bottom={0} insetInlineStart={0}>
|
||||
<Flex gap={2} position="absolute" bottom={2} insetInlineStart={2}>
|
||||
<ViewportControls />
|
||||
<NodeOpacitySlider />
|
||||
</Flex>
|
||||
|
||||
@@ -20,7 +20,7 @@ const MinimapPanel = () => {
|
||||
const shouldShowMinimapPanel = useAppSelector(selectShouldShowMinimapPanel);
|
||||
|
||||
return (
|
||||
<Flex gap={2} position="absolute" bottom={0} insetInlineEnd={0}>
|
||||
<Flex gap={2} position="absolute" bottom={2} insetInlineEnd={2}>
|
||||
{shouldShowMinimapPanel && (
|
||||
<ChakraMiniMap
|
||||
pannable
|
||||
|
||||
@@ -12,7 +12,7 @@ import { memo } from 'react';
|
||||
const TopCenterPanel = () => {
|
||||
const name = useAppSelector(selectWorkflowName);
|
||||
return (
|
||||
<Flex gap={2} top={0} left={0} right={0} position="absolute" alignItems="flex-start" pointerEvents="none">
|
||||
<Flex gap={2} top={2} left={2} right={2} position="absolute" alignItems="flex-start" pointerEvents="none">
|
||||
<Flex gap="2">
|
||||
<AddNodeButton />
|
||||
<UpdateNodesButton />
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useBatchGroupColorToken = (batchGroupId?: string) => {
|
||||
const batchGroupColorToken = useMemo(() => {
|
||||
switch (batchGroupId) {
|
||||
case 'Group 1':
|
||||
return 'invokeGreen.300';
|
||||
case 'Group 2':
|
||||
return 'invokeBlue.300';
|
||||
case 'Group 3':
|
||||
return 'invokePurple.200';
|
||||
case 'Group 4':
|
||||
return 'invokeRed.300';
|
||||
case 'Group 5':
|
||||
return 'invokeYellow.300';
|
||||
default:
|
||||
return undefined;
|
||||
}
|
||||
}, [batchGroupId]);
|
||||
|
||||
return batchGroupColorToken;
|
||||
};
|
||||
@@ -0,0 +1,19 @@
|
||||
import { useNode } from 'features/nodes/hooks/useNode';
|
||||
import { isBatchNode, isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useBatchGroupId = (nodeId: string) => {
|
||||
const node = useNode(nodeId);
|
||||
|
||||
const batchGroupId = useMemo(() => {
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
if (!isBatchNode(node)) {
|
||||
return;
|
||||
}
|
||||
return node.data.inputs['batch_group_id']?.value as string;
|
||||
}, [node]);
|
||||
|
||||
return batchGroupId;
|
||||
};
|
||||
@@ -3,7 +3,21 @@ import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useConnectionState } from 'features/nodes/hooks/useConnectionState';
|
||||
import { useFieldInputTemplate } from 'features/nodes/hooks/useFieldInputTemplate';
|
||||
import { selectFieldInputInstance, selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { isImageFieldCollectionInputInstance, isImageFieldCollectionInputTemplate } from 'features/nodes/types/field';
|
||||
import {
|
||||
isFloatFieldCollectionInputInstance,
|
||||
isFloatFieldCollectionInputTemplate,
|
||||
isImageFieldCollectionInputInstance,
|
||||
isImageFieldCollectionInputTemplate,
|
||||
isIntegerFieldCollectionInputInstance,
|
||||
isIntegerFieldCollectionInputTemplate,
|
||||
isStringFieldCollectionInputInstance,
|
||||
isStringFieldCollectionInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import {
|
||||
validateImageFieldCollectionValue,
|
||||
validateNumberFieldCollectionValue,
|
||||
validateStringFieldCollectionValue,
|
||||
} from 'features/nodes/types/fieldValidators';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useFieldIsInvalid = (nodeId: string, fieldName: string) => {
|
||||
@@ -35,13 +49,27 @@ export const useFieldIsInvalid = (nodeId: string, fieldName: string) => {
|
||||
}
|
||||
|
||||
// Else special handling for individual field types
|
||||
|
||||
if (isImageFieldCollectionInputInstance(field) && isImageFieldCollectionInputTemplate(template)) {
|
||||
// Image collections may have min or max item counts
|
||||
if (template.minItems !== undefined && field.value.length < template.minItems) {
|
||||
if (validateImageFieldCollectionValue(field.value, template).length > 0) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if (template.maxItems !== undefined && field.value.length > template.maxItems) {
|
||||
if (isStringFieldCollectionInputInstance(field) && isStringFieldCollectionInputTemplate(template)) {
|
||||
if (validateStringFieldCollectionValue(field.value, template).length > 0) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if (isIntegerFieldCollectionInputInstance(field) && isIntegerFieldCollectionInputTemplate(template)) {
|
||||
if (validateNumberFieldCollectionValue(field.value, template).length > 0) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if (isFloatFieldCollectionInputInstance(field) && isFloatFieldCollectionInputTemplate(template)) {
|
||||
if (validateNumberFieldCollectionValue(field.value, template).length > 0) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,10 +16,13 @@ import type {
|
||||
EnumFieldValue,
|
||||
FieldValue,
|
||||
FloatFieldValue,
|
||||
FloatGeneratorFieldValue,
|
||||
FluxVAEModelFieldValue,
|
||||
ImageFieldCollectionValue,
|
||||
ImageFieldValue,
|
||||
IntegerFieldCollectionValue,
|
||||
IntegerFieldValue,
|
||||
IntegerGeneratorFieldValue,
|
||||
IPAdapterModelFieldValue,
|
||||
LoRAModelFieldValue,
|
||||
MainModelFieldValue,
|
||||
@@ -28,7 +31,9 @@ import type {
|
||||
SDXLRefinerModelFieldValue,
|
||||
SpandrelImageToImageModelFieldValue,
|
||||
StatefulFieldValue,
|
||||
StringFieldCollectionValue,
|
||||
StringFieldValue,
|
||||
StringGeneratorFieldValue,
|
||||
T2IAdapterModelFieldValue,
|
||||
T5EncoderModelFieldValue,
|
||||
VAEModelFieldValue,
|
||||
@@ -43,11 +48,15 @@ import {
|
||||
zControlLoRAModelFieldValue,
|
||||
zControlNetModelFieldValue,
|
||||
zEnumFieldValue,
|
||||
zFloatFieldCollectionValue,
|
||||
zFloatFieldValue,
|
||||
zFloatGeneratorFieldValue,
|
||||
zFluxVAEModelFieldValue,
|
||||
zImageFieldCollectionValue,
|
||||
zImageFieldValue,
|
||||
zIntegerFieldCollectionValue,
|
||||
zIntegerFieldValue,
|
||||
zIntegerGeneratorFieldValue,
|
||||
zIPAdapterModelFieldValue,
|
||||
zLoRAModelFieldValue,
|
||||
zMainModelFieldValue,
|
||||
@@ -56,7 +65,9 @@ import {
|
||||
zSDXLRefinerModelFieldValue,
|
||||
zSpandrelImageToImageModelFieldValue,
|
||||
zStatefulFieldValue,
|
||||
zStringFieldCollectionValue,
|
||||
zStringFieldValue,
|
||||
zStringGeneratorFieldValue,
|
||||
zT2IAdapterModelFieldValue,
|
||||
zT5EncoderModelFieldValue,
|
||||
zVAEModelFieldValue,
|
||||
@@ -311,9 +322,15 @@ export const nodesSlice = createSlice({
|
||||
fieldStringValueChanged: (state, action: FieldValueAction<StringFieldValue>) => {
|
||||
fieldValueReducer(state, action, zStringFieldValue);
|
||||
},
|
||||
fieldStringCollectionValueChanged: (state, action: FieldValueAction<StringFieldCollectionValue>) => {
|
||||
fieldValueReducer(state, action, zStringFieldCollectionValue);
|
||||
},
|
||||
fieldNumberValueChanged: (state, action: FieldValueAction<IntegerFieldValue | FloatFieldValue>) => {
|
||||
fieldValueReducer(state, action, zIntegerFieldValue.or(zFloatFieldValue));
|
||||
},
|
||||
fieldNumberCollectionValueChanged: (state, action: FieldValueAction<IntegerFieldCollectionValue>) => {
|
||||
fieldValueReducer(state, action, zIntegerFieldCollectionValue.or(zFloatFieldCollectionValue));
|
||||
},
|
||||
fieldBooleanValueChanged: (state, action: FieldValueAction<BooleanFieldValue>) => {
|
||||
fieldValueReducer(state, action, zBooleanFieldValue);
|
||||
},
|
||||
@@ -383,6 +400,15 @@ export const nodesSlice = createSlice({
|
||||
fieldSchedulerValueChanged: (state, action: FieldValueAction<SchedulerFieldValue>) => {
|
||||
fieldValueReducer(state, action, zSchedulerFieldValue);
|
||||
},
|
||||
fieldFloatGeneratorValueChanged: (state, action: FieldValueAction<FloatGeneratorFieldValue>) => {
|
||||
fieldValueReducer(state, action, zFloatGeneratorFieldValue);
|
||||
},
|
||||
fieldIntegerGeneratorValueChanged: (state, action: FieldValueAction<IntegerGeneratorFieldValue>) => {
|
||||
fieldValueReducer(state, action, zIntegerGeneratorFieldValue);
|
||||
},
|
||||
fieldStringGeneratorValueChanged: (state, action: FieldValueAction<StringGeneratorFieldValue>) => {
|
||||
fieldValueReducer(state, action, zStringGeneratorFieldValue);
|
||||
},
|
||||
notesNodeValueChanged: (state, action: PayloadAction<{ nodeId: string; value: string }>) => {
|
||||
const { nodeId, value } = action.payload;
|
||||
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
|
||||
@@ -435,9 +461,11 @@ export const {
|
||||
fieldModelIdentifierValueChanged,
|
||||
fieldMainModelValueChanged,
|
||||
fieldNumberValueChanged,
|
||||
fieldNumberCollectionValueChanged,
|
||||
fieldRefinerModelValueChanged,
|
||||
fieldSchedulerValueChanged,
|
||||
fieldStringValueChanged,
|
||||
fieldStringCollectionValueChanged,
|
||||
fieldVaeModelValueChanged,
|
||||
fieldT5EncoderValueChanged,
|
||||
fieldCLIPEmbedValueChanged,
|
||||
@@ -445,6 +473,9 @@ export const {
|
||||
fieldCLIPGEmbedValueChanged,
|
||||
fieldControlLoRAModelValueChanged,
|
||||
fieldFluxVAEModelValueChanged,
|
||||
fieldFloatGeneratorValueChanged,
|
||||
fieldIntegerGeneratorValueChanged,
|
||||
fieldStringGeneratorValueChanged,
|
||||
nodeEditorReset,
|
||||
nodeIsIntermediateChanged,
|
||||
nodeIsOpenChanged,
|
||||
@@ -546,9 +577,11 @@ export const isAnyNodeOrEdgeMutation = isAnyOf(
|
||||
fieldLoRAModelValueChanged,
|
||||
fieldMainModelValueChanged,
|
||||
fieldNumberValueChanged,
|
||||
fieldNumberCollectionValueChanged,
|
||||
fieldRefinerModelValueChanged,
|
||||
fieldSchedulerValueChanged,
|
||||
fieldStringValueChanged,
|
||||
fieldStringCollectionValueChanged,
|
||||
fieldVaeModelValueChanged,
|
||||
fieldT5EncoderValueChanged,
|
||||
fieldCLIPEmbedValueChanged,
|
||||
|
||||
@@ -8,17 +8,21 @@ describe(areTypesEqual.name, () => {
|
||||
const sourceType: FieldType = {
|
||||
name: 'IntegerField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
originalType: {
|
||||
name: 'Foo',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
},
|
||||
};
|
||||
const targetType: FieldType = {
|
||||
name: 'IntegerField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
originalType: {
|
||||
name: 'Bar',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
},
|
||||
};
|
||||
expect(areTypesEqual(sourceType, targetType)).toBe(true);
|
||||
@@ -28,17 +32,21 @@ describe(areTypesEqual.name, () => {
|
||||
const sourceType: FieldType = {
|
||||
name: 'IntegerField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
originalType: {
|
||||
name: 'Foo',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
},
|
||||
};
|
||||
const targetType: FieldType = {
|
||||
name: 'MainModelField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
originalType: {
|
||||
name: 'IntegerField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
},
|
||||
};
|
||||
expect(areTypesEqual(sourceType, targetType)).toBe(true);
|
||||
@@ -48,17 +56,21 @@ describe(areTypesEqual.name, () => {
|
||||
const sourceType: FieldType = {
|
||||
name: 'MainModelField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
originalType: {
|
||||
name: 'IntegerField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
},
|
||||
};
|
||||
const targetType: FieldType = {
|
||||
name: 'IntegerField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
originalType: {
|
||||
name: 'Bar',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
},
|
||||
};
|
||||
expect(areTypesEqual(sourceType, targetType)).toBe(true);
|
||||
@@ -68,17 +80,21 @@ describe(areTypesEqual.name, () => {
|
||||
const sourceType: FieldType = {
|
||||
name: 'MainModelField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
originalType: {
|
||||
name: 'IntegerField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
},
|
||||
};
|
||||
const targetType: FieldType = {
|
||||
name: 'LoRAModelField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
originalType: {
|
||||
name: 'IntegerField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
},
|
||||
};
|
||||
expect(areTypesEqual(sourceType, targetType)).toBe(true);
|
||||
|
||||
@@ -11,7 +11,7 @@ describe(getCollectItemType.name, () => {
|
||||
const n2 = buildNode(collect);
|
||||
const e1 = buildEdge(n1.id, 'value', n2.id, 'item');
|
||||
const result = getCollectItemType(templates, [n1, n2], [e1], n2.id);
|
||||
expect(result).toEqual<FieldType>({ name: 'IntegerField', cardinality: 'SINGLE' });
|
||||
expect(result).toEqual<FieldType>({ name: 'IntegerField', cardinality: 'SINGLE', batch: false });
|
||||
});
|
||||
it('should return null if the collect node does not have any connections', () => {
|
||||
const n1 = buildNode(collect);
|
||||
|
||||
@@ -34,6 +34,7 @@ export const add: InvocationTemplate = {
|
||||
type: {
|
||||
name: 'IntegerField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
},
|
||||
default: 0,
|
||||
},
|
||||
@@ -48,6 +49,7 @@ export const add: InvocationTemplate = {
|
||||
type: {
|
||||
name: 'IntegerField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
},
|
||||
default: 0,
|
||||
},
|
||||
@@ -61,6 +63,7 @@ export const add: InvocationTemplate = {
|
||||
type: {
|
||||
name: 'IntegerField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
},
|
||||
ui_hidden: false,
|
||||
},
|
||||
@@ -89,6 +92,7 @@ export const sub: InvocationTemplate = {
|
||||
type: {
|
||||
name: 'IntegerField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
},
|
||||
default: 0,
|
||||
},
|
||||
@@ -103,6 +107,7 @@ export const sub: InvocationTemplate = {
|
||||
type: {
|
||||
name: 'IntegerField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
},
|
||||
default: 0,
|
||||
},
|
||||
@@ -116,6 +121,7 @@ export const sub: InvocationTemplate = {
|
||||
type: {
|
||||
name: 'IntegerField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
},
|
||||
ui_hidden: false,
|
||||
},
|
||||
@@ -145,6 +151,7 @@ export const collect: InvocationTemplate = {
|
||||
type: {
|
||||
name: 'CollectionItemField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -157,6 +164,7 @@ export const collect: InvocationTemplate = {
|
||||
type: {
|
||||
name: 'CollectionField',
|
||||
cardinality: 'COLLECTION',
|
||||
batch: false,
|
||||
},
|
||||
ui_hidden: false,
|
||||
ui_type: 'CollectionField',
|
||||
@@ -187,10 +195,11 @@ const scheduler: InvocationTemplate = {
|
||||
type: {
|
||||
name: 'SchedulerField',
|
||||
cardinality: 'SINGLE',
|
||||
|
||||
batch: false,
|
||||
originalType: {
|
||||
name: 'EnumField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
},
|
||||
},
|
||||
default: 'euler',
|
||||
@@ -205,10 +214,12 @@ const scheduler: InvocationTemplate = {
|
||||
type: {
|
||||
name: 'SchedulerField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
|
||||
originalType: {
|
||||
name: 'EnumField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
},
|
||||
},
|
||||
ui_hidden: false,
|
||||
@@ -240,10 +251,12 @@ export const main_model_loader: InvocationTemplate = {
|
||||
type: {
|
||||
name: 'MainModelField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
|
||||
originalType: {
|
||||
name: 'ModelIdentifierField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -257,6 +270,7 @@ export const main_model_loader: InvocationTemplate = {
|
||||
type: {
|
||||
name: 'VAEField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
},
|
||||
ui_hidden: false,
|
||||
},
|
||||
@@ -268,6 +282,7 @@ export const main_model_loader: InvocationTemplate = {
|
||||
type: {
|
||||
name: 'CLIPField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
},
|
||||
ui_hidden: false,
|
||||
},
|
||||
@@ -279,6 +294,7 @@ export const main_model_loader: InvocationTemplate = {
|
||||
type: {
|
||||
name: 'UNetField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
},
|
||||
ui_hidden: false,
|
||||
},
|
||||
@@ -307,6 +323,7 @@ export const img_resize: InvocationTemplate = {
|
||||
type: {
|
||||
name: 'BoardField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
},
|
||||
},
|
||||
metadata: {
|
||||
@@ -320,6 +337,7 @@ export const img_resize: InvocationTemplate = {
|
||||
type: {
|
||||
name: 'MetadataField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
},
|
||||
},
|
||||
image: {
|
||||
@@ -333,6 +351,7 @@ export const img_resize: InvocationTemplate = {
|
||||
type: {
|
||||
name: 'ImageField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
},
|
||||
},
|
||||
width: {
|
||||
@@ -346,6 +365,7 @@ export const img_resize: InvocationTemplate = {
|
||||
type: {
|
||||
name: 'IntegerField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
},
|
||||
default: 512,
|
||||
exclusiveMinimum: 0,
|
||||
@@ -361,6 +381,7 @@ export const img_resize: InvocationTemplate = {
|
||||
type: {
|
||||
name: 'IntegerField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
},
|
||||
default: 512,
|
||||
exclusiveMinimum: 0,
|
||||
@@ -376,6 +397,7 @@ export const img_resize: InvocationTemplate = {
|
||||
type: {
|
||||
name: 'EnumField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
},
|
||||
options: ['nearest', 'box', 'bilinear', 'hamming', 'bicubic', 'lanczos'],
|
||||
default: 'bicubic',
|
||||
@@ -390,6 +412,7 @@ export const img_resize: InvocationTemplate = {
|
||||
type: {
|
||||
name: 'ImageField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
},
|
||||
ui_hidden: false,
|
||||
},
|
||||
@@ -401,6 +424,7 @@ export const img_resize: InvocationTemplate = {
|
||||
type: {
|
||||
name: 'IntegerField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
},
|
||||
ui_hidden: false,
|
||||
},
|
||||
@@ -412,6 +436,7 @@ export const img_resize: InvocationTemplate = {
|
||||
type: {
|
||||
name: 'IntegerField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
},
|
||||
ui_hidden: false,
|
||||
},
|
||||
@@ -441,6 +466,7 @@ const iterate: InvocationTemplate = {
|
||||
type: {
|
||||
name: 'CollectionField',
|
||||
cardinality: 'COLLECTION',
|
||||
batch: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -453,6 +479,7 @@ const iterate: InvocationTemplate = {
|
||||
type: {
|
||||
name: 'CollectionItemField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
},
|
||||
ui_hidden: false,
|
||||
ui_type: 'CollectionItemField',
|
||||
@@ -465,6 +492,7 @@ const iterate: InvocationTemplate = {
|
||||
type: {
|
||||
name: 'IntegerField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
},
|
||||
ui_hidden: false,
|
||||
},
|
||||
@@ -476,6 +504,7 @@ const iterate: InvocationTemplate = {
|
||||
type: {
|
||||
name: 'IntegerField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
},
|
||||
ui_hidden: false,
|
||||
},
|
||||
|
||||
@@ -6,50 +6,57 @@ describe(validateConnectionTypes.name, () => {
|
||||
describe('generic cases', () => {
|
||||
it('should accept SINGLE to SINGLE of same type', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'FooField', cardinality: 'SINGLE' },
|
||||
{ name: 'FooField', cardinality: 'SINGLE' }
|
||||
{ name: 'FooField', cardinality: 'SINGLE', batch: false },
|
||||
{ name: 'FooField', cardinality: 'SINGLE', batch: false }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it('should accept COLLECTION to COLLECTION of same type', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'FooField', cardinality: 'COLLECTION' },
|
||||
{ name: 'FooField', cardinality: 'COLLECTION' }
|
||||
{ name: 'FooField', cardinality: 'COLLECTION', batch: false },
|
||||
{ name: 'FooField', cardinality: 'COLLECTION', batch: false }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it('should accept SINGLE to SINGLE_OR_COLLECTION of same type', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'FooField', cardinality: 'SINGLE' },
|
||||
{ name: 'FooField', cardinality: 'SINGLE_OR_COLLECTION' }
|
||||
{ name: 'FooField', cardinality: 'SINGLE', batch: false },
|
||||
{ name: 'FooField', cardinality: 'SINGLE_OR_COLLECTION', batch: false }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it('should accept COLLECTION to SINGLE_OR_COLLECTION of same type', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'FooField', cardinality: 'COLLECTION' },
|
||||
{ name: 'FooField', cardinality: 'SINGLE_OR_COLLECTION' }
|
||||
{ name: 'FooField', cardinality: 'COLLECTION', batch: false },
|
||||
{ name: 'FooField', cardinality: 'SINGLE_OR_COLLECTION', batch: false }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it('should reject COLLECTION to SINGLE of same type', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'FooField', cardinality: 'COLLECTION' },
|
||||
{ name: 'FooField', cardinality: 'SINGLE' }
|
||||
{ name: 'FooField', cardinality: 'COLLECTION', batch: false },
|
||||
{ name: 'FooField', cardinality: 'SINGLE', batch: false }
|
||||
);
|
||||
expect(r).toBe(false);
|
||||
});
|
||||
it('should reject SINGLE_OR_COLLECTION to SINGLE of same type', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'FooField', cardinality: 'SINGLE_OR_COLLECTION' },
|
||||
{ name: 'FooField', cardinality: 'SINGLE' }
|
||||
{ name: 'FooField', cardinality: 'SINGLE_OR_COLLECTION', batch: false },
|
||||
{ name: 'FooField', cardinality: 'SINGLE', batch: false }
|
||||
);
|
||||
expect(r).toBe(false);
|
||||
});
|
||||
it('should reject types with mismatch batch fields', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'FooField', cardinality: 'SINGLE', batch: false },
|
||||
{ name: 'FooField', cardinality: 'SINGLE', batch: true }
|
||||
);
|
||||
expect(r).toBe(false);
|
||||
});
|
||||
it('should reject mismatched types', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'FooField', cardinality: 'SINGLE' },
|
||||
{ name: 'BarField', cardinality: 'SINGLE' }
|
||||
{ name: 'FooField', cardinality: 'SINGLE', batch: false },
|
||||
{ name: 'BarField', cardinality: 'SINGLE', batch: false }
|
||||
);
|
||||
expect(r).toBe(false);
|
||||
});
|
||||
@@ -58,16 +65,16 @@ describe(validateConnectionTypes.name, () => {
|
||||
describe('special cases', () => {
|
||||
it('should reject a COLLECTION input to a COLLECTION input', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'CollectionField', cardinality: 'COLLECTION' },
|
||||
{ name: 'CollectionField', cardinality: 'COLLECTION' }
|
||||
{ name: 'CollectionField', cardinality: 'COLLECTION', batch: false },
|
||||
{ name: 'CollectionField', cardinality: 'COLLECTION', batch: false }
|
||||
);
|
||||
expect(r).toBe(false);
|
||||
});
|
||||
|
||||
it('should accept equal types', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'IntegerField', cardinality: 'SINGLE' },
|
||||
{ name: 'IntegerField', cardinality: 'SINGLE' }
|
||||
{ name: 'IntegerField', cardinality: 'SINGLE', batch: false },
|
||||
{ name: 'IntegerField', cardinality: 'SINGLE', batch: false }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
@@ -75,36 +82,36 @@ describe(validateConnectionTypes.name, () => {
|
||||
describe('CollectionItemField', () => {
|
||||
it('should accept CollectionItemField to any SINGLE target', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'CollectionItemField', cardinality: 'SINGLE' },
|
||||
{ name: 'IntegerField', cardinality: 'SINGLE' }
|
||||
{ name: 'CollectionItemField', cardinality: 'SINGLE', batch: false },
|
||||
{ name: 'IntegerField', cardinality: 'SINGLE', batch: false }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it('should accept CollectionItemField to any SINGLE_OR_COLLECTION target', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'CollectionItemField', cardinality: 'SINGLE' },
|
||||
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' }
|
||||
{ name: 'CollectionItemField', cardinality: 'SINGLE', batch: false },
|
||||
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION', batch: false }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it('should accept any SINGLE to CollectionItemField', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'IntegerField', cardinality: 'SINGLE' },
|
||||
{ name: 'CollectionItemField', cardinality: 'SINGLE' }
|
||||
{ name: 'IntegerField', cardinality: 'SINGLE', batch: false },
|
||||
{ name: 'CollectionItemField', cardinality: 'SINGLE', batch: false }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it('should reject any COLLECTION to CollectionItemField', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'IntegerField', cardinality: 'COLLECTION' },
|
||||
{ name: 'CollectionItemField', cardinality: 'SINGLE' }
|
||||
{ name: 'IntegerField', cardinality: 'COLLECTION', batch: false },
|
||||
{ name: 'CollectionItemField', cardinality: 'SINGLE', batch: false }
|
||||
);
|
||||
expect(r).toBe(false);
|
||||
});
|
||||
it('should reject any SINGLE_OR_COLLECTION to CollectionItemField', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' },
|
||||
{ name: 'CollectionItemField', cardinality: 'SINGLE' }
|
||||
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION', batch: false },
|
||||
{ name: 'CollectionItemField', cardinality: 'SINGLE', batch: false }
|
||||
);
|
||||
expect(r).toBe(false);
|
||||
});
|
||||
@@ -113,22 +120,22 @@ describe(validateConnectionTypes.name, () => {
|
||||
describe('SINGLE_OR_COLLECTION', () => {
|
||||
it('should accept any SINGLE of same type to SINGLE_OR_COLLECTION', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'IntegerField', cardinality: 'SINGLE' },
|
||||
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' }
|
||||
{ name: 'IntegerField', cardinality: 'SINGLE', batch: false },
|
||||
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION', batch: false }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it('should accept any COLLECTION of same type to SINGLE_OR_COLLECTION', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'IntegerField', cardinality: 'COLLECTION' },
|
||||
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' }
|
||||
{ name: 'IntegerField', cardinality: 'COLLECTION', batch: false },
|
||||
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION', batch: false }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it('should accept any SINGLE_OR_COLLECTION of same type to SINGLE_OR_COLLECTION', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' },
|
||||
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' }
|
||||
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION', batch: false },
|
||||
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION', batch: false }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
@@ -137,15 +144,15 @@ describe(validateConnectionTypes.name, () => {
|
||||
describe('CollectionField', () => {
|
||||
it('should accept any CollectionField to any COLLECTION type', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'CollectionField', cardinality: 'SINGLE' },
|
||||
{ name: 'IntegerField', cardinality: 'COLLECTION' }
|
||||
{ name: 'CollectionField', cardinality: 'SINGLE', batch: false },
|
||||
{ name: 'IntegerField', cardinality: 'COLLECTION', batch: false }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it('should accept any CollectionField to any SINGLE_OR_COLLECTION type', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'CollectionField', cardinality: 'SINGLE' },
|
||||
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' }
|
||||
{ name: 'CollectionField', cardinality: 'SINGLE', batch: false },
|
||||
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION', batch: false }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
@@ -159,27 +166,30 @@ describe(validateConnectionTypes.name, () => {
|
||||
{ t1: 'FloatField', t2: 'StringField' },
|
||||
];
|
||||
it.each(typePairs)('should accept SINGLE $t1 to SINGLE $t2', ({ t1, t2 }: TypePair) => {
|
||||
const r = validateConnectionTypes({ name: t1, cardinality: 'SINGLE' }, { name: t2, cardinality: 'SINGLE' });
|
||||
const r = validateConnectionTypes(
|
||||
{ name: t1, cardinality: 'SINGLE', batch: false },
|
||||
{ name: t2, cardinality: 'SINGLE', batch: false }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it.each(typePairs)('should accept SINGLE $t1 to SINGLE_OR_COLLECTION $t2', ({ t1, t2 }: TypePair) => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: t1, cardinality: 'SINGLE' },
|
||||
{ name: t2, cardinality: 'SINGLE_OR_COLLECTION' }
|
||||
{ name: t1, cardinality: 'SINGLE', batch: false },
|
||||
{ name: t2, cardinality: 'SINGLE_OR_COLLECTION', batch: false }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it.each(typePairs)('should accept COLLECTION $t1 to COLLECTION $t2', ({ t1, t2 }: TypePair) => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: t1, cardinality: 'COLLECTION' },
|
||||
{ name: t2, cardinality: 'COLLECTION' }
|
||||
{ name: t1, cardinality: 'COLLECTION', batch: false },
|
||||
{ name: t2, cardinality: 'COLLECTION', batch: false }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it.each(typePairs)('should accept COLLECTION $t1 to SINGLE_OR_COLLECTION $t2', ({ t1, t2 }: TypePair) => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: t1, cardinality: 'COLLECTION' },
|
||||
{ name: t2, cardinality: 'SINGLE_OR_COLLECTION' }
|
||||
{ name: t1, cardinality: 'COLLECTION', batch: false },
|
||||
{ name: t2, cardinality: 'SINGLE_OR_COLLECTION', batch: false }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
@@ -187,8 +197,8 @@ describe(validateConnectionTypes.name, () => {
|
||||
'should accept SINGLE_OR_COLLECTION $t1 to SINGLE_OR_COLLECTION $t2',
|
||||
({ t1, t2 }: TypePair) => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: t1, cardinality: 'SINGLE_OR_COLLECTION' },
|
||||
{ name: t2, cardinality: 'SINGLE_OR_COLLECTION' }
|
||||
{ name: t1, cardinality: 'SINGLE_OR_COLLECTION', batch: false },
|
||||
{ name: t2, cardinality: 'SINGLE_OR_COLLECTION', batch: false }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
}
|
||||
@@ -198,22 +208,22 @@ describe(validateConnectionTypes.name, () => {
|
||||
describe('AnyField', () => {
|
||||
it('should accept any SINGLE type to AnyField', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'FooField', cardinality: 'SINGLE' },
|
||||
{ name: 'AnyField', cardinality: 'SINGLE' }
|
||||
{ name: 'FooField', cardinality: 'SINGLE', batch: false },
|
||||
{ name: 'AnyField', cardinality: 'SINGLE', batch: false }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it('should accept any COLLECTION type to AnyField', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'FooField', cardinality: 'SINGLE' },
|
||||
{ name: 'AnyField', cardinality: 'COLLECTION' }
|
||||
{ name: 'FooField', cardinality: 'SINGLE', batch: false },
|
||||
{ name: 'AnyField', cardinality: 'COLLECTION', batch: false }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
it('should accept any SINGLE_OR_COLLECTION type to AnyField', () => {
|
||||
const r = validateConnectionTypes(
|
||||
{ name: 'FooField', cardinality: 'SINGLE' },
|
||||
{ name: 'AnyField', cardinality: 'SINGLE_OR_COLLECTION' }
|
||||
{ name: 'FooField', cardinality: 'SINGLE', batch: false },
|
||||
{ name: 'AnyField', cardinality: 'SINGLE_OR_COLLECTION', batch: false }
|
||||
);
|
||||
expect(r).toBe(true);
|
||||
});
|
||||
|
||||
@@ -19,6 +19,11 @@ export const validateConnectionTypes = (sourceType: FieldType, targetType: Field
|
||||
return true;
|
||||
}
|
||||
|
||||
// Batch and non-batch fields are incompatible.
|
||||
if (sourceType.batch !== targetType.batch) {
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Connection types must be the same for a connection, with exceptions:
|
||||
* - CollectionItem can connect to any non-COLLECTION (e.g. SINGLE or SINGLE_OR_COLLECTION)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,110 @@
|
||||
import type {
|
||||
FloatFieldCollectionInputTemplate,
|
||||
FloatFieldCollectionValue,
|
||||
ImageFieldCollectionInputTemplate,
|
||||
ImageFieldCollectionValue,
|
||||
IntegerFieldCollectionInputTemplate,
|
||||
IntegerFieldCollectionValue,
|
||||
StringFieldCollectionInputTemplate,
|
||||
StringFieldCollectionValue,
|
||||
} from 'features/nodes/types/field';
|
||||
import { t } from 'i18next';
|
||||
|
||||
export const validateImageFieldCollectionValue = (
|
||||
value: NonNullable<ImageFieldCollectionValue>,
|
||||
template: ImageFieldCollectionInputTemplate
|
||||
): string[] => {
|
||||
const reasons: string[] = [];
|
||||
const { minItems, maxItems } = template;
|
||||
const count = value.length;
|
||||
|
||||
// Image collections may have min or max items to validate
|
||||
if (minItems !== undefined && minItems > 0 && count === 0) {
|
||||
reasons.push(t('parameters.invoke.collectionEmpty'));
|
||||
}
|
||||
|
||||
if (minItems !== undefined && count < minItems) {
|
||||
reasons.push(t('parameters.invoke.collectionTooFewItems', { count, minItems }));
|
||||
}
|
||||
|
||||
if (maxItems !== undefined && count > maxItems) {
|
||||
reasons.push(t('parameters.invoke.collectionTooManyItems', { count, maxItems }));
|
||||
}
|
||||
|
||||
return reasons;
|
||||
};
|
||||
|
||||
export const validateStringFieldCollectionValue = (
|
||||
value: NonNullable<StringFieldCollectionValue>,
|
||||
template: StringFieldCollectionInputTemplate
|
||||
): string[] => {
|
||||
const reasons: string[] = [];
|
||||
const { minItems, maxItems, minLength, maxLength } = template;
|
||||
const count = value.length;
|
||||
|
||||
// Image collections may have min or max items to validate
|
||||
if (minItems !== undefined && minItems > 0 && count === 0) {
|
||||
reasons.push(t('parameters.invoke.collectionEmpty'));
|
||||
}
|
||||
|
||||
if (minItems !== undefined && count < minItems) {
|
||||
reasons.push(t('parameters.invoke.collectionTooFewItems', { count, minItems }));
|
||||
}
|
||||
|
||||
if (maxItems !== undefined && count > maxItems) {
|
||||
reasons.push(t('parameters.invoke.collectionTooManyItems', { count, maxItems }));
|
||||
}
|
||||
|
||||
for (const str of value) {
|
||||
if (maxLength !== undefined && str.length > maxLength) {
|
||||
reasons.push(t('parameters.invoke.collectionStringTooLong', { value, maxLength }));
|
||||
}
|
||||
if (minLength !== undefined && str.length < minLength) {
|
||||
reasons.push(t('parameters.invoke.collectionStringTooShort', { value, minLength }));
|
||||
}
|
||||
}
|
||||
|
||||
return reasons;
|
||||
};
|
||||
|
||||
export const validateNumberFieldCollectionValue = (
|
||||
value: NonNullable<IntegerFieldCollectionValue> | NonNullable<FloatFieldCollectionValue>,
|
||||
template: IntegerFieldCollectionInputTemplate | FloatFieldCollectionInputTemplate
|
||||
): string[] => {
|
||||
const reasons: string[] = [];
|
||||
const { minItems, maxItems, minimum, maximum, exclusiveMinimum, exclusiveMaximum, multipleOf } = template;
|
||||
const count = value.length;
|
||||
|
||||
// Image collections may have min or max items to validate
|
||||
if (minItems !== undefined && minItems > 0 && count === 0) {
|
||||
reasons.push(t('parameters.invoke.collectionEmpty'));
|
||||
}
|
||||
|
||||
if (minItems !== undefined && count < minItems) {
|
||||
reasons.push(t('parameters.invoke.collectionTooFewItems', { count, minItems }));
|
||||
}
|
||||
|
||||
if (maxItems !== undefined && count > maxItems) {
|
||||
reasons.push(t('parameters.invoke.collectionTooManyItems', { count, maxItems }));
|
||||
}
|
||||
|
||||
for (const num of value) {
|
||||
if (maximum !== undefined && num > maximum) {
|
||||
reasons.push(t('parameters.invoke.collectionNumberGTMax', { value, maximum }));
|
||||
}
|
||||
if (minimum !== undefined && num < minimum) {
|
||||
reasons.push(t('parameters.invoke.collectionNumberLTMin', { value, minimum }));
|
||||
}
|
||||
if (exclusiveMaximum !== undefined && num >= exclusiveMaximum) {
|
||||
reasons.push(t('parameters.invoke.collectionNumberGTExclusiveMax', { value, exclusiveMaximum }));
|
||||
}
|
||||
if (exclusiveMinimum !== undefined && num <= exclusiveMinimum) {
|
||||
reasons.push(t('parameters.invoke.collectionNumberLTExclusiveMin', { value, exclusiveMinimum }));
|
||||
}
|
||||
if (multipleOf !== undefined && num % multipleOf !== 0) {
|
||||
reasons.push(t('parameters.invoke.collectionNumberNotMultipleOf', { value, multipleOf }));
|
||||
}
|
||||
}
|
||||
|
||||
return reasons;
|
||||
};
|
||||
@@ -91,3 +91,30 @@ const zInvocationNodeEdgeExtra = z.object({
|
||||
type InvocationNodeEdgeExtra = z.infer<typeof zInvocationNodeEdgeExtra>;
|
||||
export type InvocationNodeEdge = Edge<InvocationNodeEdgeExtra>;
|
||||
// #endregion
|
||||
|
||||
export const isBatchNode = (node: InvocationNode) => {
|
||||
switch (node.data.type) {
|
||||
case 'image_batch':
|
||||
case 'string_batch':
|
||||
case 'integer_batch':
|
||||
case 'float_batch':
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
const isGeneratorNode = (node: InvocationNode) => {
|
||||
switch (node.data.type) {
|
||||
case 'float_generator':
|
||||
case 'integer_generator':
|
||||
case 'string_generator':
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
export const isExecutableNode = (node: InvocationNode) => {
|
||||
return !isBatchNode(node) && !isGeneratorNode(node);
|
||||
};
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { NodesState } from 'features/nodes/store/types';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { isExecutableNode, isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { omit, reduce } from 'lodash-es';
|
||||
import type { AnyInvocation, Graph } from 'services/api/types';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
@@ -14,7 +14,7 @@ export const buildNodesGraph = (nodesState: NodesState): Graph => {
|
||||
const { nodes, edges } = nodesState;
|
||||
|
||||
// Exclude all batch nodes - we will handle these in the batch setup in a diff function
|
||||
const filteredNodes = nodes.filter(isInvocationNode).filter((node) => node.data.type !== 'image_batch');
|
||||
const filteredNodes = nodes.filter(isInvocationNode).filter(isExecutableNode);
|
||||
|
||||
// Reduce the node editor nodes into invocation graph nodes
|
||||
const parsedNodes = filteredNodes.reduce<NonNullable<Graph['nodes']>>((nodesAccumulator, node) => {
|
||||
|
||||
@@ -29,6 +29,9 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
|
||||
CLIPLEmbedModelField: undefined,
|
||||
CLIPGEmbedModelField: undefined,
|
||||
ControlLoRAModelField: undefined,
|
||||
FloatGeneratorField: undefined,
|
||||
IntegerGeneratorField: undefined,
|
||||
StringGeneratorField: undefined,
|
||||
};
|
||||
|
||||
export const buildFieldInputInstance = (id: string, template: FieldInputTemplate): FieldInputInstance => {
|
||||
|
||||
@@ -11,12 +11,16 @@ import type {
|
||||
EnumFieldInputTemplate,
|
||||
FieldInputTemplate,
|
||||
FieldType,
|
||||
FloatFieldCollectionInputTemplate,
|
||||
FloatFieldInputTemplate,
|
||||
FloatGeneratorFieldInputTemplate,
|
||||
FluxMainModelFieldInputTemplate,
|
||||
FluxVAEModelFieldInputTemplate,
|
||||
ImageFieldCollectionInputTemplate,
|
||||
ImageFieldInputTemplate,
|
||||
IntegerFieldCollectionInputTemplate,
|
||||
IntegerFieldInputTemplate,
|
||||
IntegerGeneratorFieldInputTemplate,
|
||||
IPAdapterModelFieldInputTemplate,
|
||||
LoRAModelFieldInputTemplate,
|
||||
MainModelFieldInputTemplate,
|
||||
@@ -28,12 +32,23 @@ import type {
|
||||
SpandrelImageToImageModelFieldInputTemplate,
|
||||
StatefulFieldType,
|
||||
StatelessFieldInputTemplate,
|
||||
StringFieldCollectionInputTemplate,
|
||||
StringFieldInputTemplate,
|
||||
StringGeneratorFieldInputTemplate,
|
||||
T2IAdapterModelFieldInputTemplate,
|
||||
T5EncoderModelFieldInputTemplate,
|
||||
VAEModelFieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { isImageCollectionFieldType, isStatefulFieldType } from 'features/nodes/types/field';
|
||||
import {
|
||||
getFloatGeneratorArithmeticSequenceDefaults,
|
||||
getIntegerGeneratorArithmeticSequenceDefaults,
|
||||
getStringGeneratorParseStringDefaults,
|
||||
isFloatCollectionFieldType,
|
||||
isImageCollectionFieldType,
|
||||
isIntegerCollectionFieldType,
|
||||
isStatefulFieldType,
|
||||
isStringCollectionFieldType,
|
||||
} from 'features/nodes/types/field';
|
||||
import type { InvocationFieldSchema } from 'features/nodes/types/openapi';
|
||||
import { isSchemaObject } from 'features/nodes/types/openapi';
|
||||
import { t } from 'i18next';
|
||||
@@ -77,6 +92,48 @@ const buildIntegerFieldInputTemplate: FieldInputTemplateBuilder<IntegerFieldInpu
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildIntegerFieldCollectionInputTemplate: FieldInputTemplateBuilder<IntegerFieldCollectionInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: IntegerFieldCollectionInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? (schemaObject.orig_required ? [] : undefined),
|
||||
};
|
||||
|
||||
if (schemaObject.minItems !== undefined) {
|
||||
template.minItems = schemaObject.minItems;
|
||||
}
|
||||
|
||||
if (schemaObject.maxItems !== undefined) {
|
||||
template.maxItems = schemaObject.maxItems;
|
||||
}
|
||||
|
||||
if (schemaObject.multipleOf !== undefined) {
|
||||
template.multipleOf = schemaObject.multipleOf;
|
||||
}
|
||||
|
||||
if (schemaObject.maximum !== undefined) {
|
||||
template.maximum = schemaObject.maximum;
|
||||
}
|
||||
|
||||
if (schemaObject.exclusiveMaximum !== undefined && isNumber(schemaObject.exclusiveMaximum)) {
|
||||
template.exclusiveMaximum = schemaObject.exclusiveMaximum;
|
||||
}
|
||||
|
||||
if (schemaObject.minimum !== undefined) {
|
||||
template.minimum = schemaObject.minimum;
|
||||
}
|
||||
|
||||
if (schemaObject.exclusiveMinimum !== undefined && isNumber(schemaObject.exclusiveMinimum)) {
|
||||
template.exclusiveMinimum = schemaObject.exclusiveMinimum;
|
||||
}
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildFloatFieldInputTemplate: FieldInputTemplateBuilder<FloatFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@@ -111,6 +168,48 @@ const buildFloatFieldInputTemplate: FieldInputTemplateBuilder<FloatFieldInputTem
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildFloatFieldCollectionInputTemplate: FieldInputTemplateBuilder<FloatFieldCollectionInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: FloatFieldCollectionInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? (schemaObject.orig_required ? [] : undefined),
|
||||
};
|
||||
|
||||
if (schemaObject.minItems !== undefined) {
|
||||
template.minItems = schemaObject.minItems;
|
||||
}
|
||||
|
||||
if (schemaObject.maxItems !== undefined) {
|
||||
template.maxItems = schemaObject.maxItems;
|
||||
}
|
||||
|
||||
if (schemaObject.multipleOf !== undefined) {
|
||||
template.multipleOf = schemaObject.multipleOf;
|
||||
}
|
||||
|
||||
if (schemaObject.maximum !== undefined) {
|
||||
template.maximum = schemaObject.maximum;
|
||||
}
|
||||
|
||||
if (schemaObject.exclusiveMaximum !== undefined && isNumber(schemaObject.exclusiveMaximum)) {
|
||||
template.exclusiveMaximum = schemaObject.exclusiveMaximum;
|
||||
}
|
||||
|
||||
if (schemaObject.minimum !== undefined) {
|
||||
template.minimum = schemaObject.minimum;
|
||||
}
|
||||
|
||||
if (schemaObject.exclusiveMinimum !== undefined && isNumber(schemaObject.exclusiveMinimum)) {
|
||||
template.exclusiveMinimum = schemaObject.exclusiveMinimum;
|
||||
}
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildStringFieldInputTemplate: FieldInputTemplateBuilder<StringFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@@ -133,6 +232,36 @@ const buildStringFieldInputTemplate: FieldInputTemplateBuilder<StringFieldInputT
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildStringFieldCollectionInputTemplate: FieldInputTemplateBuilder<StringFieldCollectionInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: StringFieldCollectionInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? (schemaObject.orig_required ? [] : undefined),
|
||||
};
|
||||
|
||||
if (schemaObject.minLength !== undefined) {
|
||||
template.minLength = schemaObject.minLength;
|
||||
}
|
||||
|
||||
if (schemaObject.maxLength !== undefined) {
|
||||
template.maxLength = schemaObject.maxLength;
|
||||
}
|
||||
|
||||
if (schemaObject.minItems !== undefined) {
|
||||
template.minItems = schemaObject.minItems;
|
||||
}
|
||||
|
||||
if (schemaObject.maxItems !== undefined) {
|
||||
template.maxItems = schemaObject.maxItems;
|
||||
}
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildBooleanFieldInputTemplate: FieldInputTemplateBuilder<BooleanFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@@ -514,6 +643,48 @@ const buildSchedulerFieldInputTemplate: FieldInputTemplateBuilder<SchedulerField
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildFloatGeneratorFieldInputTemplate: FieldInputTemplateBuilder<FloatGeneratorFieldInputTemplate> = ({
|
||||
// schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: FloatGeneratorFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: getFloatGeneratorArithmeticSequenceDefaults(),
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildIntegerGeneratorFieldInputTemplate: FieldInputTemplateBuilder<IntegerGeneratorFieldInputTemplate> = ({
|
||||
// schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: IntegerGeneratorFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: getIntegerGeneratorArithmeticSequenceDefaults(),
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildStringGeneratorFieldInputTemplate: FieldInputTemplateBuilder<StringGeneratorFieldInputTemplate> = ({
|
||||
// schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: StringGeneratorFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: getStringGeneratorParseStringDefaults(),
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputTemplateBuilder> = {
|
||||
BoardField: buildBoardFieldInputTemplate,
|
||||
BooleanField: buildBooleanFieldInputTemplate,
|
||||
@@ -542,6 +713,9 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
|
||||
CLIPGEmbedModelField: buildCLIPGEmbedModelFieldInputTemplate,
|
||||
FluxVAEModelField: buildFluxVAEModelFieldInputTemplate,
|
||||
ControlLoRAModelField: buildControlLoRAModelFieldInputTemplate,
|
||||
FloatGeneratorField: buildFloatGeneratorFieldInputTemplate,
|
||||
IntegerGeneratorField: buildIntegerGeneratorFieldInputTemplate,
|
||||
StringGeneratorField: buildStringGeneratorFieldInputTemplate,
|
||||
} as const;
|
||||
|
||||
export const buildFieldInputTemplate = (
|
||||
@@ -569,12 +743,29 @@ export const buildFieldInputTemplate = (
|
||||
|
||||
if (isStatefulFieldType(fieldType)) {
|
||||
if (isImageCollectionFieldType(fieldType)) {
|
||||
fieldType;
|
||||
return buildImageFieldCollectionInputTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
fieldType,
|
||||
});
|
||||
} else if (isStringCollectionFieldType(fieldType)) {
|
||||
return buildStringFieldCollectionInputTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
fieldType,
|
||||
});
|
||||
} else if (isIntegerCollectionFieldType(fieldType)) {
|
||||
return buildIntegerFieldCollectionInputTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
fieldType,
|
||||
});
|
||||
} else if (isFloatCollectionFieldType(fieldType)) {
|
||||
return buildFloatFieldCollectionInputTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
fieldType,
|
||||
});
|
||||
} else {
|
||||
const builder = TEMPLATE_BUILDER_MAP[fieldType.name];
|
||||
const template = builder({
|
||||
|
||||
@@ -19,42 +19,42 @@ const primitiveTypes: ParseFieldTypeTestCase[] = [
|
||||
{
|
||||
name: 'SINGLE IntegerField',
|
||||
schema: { type: 'integer' },
|
||||
expected: { name: 'IntegerField', cardinality: 'SINGLE' },
|
||||
expected: { name: 'IntegerField', cardinality: 'SINGLE', batch: false },
|
||||
},
|
||||
{
|
||||
name: 'SINGLE FloatField',
|
||||
schema: { type: 'number' },
|
||||
expected: { name: 'FloatField', cardinality: 'SINGLE' },
|
||||
expected: { name: 'FloatField', cardinality: 'SINGLE', batch: false },
|
||||
},
|
||||
{
|
||||
name: 'SINGLE StringField',
|
||||
schema: { type: 'string' },
|
||||
expected: { name: 'StringField', cardinality: 'SINGLE' },
|
||||
expected: { name: 'StringField', cardinality: 'SINGLE', batch: false },
|
||||
},
|
||||
{
|
||||
name: 'SINGLE BooleanField',
|
||||
schema: { type: 'boolean' },
|
||||
expected: { name: 'BooleanField', cardinality: 'SINGLE' },
|
||||
expected: { name: 'BooleanField', cardinality: 'SINGLE', batch: false },
|
||||
},
|
||||
{
|
||||
name: 'COLLECTION IntegerField',
|
||||
schema: { items: { type: 'integer' }, type: 'array' },
|
||||
expected: { name: 'IntegerField', cardinality: 'COLLECTION' },
|
||||
expected: { name: 'IntegerField', cardinality: 'COLLECTION', batch: false },
|
||||
},
|
||||
{
|
||||
name: 'COLLECTION FloatField',
|
||||
schema: { items: { type: 'number' }, type: 'array' },
|
||||
expected: { name: 'FloatField', cardinality: 'COLLECTION' },
|
||||
expected: { name: 'FloatField', cardinality: 'COLLECTION', batch: false },
|
||||
},
|
||||
{
|
||||
name: 'COLLECTION StringField',
|
||||
schema: { items: { type: 'string' }, type: 'array' },
|
||||
expected: { name: 'StringField', cardinality: 'COLLECTION' },
|
||||
expected: { name: 'StringField', cardinality: 'COLLECTION', batch: false },
|
||||
},
|
||||
{
|
||||
name: 'COLLECTION BooleanField',
|
||||
schema: { items: { type: 'boolean' }, type: 'array' },
|
||||
expected: { name: 'BooleanField', cardinality: 'COLLECTION' },
|
||||
expected: { name: 'BooleanField', cardinality: 'COLLECTION', batch: false },
|
||||
},
|
||||
{
|
||||
name: 'SINGLE_OR_COLLECTION IntegerField',
|
||||
@@ -71,7 +71,7 @@ const primitiveTypes: ParseFieldTypeTestCase[] = [
|
||||
},
|
||||
],
|
||||
},
|
||||
expected: { name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' },
|
||||
expected: { name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION', batch: false },
|
||||
},
|
||||
{
|
||||
name: 'SINGLE_OR_COLLECTION FloatField',
|
||||
@@ -88,7 +88,7 @@ const primitiveTypes: ParseFieldTypeTestCase[] = [
|
||||
},
|
||||
],
|
||||
},
|
||||
expected: { name: 'FloatField', cardinality: 'SINGLE_OR_COLLECTION' },
|
||||
expected: { name: 'FloatField', cardinality: 'SINGLE_OR_COLLECTION', batch: false },
|
||||
},
|
||||
{
|
||||
name: 'SINGLE_OR_COLLECTION StringField',
|
||||
@@ -105,7 +105,7 @@ const primitiveTypes: ParseFieldTypeTestCase[] = [
|
||||
},
|
||||
],
|
||||
},
|
||||
expected: { name: 'StringField', cardinality: 'SINGLE_OR_COLLECTION' },
|
||||
expected: { name: 'StringField', cardinality: 'SINGLE_OR_COLLECTION', batch: false },
|
||||
},
|
||||
{
|
||||
name: 'SINGLE_OR_COLLECTION BooleanField',
|
||||
@@ -122,7 +122,7 @@ const primitiveTypes: ParseFieldTypeTestCase[] = [
|
||||
},
|
||||
],
|
||||
},
|
||||
expected: { name: 'BooleanField', cardinality: 'SINGLE_OR_COLLECTION' },
|
||||
expected: { name: 'BooleanField', cardinality: 'SINGLE_OR_COLLECTION', batch: false },
|
||||
},
|
||||
];
|
||||
|
||||
@@ -136,7 +136,7 @@ const complexTypes: ParseFieldTypeTestCase[] = [
|
||||
},
|
||||
],
|
||||
},
|
||||
expected: { name: 'ConditioningField', cardinality: 'SINGLE' },
|
||||
expected: { name: 'ConditioningField', cardinality: 'SINGLE', batch: false },
|
||||
},
|
||||
{
|
||||
name: 'Nullable SINGLE ConditioningField',
|
||||
@@ -150,7 +150,7 @@ const complexTypes: ParseFieldTypeTestCase[] = [
|
||||
},
|
||||
],
|
||||
},
|
||||
expected: { name: 'ConditioningField', cardinality: 'SINGLE' },
|
||||
expected: { name: 'ConditioningField', cardinality: 'SINGLE', batch: false },
|
||||
},
|
||||
{
|
||||
name: 'COLLECTION ConditioningField',
|
||||
@@ -164,7 +164,7 @@ const complexTypes: ParseFieldTypeTestCase[] = [
|
||||
},
|
||||
],
|
||||
},
|
||||
expected: { name: 'ConditioningField', cardinality: 'COLLECTION' },
|
||||
expected: { name: 'ConditioningField', cardinality: 'COLLECTION', batch: false },
|
||||
},
|
||||
{
|
||||
name: 'Nullable Collection ConditioningField',
|
||||
@@ -181,7 +181,7 @@ const complexTypes: ParseFieldTypeTestCase[] = [
|
||||
},
|
||||
],
|
||||
},
|
||||
expected: { name: 'ConditioningField', cardinality: 'COLLECTION' },
|
||||
expected: { name: 'ConditioningField', cardinality: 'COLLECTION', batch: false },
|
||||
},
|
||||
{
|
||||
name: 'SINGLE_OR_COLLECTION ConditioningField',
|
||||
@@ -198,7 +198,7 @@ const complexTypes: ParseFieldTypeTestCase[] = [
|
||||
},
|
||||
],
|
||||
},
|
||||
expected: { name: 'ConditioningField', cardinality: 'SINGLE_OR_COLLECTION' },
|
||||
expected: { name: 'ConditioningField', cardinality: 'SINGLE_OR_COLLECTION', batch: false },
|
||||
},
|
||||
{
|
||||
name: 'Nullable SINGLE_OR_COLLECTION ConditioningField',
|
||||
@@ -218,7 +218,7 @@ const complexTypes: ParseFieldTypeTestCase[] = [
|
||||
},
|
||||
],
|
||||
},
|
||||
expected: { name: 'ConditioningField', cardinality: 'SINGLE_OR_COLLECTION' },
|
||||
expected: { name: 'ConditioningField', cardinality: 'SINGLE_OR_COLLECTION', batch: false },
|
||||
},
|
||||
];
|
||||
|
||||
@@ -229,14 +229,14 @@ const specialCases: ParseFieldTypeTestCase[] = [
|
||||
type: 'string',
|
||||
enum: ['large', 'base', 'small'],
|
||||
},
|
||||
expected: { name: 'EnumField', cardinality: 'SINGLE' },
|
||||
expected: { name: 'EnumField', cardinality: 'SINGLE', batch: false },
|
||||
},
|
||||
{
|
||||
name: 'String EnumField with one value',
|
||||
schema: {
|
||||
const: 'Some Value',
|
||||
},
|
||||
expected: { name: 'EnumField', cardinality: 'SINGLE' },
|
||||
expected: { name: 'EnumField', cardinality: 'SINGLE', batch: false },
|
||||
},
|
||||
{
|
||||
name: 'Explicit ui_type (SchedulerField)',
|
||||
@@ -245,7 +245,7 @@ const specialCases: ParseFieldTypeTestCase[] = [
|
||||
enum: ['ddim', 'ddpm', 'deis'],
|
||||
ui_type: 'SchedulerField',
|
||||
},
|
||||
expected: { name: 'EnumField', cardinality: 'SINGLE' },
|
||||
expected: { name: 'EnumField', cardinality: 'SINGLE', batch: false },
|
||||
},
|
||||
{
|
||||
name: 'Explicit ui_type (AnyField)',
|
||||
@@ -254,7 +254,7 @@ const specialCases: ParseFieldTypeTestCase[] = [
|
||||
enum: ['ddim', 'ddpm', 'deis'],
|
||||
ui_type: 'AnyField',
|
||||
},
|
||||
expected: { name: 'EnumField', cardinality: 'SINGLE' },
|
||||
expected: { name: 'EnumField', cardinality: 'SINGLE', batch: false },
|
||||
},
|
||||
{
|
||||
name: 'Explicit ui_type (CollectionField)',
|
||||
@@ -263,7 +263,7 @@ const specialCases: ParseFieldTypeTestCase[] = [
|
||||
enum: ['ddim', 'ddpm', 'deis'],
|
||||
ui_type: 'CollectionField',
|
||||
},
|
||||
expected: { name: 'EnumField', cardinality: 'SINGLE' },
|
||||
expected: { name: 'EnumField', cardinality: 'SINGLE', batch: false },
|
||||
},
|
||||
];
|
||||
|
||||
|
||||
@@ -49,6 +49,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType
|
||||
return {
|
||||
name: 'EnumField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
};
|
||||
}
|
||||
if (!schemaObject.type) {
|
||||
@@ -65,6 +66,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType
|
||||
return {
|
||||
name,
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
};
|
||||
}
|
||||
} else if (schemaObject.anyOf) {
|
||||
@@ -88,6 +90,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType
|
||||
return {
|
||||
name,
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
};
|
||||
} else if (isSchemaObject(filteredAnyOf[0])) {
|
||||
return parseFieldType(filteredAnyOf[0]);
|
||||
@@ -141,6 +144,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType
|
||||
return {
|
||||
name: OPENAPI_TO_FIELD_TYPE_MAP[firstType] ?? firstType,
|
||||
cardinality: 'SINGLE_OR_COLLECTION',
|
||||
batch: false,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -155,6 +159,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType
|
||||
return {
|
||||
name: 'EnumField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
};
|
||||
} else if (schemaObject.type) {
|
||||
if (schemaObject.type === 'array') {
|
||||
@@ -181,6 +186,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType
|
||||
return {
|
||||
name,
|
||||
cardinality: 'COLLECTION',
|
||||
batch: false,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -192,6 +198,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType
|
||||
return {
|
||||
name,
|
||||
cardinality: 'COLLECTION',
|
||||
batch: false,
|
||||
};
|
||||
} else if (!isArray(schemaObject.type)) {
|
||||
// This is an OpenAPI primitive - 'null', 'object', 'array', 'integer', 'number', 'string', 'boolean'
|
||||
@@ -207,6 +214,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType
|
||||
return {
|
||||
name,
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -218,6 +226,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType
|
||||
return {
|
||||
name,
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
};
|
||||
}
|
||||
throw new FieldParseError(t('nodes.unableToParseFieldType'));
|
||||
|
||||
@@ -107,6 +107,7 @@ export const parseSchema = (
|
||||
? {
|
||||
name: property.ui_type,
|
||||
cardinality: isCollectionFieldType(property.ui_type) ? 'COLLECTION' : 'SINGLE',
|
||||
batch: false,
|
||||
}
|
||||
: null;
|
||||
|
||||
@@ -127,6 +128,12 @@ export const parseSchema = (
|
||||
fieldType.originalType = deepClone(originalFieldType);
|
||||
}
|
||||
|
||||
if (type === 'float_batch' && propertyName === 'floats') {
|
||||
fieldType.batch = true;
|
||||
} else if (type === 'integer_batch' && propertyName === 'integers') {
|
||||
fieldType.batch = true;
|
||||
}
|
||||
|
||||
const fieldInputTemplate = buildFieldInputTemplate(property, propertyName, fieldType);
|
||||
inputsAccumulator[propertyName] = fieldInputTemplate;
|
||||
|
||||
@@ -172,6 +179,7 @@ export const parseSchema = (
|
||||
? {
|
||||
name: property.ui_type,
|
||||
cardinality: isCollectionFieldType(property.ui_type) ? 'COLLECTION' : 'SINGLE',
|
||||
batch: false,
|
||||
}
|
||||
: null;
|
||||
|
||||
@@ -187,6 +195,14 @@ export const parseSchema = (
|
||||
fieldType.originalType = deepClone(originalFieldType);
|
||||
}
|
||||
|
||||
if (type === 'float_generator' && propertyName === 'floats') {
|
||||
fieldType.batch = true;
|
||||
} else if (type === 'integer_generator' && propertyName === 'integers') {
|
||||
fieldType.batch = true;
|
||||
} else if (type === 'string_generator' && propertyName === 'strings') {
|
||||
fieldType.batch = true;
|
||||
}
|
||||
|
||||
const fieldOutputTemplate = buildFieldOutputTemplate(property, propertyName, fieldType);
|
||||
|
||||
outputsAccumulator[propertyName] = fieldOutputTemplate;
|
||||
|
||||
@@ -20,7 +20,7 @@ import { z } from 'zod';
|
||||
* @param schema The zod schema to create a type guard from.
|
||||
* @returns A type guard function for the schema.
|
||||
*/
|
||||
const buildTypeGuard = <T extends z.ZodTypeAny>(schema: T) => {
|
||||
export const buildTypeGuard = <T extends z.ZodTypeAny>(schema: T) => {
|
||||
return (val: unknown): val is z.infer<T> => schema.safeParse(val).success;
|
||||
};
|
||||
|
||||
|
||||
@@ -206,8 +206,16 @@ const QueueCountPredictionWorkflowsTab = memo(() => {
|
||||
const iterationsCount = useAppSelector(selectIterations);
|
||||
|
||||
const text = useMemo(() => {
|
||||
const generationCount = Math.min(batchSize * iterationsCount, 10000);
|
||||
const iterations = t('queue.iterations', { count: iterationsCount });
|
||||
if (batchSize === 'NO_BATCHES') {
|
||||
const generationCount = Math.min(10000, iterationsCount);
|
||||
const generations = t('queue.generations', { count: generationCount });
|
||||
return `${iterationsCount} ${iterations} -> ${generationCount} ${generations}`.toLowerCase();
|
||||
}
|
||||
if (batchSize === 'EMPTY_BATCHES') {
|
||||
return t('parameters.invoke.invalidBatchConfigurationCannotCalculate');
|
||||
}
|
||||
const generationCount = Math.min(batchSize * iterationsCount, 10000);
|
||||
const generations = t('queue.generations', { count: generationCount });
|
||||
return `${batchSize} ${t('queue.batchSize')} \u00d7 ${iterationsCount} ${iterations} -> ${generationCount} ${generations}`.toLowerCase();
|
||||
}, [batchSize, iterationsCount, t]);
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import type { AppConfig } from 'app/types/invokeai';
|
||||
import type { ParamsState } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
|
||||
@@ -18,14 +19,36 @@ import { selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import type { NodesState, Templates } from 'features/nodes/store/types';
|
||||
import type { WorkflowSettingsState } from 'features/nodes/store/workflowSettingsSlice';
|
||||
import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
|
||||
import { isImageFieldCollectionInputInstance, isImageFieldCollectionInputTemplate } from 'features/nodes/types/field';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import {
|
||||
isFloatFieldCollectionInputInstance,
|
||||
isFloatFieldCollectionInputTemplate,
|
||||
isFloatGeneratorFieldInputInstance,
|
||||
isImageFieldCollectionInputInstance,
|
||||
isImageFieldCollectionInputTemplate,
|
||||
isIntegerFieldCollectionInputInstance,
|
||||
isIntegerFieldCollectionInputTemplate,
|
||||
isIntegerGeneratorFieldInputInstance,
|
||||
isStringFieldCollectionInputInstance,
|
||||
isStringFieldCollectionInputTemplate,
|
||||
isStringGeneratorFieldInputInstance,
|
||||
resolveFloatGeneratorField,
|
||||
resolveIntegerGeneratorField,
|
||||
resolveStringGeneratorField,
|
||||
} from 'features/nodes/types/field';
|
||||
import {
|
||||
validateImageFieldCollectionValue,
|
||||
validateNumberFieldCollectionValue,
|
||||
validateStringFieldCollectionValue,
|
||||
} from 'features/nodes/types/fieldValidators';
|
||||
import type { InvocationNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
|
||||
import { isBatchNode, isExecutableNode, isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import type { UpscaleState } from 'features/parameters/store/upscaleSlice';
|
||||
import { selectUpscaleSlice } from 'features/parameters/store/upscaleSlice';
|
||||
import { selectConfigSlice } from 'features/system/store/configSlice';
|
||||
import i18n from 'i18next';
|
||||
import { forEach, upperFirst } from 'lodash-es';
|
||||
import { forEach, groupBy, upperFirst } from 'lodash-es';
|
||||
import { getConnectedEdges } from 'reactflow';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
/**
|
||||
* This file contains selectors and utilities for determining the app is ready to enqueue generations. The handling
|
||||
@@ -47,6 +70,67 @@ export type Reason = { prefix?: string; content: string };
|
||||
|
||||
const disconnectedReason = (t: typeof i18n.t) => ({ content: t('parameters.invoke.systemDisconnected') });
|
||||
|
||||
export const resolveBatchValue = (batchNode: InvocationNode, nodes: InvocationNode[], edges: InvocationNodeEdge[]) => {
|
||||
if (batchNode.data.type === 'image_batch') {
|
||||
assert(isImageFieldCollectionInputInstance(batchNode.data.inputs.images));
|
||||
const ownValue = batchNode.data.inputs.images.value ?? [];
|
||||
// no generators for images yet
|
||||
return ownValue;
|
||||
} else if (batchNode.data.type === 'string_batch') {
|
||||
assert(isStringFieldCollectionInputInstance(batchNode.data.inputs.strings));
|
||||
const ownValue = batchNode.data.inputs.strings.value;
|
||||
const edgeToStrings = edges.find((edge) => edge.target === batchNode.id && edge.targetHandle === 'strings');
|
||||
|
||||
if (!edgeToStrings) {
|
||||
return ownValue ?? [];
|
||||
}
|
||||
|
||||
const generatorNode = nodes.find((node) => node.id === edgeToStrings.source);
|
||||
assert(generatorNode, 'Missing edge from string generator to string batch');
|
||||
|
||||
const generatorField = generatorNode.data.inputs['generator'];
|
||||
assert(isStringGeneratorFieldInputInstance(generatorField), 'Invalid string generator');
|
||||
|
||||
const generatorValue = resolveStringGeneratorField(generatorField);
|
||||
return generatorValue;
|
||||
} else if (batchNode.data.type === 'float_batch') {
|
||||
assert(isFloatFieldCollectionInputInstance(batchNode.data.inputs.floats));
|
||||
const ownValue = batchNode.data.inputs.floats.value;
|
||||
const edgeToFloats = edges.find((edge) => edge.target === batchNode.id && edge.targetHandle === 'floats');
|
||||
|
||||
if (!edgeToFloats) {
|
||||
return ownValue ?? [];
|
||||
}
|
||||
|
||||
const generatorNode = nodes.find((node) => node.id === edgeToFloats.source);
|
||||
assert(generatorNode, 'Missing edge from float generator to float batch');
|
||||
|
||||
const generatorField = generatorNode.data.inputs['generator'];
|
||||
assert(isFloatGeneratorFieldInputInstance(generatorField), 'Invalid float generator');
|
||||
|
||||
const generatorValue = resolveFloatGeneratorField(generatorField);
|
||||
return generatorValue;
|
||||
} else if (batchNode.data.type === 'integer_batch') {
|
||||
assert(isIntegerFieldCollectionInputInstance(batchNode.data.inputs.integers));
|
||||
const ownValue = batchNode.data.inputs.integers.value;
|
||||
const incomers = edges.find((edge) => edge.target === batchNode.id && edge.targetHandle === 'integers');
|
||||
|
||||
if (!incomers) {
|
||||
return ownValue ?? [];
|
||||
}
|
||||
|
||||
const generatorNode = nodes.find((node) => node.id === incomers.source);
|
||||
assert(generatorNode, 'Missing edge from integer generator to integer batch');
|
||||
|
||||
const generatorField = generatorNode.data.inputs['generator'];
|
||||
assert(isIntegerGeneratorFieldInputInstance(generatorField), 'Invalid integer generator field');
|
||||
|
||||
const generatorValue = resolveIntegerGeneratorField(generatorField);
|
||||
return generatorValue;
|
||||
}
|
||||
assert(false, 'Invalid batch node type');
|
||||
};
|
||||
|
||||
const getReasonsWhyCannotEnqueueWorkflowsTab = (arg: {
|
||||
isConnected: boolean;
|
||||
nodes: NodesState;
|
||||
@@ -61,11 +145,54 @@ const getReasonsWhyCannotEnqueueWorkflowsTab = (arg: {
|
||||
}
|
||||
|
||||
if (workflowSettings.shouldValidateGraph) {
|
||||
if (!nodes.nodes.length) {
|
||||
const invocationNodes = nodes.nodes.filter(isInvocationNode);
|
||||
const batchNodes = invocationNodes.filter(isBatchNode);
|
||||
const executableNodes = invocationNodes.filter(isExecutableNode);
|
||||
|
||||
if (!executableNodes.length) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.noNodesInGraph') });
|
||||
}
|
||||
|
||||
nodes.nodes.forEach((node) => {
|
||||
for (const node of batchNodes) {
|
||||
if (nodes.edges.find((e) => e.source === node.id) === undefined) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.batchNodeNotConnected', { label: node.data.label }) });
|
||||
}
|
||||
}
|
||||
|
||||
if (batchNodes.length > 1) {
|
||||
const batchSizes: number[] = [];
|
||||
const groupedBatchNodes = groupBy(batchNodes, (node) => node.data.inputs['batch_group_id']?.value);
|
||||
for (const [batchGroupId, batchNodes] of Object.entries(groupedBatchNodes)) {
|
||||
// But grouped batch nodes must have the same collection size
|
||||
const groupBatchSizes: number[] = [];
|
||||
|
||||
for (const node of batchNodes) {
|
||||
const size = resolveBatchValue(node, invocationNodes, nodes.edges).length;
|
||||
if (batchGroupId === 'None') {
|
||||
// Ungrouped batch nodes may have differing collection sizes
|
||||
batchSizes.push(size);
|
||||
} else {
|
||||
groupBatchSizes.push(size);
|
||||
}
|
||||
}
|
||||
|
||||
if (groupBatchSizes.some((count) => count !== groupBatchSizes[0])) {
|
||||
reasons.push({
|
||||
content: i18n.t('parameters.invoke.batchNodeCollectionSizeMismatch', { batchGroupId }),
|
||||
});
|
||||
}
|
||||
|
||||
if (groupBatchSizes[0] !== undefined) {
|
||||
batchSizes.push(groupBatchSizes[0]);
|
||||
}
|
||||
}
|
||||
|
||||
if (batchSizes.some((size) => size === 0)) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.batchNodeEmptyCollection') });
|
||||
}
|
||||
}
|
||||
|
||||
executableNodes.forEach((node) => {
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
@@ -91,45 +218,38 @@ const getReasonsWhyCannotEnqueueWorkflowsTab = (arg: {
|
||||
return;
|
||||
}
|
||||
|
||||
const baseTKeyOptions = {
|
||||
nodeLabel: node.data.label || nodeTemplate.title,
|
||||
fieldLabel: field.label || fieldTemplate.title,
|
||||
};
|
||||
const prefix = `${node.data.label || nodeTemplate.title} -> ${field.label || fieldTemplate.title}`;
|
||||
|
||||
if (fieldTemplate.required && field.value === undefined && !hasConnection) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.missingInputForField', baseTKeyOptions) });
|
||||
return;
|
||||
reasons.push({ prefix, content: i18n.t('parameters.invoke.missingInputForField') });
|
||||
} else if (
|
||||
field.value &&
|
||||
isImageFieldCollectionInputInstance(field) &&
|
||||
isImageFieldCollectionInputTemplate(fieldTemplate)
|
||||
) {
|
||||
// Image collections may have min or max items to validate
|
||||
// TODO(psyche): generalize this to other collection types
|
||||
if (fieldTemplate.minItems !== undefined && fieldTemplate.minItems > 0 && field.value.length === 0) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.collectionEmpty', baseTKeyOptions) });
|
||||
return;
|
||||
}
|
||||
if (fieldTemplate.minItems !== undefined && field.value.length < fieldTemplate.minItems) {
|
||||
reasons.push({
|
||||
content: i18n.t('parameters.invoke.collectionTooFewItems', {
|
||||
...baseTKeyOptions,
|
||||
size: field.value.length,
|
||||
minItems: fieldTemplate.minItems,
|
||||
}),
|
||||
});
|
||||
return;
|
||||
}
|
||||
if (fieldTemplate.maxItems !== undefined && field.value.length > fieldTemplate.maxItems) {
|
||||
reasons.push({
|
||||
content: i18n.t('parameters.invoke.collectionTooManyItems', {
|
||||
...baseTKeyOptions,
|
||||
size: field.value.length,
|
||||
maxItems: fieldTemplate.maxItems,
|
||||
}),
|
||||
});
|
||||
return;
|
||||
}
|
||||
const errors = validateImageFieldCollectionValue(field.value, fieldTemplate);
|
||||
reasons.push(...errors.map((error) => ({ prefix, content: error })));
|
||||
} else if (
|
||||
field.value &&
|
||||
isStringFieldCollectionInputInstance(field) &&
|
||||
isStringFieldCollectionInputTemplate(fieldTemplate)
|
||||
) {
|
||||
const errors = validateStringFieldCollectionValue(field.value, fieldTemplate);
|
||||
reasons.push(...errors.map((error) => ({ prefix, content: error })));
|
||||
} else if (
|
||||
field.value &&
|
||||
isIntegerFieldCollectionInputInstance(field) &&
|
||||
isIntegerFieldCollectionInputTemplate(fieldTemplate)
|
||||
) {
|
||||
const errors = validateNumberFieldCollectionValue(field.value, fieldTemplate);
|
||||
reasons.push(...errors.map((error) => ({ prefix, content: error })));
|
||||
} else if (
|
||||
field.value &&
|
||||
isFloatFieldCollectionInputInstance(field) &&
|
||||
isFloatFieldCollectionInputTemplate(fieldTemplate)
|
||||
) {
|
||||
const errors = validateNumberFieldCollectionValue(field.value, fieldTemplate);
|
||||
reasons.push(...errors.map((error) => ({ prefix, content: error })));
|
||||
}
|
||||
});
|
||||
});
|
||||
@@ -491,17 +611,80 @@ export const selectPromptsCount = createSelector(
|
||||
(params, dynamicPrompts) => (getShouldProcessPrompt(params.positivePrompt) ? dynamicPrompts.prompts.length : 1)
|
||||
);
|
||||
|
||||
export const selectWorkflowsBatchSize = createSelector(selectNodesSlice, ({ nodes }) =>
|
||||
// The batch size is the product of all batch nodes' collection sizes
|
||||
nodes.filter(isInvocationNode).reduce((batchSize, node) => {
|
||||
if (!isImageFieldCollectionInputInstance(node.data.inputs.images)) {
|
||||
return batchSize;
|
||||
}
|
||||
// If the batch size is not set, default to 1
|
||||
batchSize = batchSize || 1;
|
||||
// Multiply the batch size by the number of images in the batch
|
||||
batchSize = batchSize * (node.data.inputs.images.value?.length ?? 0);
|
||||
const buildSelectGroupBatchSizes = (batchGroupId: string) =>
|
||||
createMemoizedSelector(selectNodesSlice, ({ nodes, edges }) => {
|
||||
const invocationNodes = nodes.filter(isInvocationNode);
|
||||
return invocationNodes
|
||||
.filter(isBatchNode)
|
||||
.filter((node) => node.data.inputs['batch_group_id']?.value === batchGroupId)
|
||||
.map((batchNodes) => resolveBatchValue(batchNodes, invocationNodes, edges).length);
|
||||
});
|
||||
|
||||
return batchSize;
|
||||
}, 0)
|
||||
const selectUngroupedBatchSizes = buildSelectGroupBatchSizes('None');
|
||||
const selectGroup1BatchSizes = buildSelectGroupBatchSizes('Group 1');
|
||||
const selectGroup2BatchSizes = buildSelectGroupBatchSizes('Group 2');
|
||||
const selectGroup3BatchSizes = buildSelectGroupBatchSizes('Group 3');
|
||||
const selectGroup4BatchSizes = buildSelectGroupBatchSizes('Group 4');
|
||||
const selectGroup5BatchSizes = buildSelectGroupBatchSizes('Group 5');
|
||||
|
||||
export const selectWorkflowsBatchSize = createSelector(
|
||||
selectUngroupedBatchSizes,
|
||||
selectGroup1BatchSizes,
|
||||
selectGroup2BatchSizes,
|
||||
selectGroup3BatchSizes,
|
||||
selectGroup4BatchSizes,
|
||||
selectGroup5BatchSizes,
|
||||
(
|
||||
ungroupedBatchSizes,
|
||||
group1BatchSizes,
|
||||
group2BatchSizes,
|
||||
group3BatchSizes,
|
||||
group4BatchSizes,
|
||||
group5BatchSizes
|
||||
): number | 'EMPTY_BATCHES' | 'NO_BATCHES' => {
|
||||
// All batch nodes _must_ have a populated collection
|
||||
|
||||
const allBatchSizes = [
|
||||
...ungroupedBatchSizes,
|
||||
...group1BatchSizes,
|
||||
...group2BatchSizes,
|
||||
...group3BatchSizes,
|
||||
...group4BatchSizes,
|
||||
...group5BatchSizes,
|
||||
];
|
||||
|
||||
// There are no batch nodes
|
||||
if (allBatchSizes.length === 0) {
|
||||
return 'NO_BATCHES';
|
||||
}
|
||||
|
||||
// All batch nodes must have a populated collection
|
||||
if (allBatchSizes.some((size) => size === 0)) {
|
||||
return 'EMPTY_BATCHES';
|
||||
}
|
||||
|
||||
for (const group of [group1BatchSizes, group2BatchSizes, group3BatchSizes, group4BatchSizes, group5BatchSizes]) {
|
||||
// Ignore groups with no batch nodes
|
||||
if (group.length === 0) {
|
||||
continue;
|
||||
}
|
||||
// Grouped batch nodes must have the same collection size
|
||||
if (group.some((size) => size !== group[0])) {
|
||||
return 'EMPTY_BATCHES';
|
||||
}
|
||||
}
|
||||
|
||||
// Total batch size = product of all ungrouped batches and each grouped batch
|
||||
const totalBatchSize = [
|
||||
...ungroupedBatchSizes,
|
||||
// In case of no batch nodes in a group, fall back to 1 for the product calculation
|
||||
group1BatchSizes[0] ?? 1,
|
||||
group2BatchSizes[0] ?? 1,
|
||||
group3BatchSizes[0] ?? 1,
|
||||
group4BatchSizes[0] ?? 1,
|
||||
group5BatchSizes[0] ?? 1,
|
||||
].reduce((acc, size) => acc * size, 1);
|
||||
|
||||
return totalBatchSize;
|
||||
}
|
||||
);
|
||||
|
||||
@@ -9,6 +9,7 @@ const StylePresetImage = ({ presetImageUrl, imageWidth }: { presetImageUrl: stri
|
||||
return (
|
||||
<Tooltip
|
||||
closeOnScroll
|
||||
openDelay={0}
|
||||
label={
|
||||
presetImageUrl && (
|
||||
<Image
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import type { components } from 'services/api/schema';
|
||||
import type { paths } from 'services/api/schema';
|
||||
|
||||
import { api, buildV1Url } from '..';
|
||||
|
||||
@@ -13,8 +13,8 @@ const buildUtilitiesUrl = (path: string = '') => buildV1Url(`utilities/${path}`)
|
||||
export const utilitiesApi = api.injectEndpoints({
|
||||
endpoints: (build) => ({
|
||||
dynamicPrompts: build.query<
|
||||
components['schemas']['DynamicPromptsResponse'],
|
||||
{ prompt: string; max_prompts: number }
|
||||
paths['/api/v1/utilities/dynamicprompts']['post']['responses']['200']['content']['application/json'],
|
||||
paths['/api/v1/utilities/dynamicprompts']['post']['requestBody']['content']['application/json']
|
||||
>({
|
||||
query: (arg) => ({
|
||||
url: buildUtilitiesUrl('dynamicprompts'),
|
||||
@@ -28,3 +28,5 @@ export const utilitiesApi = api.injectEndpoints({
|
||||
}),
|
||||
}),
|
||||
});
|
||||
|
||||
export const { useDynamicPromptsQuery } = utilitiesApi;
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -1 +1 @@
|
||||
__version__ = "5.6.0rc2"
|
||||
__version__ = "5.6.0rc4"
|
||||
|
||||
@@ -3,7 +3,11 @@ import torch
|
||||
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import (
|
||||
CachedModelOnlyFullLoad,
|
||||
)
|
||||
from tests.backend.model_manager.load.model_cache.cached_model.utils import DummyModule, parameterize_mps_and_cuda
|
||||
from tests.backend.model_manager.load.model_cache.cached_model.utils import (
|
||||
DummyModule,
|
||||
parameterize_keep_ram_copy,
|
||||
parameterize_mps_and_cuda,
|
||||
)
|
||||
|
||||
|
||||
class NonTorchModel:
|
||||
@@ -17,16 +21,22 @@ class NonTorchModel:
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_total_bytes(device: str):
|
||||
@parameterize_keep_ram_copy
|
||||
def test_cached_model_total_bytes(device: str, keep_ram_copy: bool):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
|
||||
cached_model = CachedModelOnlyFullLoad(
|
||||
model=model, compute_device=torch.device(device), total_bytes=100, keep_ram_copy=keep_ram_copy
|
||||
)
|
||||
assert cached_model.total_bytes() == 100
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_is_in_vram(device: str):
|
||||
@parameterize_keep_ram_copy
|
||||
def test_cached_model_is_in_vram(device: str, keep_ram_copy: bool):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
|
||||
cached_model = CachedModelOnlyFullLoad(
|
||||
model=model, compute_device=torch.device(device), total_bytes=100, keep_ram_copy=keep_ram_copy
|
||||
)
|
||||
assert not cached_model.is_in_vram()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
@@ -40,9 +50,12 @@ def test_cached_model_is_in_vram(device: str):
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_full_load_and_unload(device: str):
|
||||
@parameterize_keep_ram_copy
|
||||
def test_cached_model_full_load_and_unload(device: str, keep_ram_copy: bool):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
|
||||
cached_model = CachedModelOnlyFullLoad(
|
||||
model=model, compute_device=torch.device(device), total_bytes=100, keep_ram_copy=keep_ram_copy
|
||||
)
|
||||
assert cached_model.full_load_to_vram() == 100
|
||||
assert cached_model.is_in_vram()
|
||||
assert all(p.device.type == device for p in cached_model.model.parameters())
|
||||
@@ -55,7 +68,9 @@ def test_cached_model_full_load_and_unload(device: str):
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_get_cpu_state_dict(device: str):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
|
||||
cached_model = CachedModelOnlyFullLoad(
|
||||
model=model, compute_device=torch.device(device), total_bytes=100, keep_ram_copy=True
|
||||
)
|
||||
assert not cached_model.is_in_vram()
|
||||
|
||||
# The CPU state dict can be accessed and has the expected properties.
|
||||
@@ -76,9 +91,12 @@ def test_cached_model_get_cpu_state_dict(device: str):
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_full_load_and_inference(device: str):
|
||||
@parameterize_keep_ram_copy
|
||||
def test_cached_model_full_load_and_inference(device: str, keep_ram_copy: bool):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
|
||||
cached_model = CachedModelOnlyFullLoad(
|
||||
model=model, compute_device=torch.device(device), total_bytes=100, keep_ram_copy=keep_ram_copy
|
||||
)
|
||||
assert not cached_model.is_in_vram()
|
||||
|
||||
# Run inference on the CPU.
|
||||
@@ -99,9 +117,12 @@ def test_cached_model_full_load_and_inference(device: str):
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_non_torch_model(device: str):
|
||||
@parameterize_keep_ram_copy
|
||||
def test_non_torch_model(device: str, keep_ram_copy: bool):
|
||||
model = NonTorchModel()
|
||||
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
|
||||
cached_model = CachedModelOnlyFullLoad(
|
||||
model=model, compute_device=torch.device(device), total_bytes=100, keep_ram_copy=keep_ram_copy
|
||||
)
|
||||
assert not cached_model.is_in_vram()
|
||||
|
||||
# The model does not have a CPU state dict.
|
||||
|
||||
@@ -10,7 +10,11 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch
|
||||
apply_custom_layers_to_model,
|
||||
)
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
|
||||
from tests.backend.model_manager.load.model_cache.cached_model.utils import DummyModule, parameterize_mps_and_cuda
|
||||
from tests.backend.model_manager.load.model_cache.cached_model.utils import (
|
||||
DummyModule,
|
||||
parameterize_keep_ram_copy,
|
||||
parameterize_mps_and_cuda,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -21,8 +25,11 @@ def model():
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_total_bytes(device: str, model: DummyModule):
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
@parameterize_keep_ram_copy
|
||||
def test_cached_model_total_bytes(device: str, model: DummyModule, keep_ram_copy: bool):
|
||||
cached_model = CachedModelWithPartialLoad(
|
||||
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
|
||||
)
|
||||
linear1_numel = 10 * 32 + 32
|
||||
linear2_numel = 32 * 64 + 64
|
||||
buffer1_numel = 64
|
||||
@@ -31,9 +38,12 @@ def test_cached_model_total_bytes(device: str, model: DummyModule):
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_cur_vram_bytes(device: str, model: DummyModule):
|
||||
@parameterize_keep_ram_copy
|
||||
def test_cached_model_cur_vram_bytes(device: str, model: DummyModule, keep_ram_copy: bool):
|
||||
# Model starts in CPU memory.
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
cached_model = CachedModelWithPartialLoad(
|
||||
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
|
||||
)
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
# Full load the model into VRAM.
|
||||
@@ -45,9 +55,12 @@ def test_cached_model_cur_vram_bytes(device: str, model: DummyModule):
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_partial_load(device: str, model: DummyModule):
|
||||
@parameterize_keep_ram_copy
|
||||
def test_cached_model_partial_load(device: str, model: DummyModule, keep_ram_copy: bool):
|
||||
# Model starts in CPU memory.
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
cached_model = CachedModelWithPartialLoad(
|
||||
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
|
||||
)
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
@@ -71,9 +84,12 @@ def test_cached_model_partial_load(device: str, model: DummyModule):
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_partial_unload(device: str, model: DummyModule):
|
||||
@parameterize_keep_ram_copy
|
||||
def test_cached_model_partial_unload(device: str, model: DummyModule, keep_ram_copy: bool):
|
||||
# Model starts in CPU memory.
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
cached_model = CachedModelWithPartialLoad(
|
||||
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
|
||||
)
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
@@ -99,9 +115,14 @@ def test_cached_model_partial_unload(device: str, model: DummyModule):
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_partial_unload_keep_required_weights_in_vram(device: str, model: DummyModule):
|
||||
@parameterize_keep_ram_copy
|
||||
def test_cached_model_partial_unload_keep_required_weights_in_vram(
|
||||
device: str, model: DummyModule, keep_ram_copy: bool
|
||||
):
|
||||
# Model starts in CPU memory.
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
cached_model = CachedModelWithPartialLoad(
|
||||
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
|
||||
)
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
@@ -130,8 +151,11 @@ def test_cached_model_partial_unload_keep_required_weights_in_vram(device: str,
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_full_load_and_unload(device: str, model: DummyModule):
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
@parameterize_keep_ram_copy
|
||||
def test_cached_model_full_load_and_unload(device: str, model: DummyModule, keep_ram_copy: bool):
|
||||
cached_model = CachedModelWithPartialLoad(
|
||||
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
|
||||
)
|
||||
|
||||
# Model starts in CPU memory.
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
@@ -162,8 +186,11 @@ def test_cached_model_full_load_and_unload(device: str, model: DummyModule):
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_full_load_from_partial(device: str, model: DummyModule):
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
@parameterize_keep_ram_copy
|
||||
def test_cached_model_full_load_from_partial(device: str, model: DummyModule, keep_ram_copy: bool):
|
||||
cached_model = CachedModelWithPartialLoad(
|
||||
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
|
||||
)
|
||||
|
||||
# Model starts in CPU memory.
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
@@ -190,8 +217,11 @@ def test_cached_model_full_load_from_partial(device: str, model: DummyModule):
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_full_unload_from_partial(device: str, model: DummyModule):
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
@parameterize_keep_ram_copy
|
||||
def test_cached_model_full_unload_from_partial(device: str, model: DummyModule, keep_ram_copy: bool):
|
||||
cached_model = CachedModelWithPartialLoad(
|
||||
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
|
||||
)
|
||||
|
||||
# Model starts in CPU memory.
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
@@ -219,7 +249,7 @@ def test_cached_model_full_unload_from_partial(device: str, model: DummyModule):
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_get_cpu_state_dict(device: str, model: DummyModule):
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device), keep_ram_copy=True)
|
||||
|
||||
# Model starts in CPU memory.
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
@@ -242,8 +272,11 @@ def test_cached_model_get_cpu_state_dict(device: str, model: DummyModule):
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_full_load_and_inference(device: str, model: DummyModule):
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
@parameterize_keep_ram_copy
|
||||
def test_cached_model_full_load_and_inference(device: str, model: DummyModule, keep_ram_copy: bool):
|
||||
cached_model = CachedModelWithPartialLoad(
|
||||
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
|
||||
)
|
||||
# Model starts in CPU memory.
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
@@ -269,9 +302,12 @@ def test_cached_model_full_load_and_inference(device: str, model: DummyModule):
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_partial_load_and_inference(device: str, model: DummyModule):
|
||||
@parameterize_keep_ram_copy
|
||||
def test_cached_model_partial_load_and_inference(device: str, model: DummyModule, keep_ram_copy: bool):
|
||||
# Model starts in CPU memory.
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
cached_model = CachedModelWithPartialLoad(
|
||||
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
|
||||
)
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
|
||||
@@ -29,3 +29,5 @@ parameterize_mps_and_cuda = pytest.mark.parametrize(
|
||||
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available.")),
|
||||
],
|
||||
)
|
||||
|
||||
parameterize_keep_ram_copy = pytest.mark.parametrize("keep_ram_copy", [True, False])
|
||||
|
||||
@@ -94,6 +94,7 @@ def mm2_loader(mm2_app_config: InvokeAIAppConfig) -> ModelLoadServiceBase:
|
||||
ram_cache = ModelCache(
|
||||
execution_device_working_mem_gb=mm2_app_config.device_working_mem_gb,
|
||||
enable_partial_loading=mm2_app_config.enable_partial_loading,
|
||||
keep_ram_copy_of_weights=mm2_app_config.keep_ram_copy_of_weights,
|
||||
max_ram_cache_size_gb=mm2_app_config.max_cache_ram_gb,
|
||||
max_vram_cache_size_gb=mm2_app_config.max_cache_vram_gb,
|
||||
execution_device=TorchDevice.choose_torch_device(),
|
||||
|
||||
@@ -189,6 +189,26 @@ def test_cannot_create_bad_batch_items_type(batch_graph):
|
||||
)
|
||||
|
||||
|
||||
def test_number_type_interop(batch_graph):
|
||||
# integers and floats can be mixed, should not throw an error
|
||||
Batch(
|
||||
graph=batch_graph,
|
||||
data=[
|
||||
[
|
||||
BatchDatum(node_path="1", field_name="prompt", items=[1, 1.5]),
|
||||
]
|
||||
],
|
||||
)
|
||||
Batch(
|
||||
graph=batch_graph,
|
||||
data=[
|
||||
[
|
||||
BatchDatum(node_path="1", field_name="prompt", items=[1.5, 1]),
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_cannot_create_bad_batch_unique_ids(batch_graph):
|
||||
with pytest.raises(ValidationError, match="Each batch data must have unique node_id and field_name"):
|
||||
Batch(
|
||||
|
||||
Reference in New Issue
Block a user