Compare commits

..

2 Commits

Author SHA1 Message Date
Ryan Dick
020968a021 Load state dict one module at a time in CachedModelWithPartialLoad. 2025-01-14 21:36:55 +00:00
Ryan Dick
bb52317377 Add keep_ram_copy option to CachedModelWithPartialLoad. 2025-01-14 16:09:35 +00:00
156 changed files with 1020 additions and 9305 deletions

1
.nvmrc
View File

@@ -1 +0,0 @@
v22.12.0

View File

@@ -28,12 +28,11 @@ It is possible to fine-tune the settings for best performance or if you still ge
## Details and fine-tuning
Low-VRAM mode involves 4 features, each of which can be configured or fine-tuned:
Low-VRAM mode involves 3 features, each of which can be configured or fine-tuned:
- Partial model loading (`enable_partial_loading`)
- Dynamic RAM and VRAM cache sizes (`max_cache_ram_gb`, `max_cache_vram_gb`)
- Working memory (`device_working_mem_gb`)
- Keeping a RAM weight copy (`keep_ram_copy_of_weights`)
- Partial model loading
- Dynamic RAM and VRAM cache sizes
- Working memory
Read on to learn about these features and understand how to fine-tune them for your system and use-cases.
@@ -68,20 +67,12 @@ As of v5.6.0, the caches are dynamically sized. The `ram` and `vram` settings ar
But, if your GPU has enough VRAM to hold models fully, you might get a perf boost by manually setting the cache sizes in `invokeai.yaml`:
```yaml
# The default max cache RAM size is logged on InvokeAI startup. It is determined based on your system RAM / VRAM.
# You can override the default value by setting `max_cache_ram_gb`.
# Increasing `max_cache_ram_gb` will increase the amount of RAM used to cache inactive models, resulting in faster model
# reloads for the cached models.
# As an example, if your system has 32GB of RAM and no other heavy processes, setting the `max_cache_ram_gb` to 28GB
# might be a good value to achieve aggressive model caching.
# Set the RAM cache size to as large as possible, leaving a few GB free for the rest of your system and Invoke.
# For example, if your system has 32GB RAM, 28GB is a good value.
max_cache_ram_gb: 28
# The default max cache VRAM size is adjusted dynamically based on the amount of available VRAM (taking into
# consideration the VRAM used by other processes).
# You can override the default value by setting `max_cache_vram_gb`. Note that this value takes precedence over the
# `device_working_mem_gb`.
# It is recommended to set the VRAM cache size to be as large as possible while leaving enough room for the working
# memory of the tasks you will be doing. For example, on a 24GB GPU that will be running unquantized FLUX without any
# auxiliary models, 18GB might be a good value.
# Set the VRAM cache size to be as large as possible while leaving enough room for the working memory of the tasks you will be doing.
# For example, on a 24GB GPU that will be running unquantized FLUX without any auxiliary models,
# 18GB is a good value.
max_cache_vram_gb: 18
```
@@ -118,15 +109,6 @@ device_working_mem_gb: 4
Once decoding completes, the model manager "reclaims" the extra VRAM allocated as working memory for future model loading operations.
### Keeping a RAM weight copy
Invoke has the option of keeping a RAM copy of all model weights, even when they are loaded onto the GPU. This optimization is _on_ by default, and enables faster model switching and LoRA patching. Disabling this feature will reduce the average RAM load while running Invoke (peak RAM likely won't change), at the cost of slower model switching and LoRA patching. If you have limited RAM, you can disable this optimization:
```yaml
# Set to false to reduce the average RAM usage at the cost of slower model switching and LoRA patching.
keep_ram_copy_of_weights: false
```
### Disabling Nvidia sysmem fallback (Windows only)
On Windows, Nvidia GPUs are able to use system RAM when their VRAM fills up via **sysmem fallback**. While it sounds like a good idea on the surface, in practice it causes massive slowdowns during generation.
@@ -145,19 +127,3 @@ It is strongly suggested to disable this feature:
If the sysmem fallback feature sounds familiar, that's because Invoke's partial model loading strategy is conceptually very similar - use VRAM when there's room, else fall back to RAM.
Unfortunately, the Nvidia implementation is not optimized for applications like Invoke and does more harm than good.
## Troubleshooting
### Windows page file
Invoke has high virtual memory (a.k.a. 'committed memory') requirements. This can cause issues on Windows if the page file size limits are hit. (See this issue for the technical details on why this happens: https://github.com/invoke-ai/InvokeAI/issues/7563).
If you run out of page file space, InvokeAI may crash. Often, these crashes will happen with one of the following errors:
- InvokeAI exits with Windows error code `3221225477`
- InvokeAI crashes without an error, but `eventvwr.msc` reveals an error with code `0xc0000005` (the hex equivalent of `3221225477`)
If you are running out of page file space, try the following solutions:
- Make sure that you have sufficient disk space for the page file to grow. Watch your disk usage as Invoke runs. If it climbs near 100% leading up to the crash, then this is very likely the source of the issue. Clear out some disk space to resolve the issue.
- Make sure that your page file is set to "System managed size" (this is the default) rather than a custom size. Under the "System managed size" policy, the page file will grow dynamically as needed.

View File

@@ -88,13 +88,13 @@ The following commands vary depending on the version of Invoke being installed a
8. Install the `invokeai` package. Substitute the package specifier and version.
```sh
uv pip install <PACKAGE_SPECIFIER>==<VERSION> --python 3.11 --python-preference only-managed --force-reinstall
uv pip install <PACKAGE_SPECIFIER>=<VERSION> --python 3.11 --python-preference only-managed --force-reinstall
```
If you determined you needed to use a `PyPI` index URL in the previous step, you'll need to add `--index=<INDEX_URL>` like this:
```sh
uv pip install <PACKAGE_SPECIFIER>==<VERSION> --python 3.11 --python-preference only-managed --index=<INDEX_URL> --force-reinstall
uv pip install <PACKAGE_SPECIFIER>=<VERSION> --python 3.11 --python-preference only-managed --index=<INDEX_URL> --force-reinstall
```
9. Deactivate and reactivate your venv so that the invokeai-specific commands become available in the environment:

View File

@@ -858,18 +858,6 @@ async def get_stats() -> Optional[CacheStats]:
return ApiDependencies.invoker.services.model_manager.load.ram_cache.stats
@model_manager_router.post(
"/empty_model_cache",
operation_id="empty_model_cache",
status_code=200,
)
async def empty_model_cache() -> None:
"""Drop all models from the model cache to free RAM/VRAM. 'Locked' models that are in active use will not be dropped."""
# Request 1000GB of room in order to force the cache to drop all models.
ApiDependencies.invoker.services.logger.info("Emptying model cache.")
ApiDependencies.invoker.services.model_manager.load.ram_cache.make_room(1000 * 2**30)
class HFTokenStatus(str, Enum):
VALID = "valid"
INVALID = "invalid"

View File

@@ -10,7 +10,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
QUEUE_ITEM_STATUS,
Batch,
BatchStatus,
CancelAllExceptCurrentResult,
CancelByBatchIDsResult,
CancelByDestinationResult,
ClearResult,
@@ -95,18 +94,6 @@ async def Pause(
return ApiDependencies.invoker.services.session_processor.pause()
@session_queue_router.put(
"/{queue_id}/cancel_all_except_current",
operation_id="cancel_all_except_current",
responses={200: {"model": CancelAllExceptCurrentResult}},
)
async def cancel_all_except_current(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> CancelAllExceptCurrentResult:
"""Immediately cancels all queue items except in-processing items"""
return ApiDependencies.invoker.services.session_queue.cancel_all_except_current(queue_id=queue_id)
@session_queue_router.put(
"/{queue_id}/cancel_by_batch_ids",
operation_id="cancel_by_batch_ids",

View File

@@ -25,7 +25,6 @@ async def parse_dynamicprompts(
prompt: str = Body(description="The prompt to parse with dynamicprompts"),
max_prompts: int = Body(ge=1, le=10000, default=1000, description="The max number of prompts to generate"),
combinatorial: bool = Body(default=True, description="Whether to use the combinatorial generator"),
seed: int | None = Body(None, description="The seed to use for random generation. Only used if not combinatorial"),
) -> DynamicPromptsResponse:
"""Creates a batch process"""
max_prompts = min(max_prompts, 10000)
@@ -36,7 +35,7 @@ async def parse_dynamicprompts(
generator = CombinatorialPromptGenerator()
prompts = generator.generate(prompt, max_prompts=max_prompts)
else:
generator = RandomPromptGenerator(seed=seed)
generator = RandomPromptGenerator()
prompts = generator.generate(prompt, num_images=max_prompts)
except ParseException as e:
prompts = [prompt]

View File

@@ -1,237 +0,0 @@
from typing import Literal
from pydantic import BaseModel
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import (
ImageField,
Input,
InputField,
OutputField,
)
from invokeai.app.invocations.primitives import (
FloatOutput,
ImageOutput,
IntegerOutput,
StringOutput,
)
from invokeai.app.services.shared.invocation_context import InvocationContext
BATCH_GROUP_IDS = Literal[
"None",
"Group 1",
"Group 2",
"Group 3",
"Group 4",
"Group 5",
]
class NotExecutableNodeError(Exception):
def __init__(self, message: str = "This class should never be executed or instantiated directly."):
super().__init__(message)
pass
class BaseBatchInvocation(BaseInvocation):
batch_group_id: BATCH_GROUP_IDS = InputField(
default="None",
description="The ID of this batch node's group. If provided, all batch nodes in with the same ID will be 'zipped' before execution, and all nodes' collections must be of the same size.",
input=Input.Direct,
title="Batch Group",
)
def __init__(self):
raise NotExecutableNodeError()
@invocation(
"image_batch",
title="Image Batch",
tags=["primitives", "image", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class ImageBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each image in the batch."""
images: list[ImageField] = InputField(
default=[], min_length=1, description="The images to batch over", input=Input.Direct
)
def invoke(self, context: InvocationContext) -> ImageOutput:
raise NotExecutableNodeError()
@invocation(
"string_batch",
title="String Batch",
tags=["primitives", "string", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class StringBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each string in the batch."""
strings: list[str] = InputField(
default=[],
min_length=1,
description="The strings to batch over",
)
def invoke(self, context: InvocationContext) -> StringOutput:
raise NotExecutableNodeError()
@invocation_output("string_generator_output")
class StringGeneratorOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of strings"""
strings: list[str] = OutputField(description="The generated strings")
class StringGeneratorField(BaseModel):
pass
@invocation(
"string_generator",
title="String Generator",
tags=["primitives", "string", "number", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class StringGenerator(BaseInvocation):
"""Generated a range of strings for use in a batched generation"""
generator: StringGeneratorField = InputField(
description="The string generator.",
input=Input.Direct,
title="Generator Type",
)
def __init__(self):
raise NotExecutableNodeError()
def invoke(self, context: InvocationContext) -> StringGeneratorOutput:
raise NotExecutableNodeError()
@invocation(
"integer_batch",
title="Integer Batch",
tags=["primitives", "integer", "number", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class IntegerBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each integer in the batch."""
integers: list[int] = InputField(
default=[],
min_length=1,
description="The integers to batch over",
)
def invoke(self, context: InvocationContext) -> IntegerOutput:
raise NotExecutableNodeError()
@invocation_output("integer_generator_output")
class IntegerGeneratorOutput(BaseInvocationOutput):
integers: list[int] = OutputField(description="The generated integers")
class IntegerGeneratorField(BaseModel):
pass
@invocation(
"integer_generator",
title="Integer Generator",
tags=["primitives", "int", "number", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class IntegerGenerator(BaseInvocation):
"""Generated a range of integers for use in a batched generation"""
generator: IntegerGeneratorField = InputField(
description="The integer generator.",
input=Input.Direct,
title="Generator Type",
)
def __init__(self):
raise NotExecutableNodeError()
def invoke(self, context: InvocationContext) -> IntegerGeneratorOutput:
raise NotExecutableNodeError()
@invocation(
"float_batch",
title="Float Batch",
tags=["primitives", "float", "number", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class FloatBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each float in the batch."""
floats: list[float] = InputField(
default=[],
min_length=1,
description="The floats to batch over",
)
def invoke(self, context: InvocationContext) -> FloatOutput:
raise NotExecutableNodeError()
@invocation_output("float_generator_output")
class FloatGeneratorOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of floats"""
floats: list[float] = OutputField(description="The generated floats")
class FloatGeneratorField(BaseModel):
pass
@invocation(
"float_generator",
title="Float Generator",
tags=["primitives", "float", "number", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class FloatGenerator(BaseInvocation):
"""Generated a range of floats for use in a batched generation"""
generator: FloatGeneratorField = InputField(
description="The float generator.",
input=Input.Direct,
title="Generator Type",
)
def __init__(self):
raise NotExecutableNodeError()
def invoke(self, context: InvocationContext) -> FloatGeneratorOutput:
raise NotExecutableNodeError()

View File

@@ -40,7 +40,6 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
from invokeai.backend.model_manager.config import AnyModelConfig
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
@@ -86,7 +85,6 @@ def get_scheduler(
scheduler_info: ModelIdentifierField,
scheduler_name: str,
seed: int,
unet_config: AnyModelConfig,
) -> Scheduler:
"""Load a scheduler and apply some scheduler-specific overrides."""
# TODO(ryand): Silently falling back to ddim seems like a bad idea. Look into why this was added and remove if
@@ -105,9 +103,6 @@ def get_scheduler(
"_backup": scheduler_config,
}
if hasattr(unet_config, "prediction_type"):
scheduler_config["prediction_type"] = unet_config.prediction_type
# make dpmpp_sde reproducable(seed can be passed only in initializer)
if scheduler_class is DPMSolverSDEScheduler:
scheduler_config["noise_sampler_seed"] = seed
@@ -834,9 +829,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
_, _, latent_height, latent_width = latents.shape
# get the unet's config so that we can pass the base to sd_step_callback()
unet_config = context.models.get_config(self.unet.unet.key)
conditioning_data = self.get_conditioning_data(
context=context,
positive_conditioning_field=self.positive_conditioning,
@@ -856,7 +848,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
seed=seed,
unet_config=unet_config,
)
timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
@@ -868,6 +859,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
denoising_end=self.denoising_end,
)
# get the unet's config so that we can pass the base to sd_step_callback()
unet_config = context.models.get_config(self.unet.unet.key)
### preview
def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, unet_config.base)
@@ -1036,7 +1030,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
seed=seed,
unet_config=unet_config,
)
pipeline = self.create_pipeline(unet, scheduler)

View File

@@ -300,13 +300,6 @@ class BoundingBoxField(BaseModel):
raise ValueError(f"y_min ({self.y_min}) is greater than y_max ({self.y_max}).")
return self
def tuple(self) -> Tuple[int, int, int, int]:
"""
Returns the bounding box as a tuple suitable for use with PIL's `Image.crop()` method.
This method returns a tuple of the form (left, upper, right, lower) == (x_min, y_min, x_max, y_max).
"""
return (self.x_min, self.y_min, self.x_max, self.y_max)
class MetadataField(RootModel[dict[str, Any]]):
"""

View File

@@ -8,7 +8,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import CLIPField, LoRAField, ModelIdentifierField, T5EncoderField, TransformerField
from invokeai.app.invocations.model import CLIPField, LoRAField, ModelIdentifierField, TransformerField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import BaseModelType
@@ -21,9 +21,6 @@ class FluxLoRALoaderOutput(BaseInvocationOutput):
default=None, description=FieldDescriptions.transformer, title="FLUX Transformer"
)
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
t5_encoder: Optional[T5EncoderField] = OutputField(
default=None, description=FieldDescriptions.t5_encoder, title="T5 Encoder"
)
@invocation(
@@ -31,7 +28,7 @@ class FluxLoRALoaderOutput(BaseInvocationOutput):
title="FLUX LoRA",
tags=["lora", "model", "flux"],
category="model",
version="1.2.0",
version="1.1.0",
classification=Classification.Prototype,
)
class FluxLoRALoaderInvocation(BaseInvocation):
@@ -53,12 +50,6 @@ class FluxLoRALoaderInvocation(BaseInvocation):
description=FieldDescriptions.clip,
input=Input.Connection,
)
t5_encoder: T5EncoderField | None = InputField(
default=None,
title="T5 Encoder",
description=FieldDescriptions.t5_encoder,
input=Input.Connection,
)
def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
lora_key = self.lora.key
@@ -71,8 +62,6 @@ class FluxLoRALoaderInvocation(BaseInvocation):
raise ValueError(f'LoRA "{lora_key}" already applied to transformer.')
if self.clip and any(lora.lora.key == lora_key for lora in self.clip.loras):
raise ValueError(f'LoRA "{lora_key}" already applied to CLIP encoder.')
if self.t5_encoder and any(lora.lora.key == lora_key for lora in self.t5_encoder.loras):
raise ValueError(f'LoRA "{lora_key}" already applied to T5 encoder.')
output = FluxLoRALoaderOutput()
@@ -93,14 +82,6 @@ class FluxLoRALoaderInvocation(BaseInvocation):
weight=self.weight,
)
)
if self.t5_encoder is not None:
output.t5_encoder = self.t5_encoder.model_copy(deep=True)
output.t5_encoder.loras.append(
LoRAField(
lora=self.lora,
weight=self.weight,
)
)
return output
@@ -110,14 +91,14 @@ class FluxLoRALoaderInvocation(BaseInvocation):
title="FLUX LoRA Collection Loader",
tags=["lora", "model", "flux"],
category="model",
version="1.3.0",
version="1.1.0",
classification=Classification.Prototype,
)
class FLUXLoRACollectionLoader(BaseInvocation):
"""Applies a collection of LoRAs to a FLUX transformer."""
loras: Optional[LoRAField | list[LoRAField]] = InputField(
default=None, description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
loras: LoRAField | list[LoRAField] = InputField(
description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
)
transformer: Optional[TransformerField] = InputField(
@@ -132,30 +113,13 @@ class FLUXLoRACollectionLoader(BaseInvocation):
description=FieldDescriptions.clip,
input=Input.Connection,
)
t5_encoder: T5EncoderField | None = InputField(
default=None,
title="T5 Encoder",
description=FieldDescriptions.t5_encoder,
input=Input.Connection,
)
def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
output = FluxLoRALoaderOutput()
loras = self.loras if isinstance(self.loras, list) else [self.loras]
added_loras: list[str] = []
if self.transformer is not None:
output.transformer = self.transformer.model_copy(deep=True)
if self.clip is not None:
output.clip = self.clip.model_copy(deep=True)
if self.t5_encoder is not None:
output.t5_encoder = self.t5_encoder.model_copy(deep=True)
for lora in loras:
if lora is None:
continue
if lora.lora.key in added_loras:
continue
@@ -166,13 +130,14 @@ class FLUXLoRACollectionLoader(BaseInvocation):
added_loras.append(lora.lora.key)
if self.transformer is not None and output.transformer is not None:
if self.transformer is not None:
if output.transformer is None:
output.transformer = self.transformer.model_copy(deep=True)
output.transformer.loras.append(lora)
if self.clip is not None and output.clip is not None:
if self.clip is not None:
if output.clip is None:
output.clip = self.clip.model_copy(deep=True)
output.clip.loras.append(lora)
if self.t5_encoder is not None and output.t5_encoder is not None:
output.t5_encoder.loras.append(lora)
return output

View File

@@ -10,10 +10,6 @@ from invokeai.app.invocations.baseinvocation import (
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.t5_model_identifier import (
preprocess_t5_encoder_model_identifier,
preprocess_t5_tokenizer_model_identifier,
)
from invokeai.backend.flux.util import max_seq_lengths
from invokeai.backend.model_manager.config import (
CheckpointConfigBase,
@@ -40,7 +36,7 @@ class FluxModelLoaderOutput(BaseInvocationOutput):
title="Flux Main Model",
tags=["model", "flux"],
category="model",
version="1.0.5",
version="1.0.4",
classification=Classification.Prototype,
)
class FluxModelLoaderInvocation(BaseInvocation):
@@ -78,8 +74,8 @@ class FluxModelLoaderInvocation(BaseInvocation):
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
tokenizer2 = preprocess_t5_tokenizer_model_identifier(self.t5_encoder_model)
t5_encoder = preprocess_t5_encoder_model_identifier(self.t5_encoder_model)
tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
transformer_config = context.models.get_config(transformer)
assert isinstance(transformer_config, CheckpointConfigBase)
@@ -87,7 +83,7 @@ class FluxModelLoaderInvocation(BaseInvocation):
return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[]),
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder, loras=[]),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
vae=VAEField(vae=vae),
max_seq_len=max_seq_lengths[transformer_config.config_path],
)

View File

@@ -2,7 +2,7 @@ from contextlib import ExitStack
from typing import Iterator, Literal, Optional, Tuple
import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer, T5TokenizerFast
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
@@ -19,7 +19,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.modules.conditioner import HFEncoder
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX, FLUX_LORA_T5_PREFIX
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
@@ -71,44 +71,12 @@ class FluxTextEncoderInvocation(BaseInvocation):
def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
prompt = [self.prompt]
t5_encoder_info = context.models.load(self.t5_encoder.text_encoder)
t5_encoder_config = t5_encoder_info.config
assert t5_encoder_config is not None
with (
t5_encoder_info.model_on_device() as (cached_weights, t5_text_encoder),
context.models.load(self.t5_encoder.text_encoder) as t5_text_encoder,
context.models.load(self.t5_encoder.tokenizer) as t5_tokenizer,
ExitStack() as exit_stack,
):
assert isinstance(t5_text_encoder, T5EncoderModel)
assert isinstance(t5_tokenizer, (T5Tokenizer, T5TokenizerFast))
# Determine if the model is quantized.
# If the model is quantized, then we need to apply the LoRA weights as sidecar layers. This results in
# slower inference than direct patching, but is agnostic to the quantization format.
if t5_encoder_config.format in [ModelFormat.T5Encoder, ModelFormat.Diffusers]:
model_is_quantized = False
elif t5_encoder_config.format in [
ModelFormat.BnbQuantizedLlmInt8b,
ModelFormat.BnbQuantizednf4b,
ModelFormat.GGUFQuantized,
]:
model_is_quantized = True
else:
raise ValueError(f"Unsupported model format: {t5_encoder_config.format}")
# Apply LoRA models to the T5 encoder.
# Note: We apply the LoRA after the encoder has been moved to its target device for faster patching.
exit_stack.enter_context(
LayerPatcher.apply_smart_model_patches(
model=t5_text_encoder,
patches=self._t5_lora_iterator(context),
prefix=FLUX_LORA_T5_PREFIX,
dtype=t5_text_encoder.dtype,
cached_weights=cached_weights,
force_sidecar_patching=model_is_quantized,
)
)
assert isinstance(t5_tokenizer, T5Tokenizer)
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)
@@ -164,10 +132,3 @@ class FluxTextEncoderInvocation(BaseInvocation):
assert isinstance(lora_info.model, ModelPatchRaw)
yield (lora_info.model, lora.weight)
del lora_info
def _t5_lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
for lora in self.t5_encoder.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, ModelPatchRaw)
yield (lora_info.model, lora.weight)
del lora_info

View File

@@ -21,7 +21,7 @@ class IdealSizeOutput(BaseInvocationOutput):
"ideal_size",
title="Ideal Size",
tags=["latents", "math", "ideal_size"],
version="1.0.4",
version="1.0.3",
)
class IdealSizeInvocation(BaseInvocation):
"""Calculates the ideal size for generation to avoid duplication"""
@@ -41,16 +41,11 @@ class IdealSizeInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> IdealSizeOutput:
unet_config = context.models.get_config(self.unet.unet.key)
aspect = self.width / self.height
if unet_config.base == BaseModelType.StableDiffusion1:
dimension = 512
elif unet_config.base == BaseModelType.StableDiffusion2:
dimension: float = 512
if unet_config.base == BaseModelType.StableDiffusion2:
dimension = 768
elif unet_config.base in (BaseModelType.StableDiffusionXL, BaseModelType.Flux, BaseModelType.StableDiffusion3):
elif unet_config.base == BaseModelType.StableDiffusionXL:
dimension = 1024
else:
raise ValueError(f"Unsupported model type: {unet_config.base}")
dimension = dimension * self.multiplier
min_dimension = math.floor(dimension * 0.5)
model_area = dimension * dimension # hardcoded for now since all models are trained on square images

View File

@@ -13,7 +13,6 @@ from invokeai.app.invocations.baseinvocation import (
)
from invokeai.app.invocations.constants import IMAGE_MODES
from invokeai.app.invocations.fields import (
BoundingBoxField,
ColorField,
FieldDescriptions,
ImageField,
@@ -24,7 +23,6 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.image_records.image_records_common import ImageCategory
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.misc import SEED_MAX
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
from invokeai.backend.image_util.safety_checker import SafetyChecker
@@ -163,12 +161,12 @@ class ImagePasteInvocation(BaseInvocation, WithMetadata, WithBoard):
crop: bool = InputField(default=False, description="Crop to base image dimensions")
def invoke(self, context: InvocationContext) -> ImageOutput:
base_image = context.images.get_pil(self.base_image.image_name, mode="RGBA")
image = context.images.get_pil(self.image.image_name, mode="RGBA")
base_image = context.images.get_pil(self.base_image.image_name)
image = context.images.get_pil(self.image.image_name)
mask = None
if self.mask is not None:
mask = context.images.get_pil(self.mask.image_name, mode="L")
mask = ImageOps.invert(mask)
mask = context.images.get_pil(self.mask.image_name)
mask = ImageOps.invert(mask.convert("L"))
# TODO: probably shouldn't invert mask here... should user be required to do it?
min_x = min(0, self.x)
@@ -178,11 +176,7 @@ class ImagePasteInvocation(BaseInvocation, WithMetadata, WithBoard):
new_image = Image.new(mode="RGBA", size=(max_x - min_x, max_y - min_y), color=(0, 0, 0, 0))
new_image.paste(base_image, (abs(min_x), abs(min_y)))
# Create a temporary image to paste the image with transparency
temp_image = Image.new("RGBA", new_image.size)
temp_image.paste(image, (max(0, self.x), max(0, self.y)), mask=mask)
new_image = Image.alpha_composite(new_image, temp_image)
new_image.paste(image, (max(0, self.x), max(0, self.y)), mask=mask)
if self.crop:
base_w, base_h = base_image.size
@@ -307,44 +301,14 @@ class ImageBlurInvocation(BaseInvocation, WithMetadata, WithBoard):
blur_type: Literal["gaussian", "box"] = InputField(default="gaussian", description="The type of blur")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, mode="RGBA")
image = context.images.get_pil(self.image.image_name)
# Split the image into RGBA channels
r, g, b, a = image.split()
# Premultiply RGB channels by alpha
premultiplied_image = ImageChops.multiply(image, a.convert("RGBA"))
premultiplied_image.putalpha(a)
# Apply the blur
blur = (
ImageFilter.GaussianBlur(self.radius) if self.blur_type == "gaussian" else ImageFilter.BoxBlur(self.radius)
)
blurred_image = premultiplied_image.filter(blur)
blur_image = image.filter(blur)
# Split the blurred image into RGBA channels
r, g, b, a_orig = blurred_image.split()
# Convert to float using NumPy. float 32/64 division are much faster than float 16
r = numpy.array(r, dtype=numpy.float32)
g = numpy.array(g, dtype=numpy.float32)
b = numpy.array(b, dtype=numpy.float32)
a = numpy.array(a_orig, dtype=numpy.float32) / 255.0 # Normalize alpha to [0, 1]
# Unpremultiply RGB channels by alpha
r /= a + 1e-6 # Add a small epsilon to avoid division by zero
g /= a + 1e-6
b /= a + 1e-6
# Convert back to PIL images
r = Image.fromarray(numpy.uint8(numpy.clip(r, 0, 255)))
g = Image.fromarray(numpy.uint8(numpy.clip(g, 0, 255)))
b = Image.fromarray(numpy.uint8(numpy.clip(b, 0, 255)))
# Merge back into a single image
result_image = Image.merge("RGBA", (r, g, b, a_orig))
image_dto = context.images.save(image=result_image)
image_dto = context.images.save(image=blur_image)
return ImageOutput.build(image_dto)
@@ -843,7 +807,7 @@ CHANNEL_FORMATS = {
"value",
],
category="image",
version="1.2.3",
version="1.2.2",
)
class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Add or subtract a value from a specific color channel of an image."""
@@ -853,22 +817,18 @@ class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata, WithBoard):
offset: int = InputField(default=0, ge=-255, le=255, description="The amount to adjust the channel by")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGBA")
pil_image = context.images.get_pil(self.image.image_name)
# extract the channel and mode from the input and reference tuple
mode = CHANNEL_FORMATS[self.channel][0]
channel_number = CHANNEL_FORMATS[self.channel][1]
# Convert PIL image to new format
converted_image = numpy.array(image.convert(mode)).astype(int)
converted_image = numpy.array(pil_image.convert(mode)).astype(int)
image_channel = converted_image[:, :, channel_number]
if self.channel == "Hue (HSV)":
# loop around the values because hue is special
image_channel = (image_channel + self.offset) % 256
else:
# Adjust the value, clipping to 0..255
image_channel = numpy.clip(image_channel + self.offset, 0, 255)
# Adjust the value, clipping to 0..255
image_channel = numpy.clip(image_channel + self.offset, 0, 255)
# Put the channel back into the image
converted_image[:, :, channel_number] = image_channel
@@ -876,10 +836,6 @@ class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata, WithBoard):
# Convert back to RGBA format and output
pil_image = Image.fromarray(converted_image.astype(numpy.uint8), mode=mode).convert("RGBA")
# restore the alpha channel
if self.channel != "Alpha (RGBA)":
pil_image.putalpha(image.getchannel("A"))
image_dto = context.images.save(image=pil_image)
return ImageOutput.build(image_dto)
@@ -907,7 +863,7 @@ class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata, WithBoard):
"value",
],
category="image",
version="1.2.3",
version="1.2.2",
)
class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Scale a specific color channel of an image."""
@@ -918,14 +874,14 @@ class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata, WithBoard):
invert_channel: bool = InputField(default=False, description="Invert the channel after scaling")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
pil_image = context.images.get_pil(self.image.image_name)
# extract the channel and mode from the input and reference tuple
mode = CHANNEL_FORMATS[self.channel][0]
channel_number = CHANNEL_FORMATS[self.channel][1]
# Convert PIL image to new format
converted_image = numpy.array(image.convert(mode)).astype(float)
converted_image = numpy.array(pil_image.convert(mode)).astype(float)
image_channel = converted_image[:, :, channel_number]
# Adjust the value, clipping to 0..255
@@ -941,10 +897,6 @@ class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata, WithBoard):
# Convert back to RGBA format and output
pil_image = Image.fromarray(converted_image.astype(numpy.uint8), mode=mode).convert("RGBA")
# restore the alpha channel
if self.channel != "Alpha (RGBA)":
pil_image.putalpha(image.getchannel("A"))
image_dto = context.images.save(image=pil_image)
return ImageOutput.build(image_dto)
@@ -1010,10 +962,10 @@ class CanvasPasteBackInvocation(BaseInvocation, WithMetadata, WithBoard):
@invocation(
"mask_from_id",
title="Mask from Segmented Image",
title="Mask from ID",
tags=["image", "mask", "id"],
category="image",
version="1.0.1",
version="1.0.0",
)
class MaskFromIDInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generate a mask for a particular color in an ID Map"""
@@ -1023,24 +975,40 @@ class MaskFromIDInvocation(BaseInvocation, WithMetadata, WithBoard):
threshold: int = InputField(default=100, description="Threshold for color detection")
invert: bool = InputField(default=False, description="Whether or not to invert the mask")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, mode="RGBA")
def rgba_to_hex(self, rgba_color: tuple[int, int, int, int]):
r, g, b, a = rgba_color
hex_code = "#{:02X}{:02X}{:02X}{:02X}".format(r, g, b, int(a * 255))
return hex_code
np_color = numpy.array(self.color.tuple())
def id_to_mask(self, id_mask: Image.Image, color: tuple[int, int, int, int], threshold: int = 100):
if id_mask.mode != "RGB":
id_mask = id_mask.convert("RGB")
# Can directly just use the tuple but I'll leave this rgba_to_hex here
# incase anyone prefers using hex codes directly instead of the color picker
hex_color_str = self.rgba_to_hex(color)
rgb_color = numpy.array([int(hex_color_str[i : i + 2], 16) for i in (1, 3, 5)])
# Maybe there's a faster way to calculate this distance but I can't think of any right now.
color_distance = numpy.linalg.norm(image - np_color, axis=-1)
color_distance = numpy.linalg.norm(id_mask - rgb_color, axis=-1)
# Create a mask based on the threshold and the distance calculated above
binary_mask = (color_distance < self.threshold).astype(numpy.uint8) * 255
binary_mask = (color_distance < threshold).astype(numpy.uint8) * 255
# Convert the mask back to PIL
binary_mask_pil = Image.fromarray(binary_mask)
if self.invert:
binary_mask_pil = ImageOps.invert(binary_mask_pil)
return binary_mask_pil
image_dto = context.images.save(image=binary_mask_pil, image_category=ImageCategory.MASK)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
mask = self.id_to_mask(image, self.color.tuple(), self.threshold)
if self.invert:
mask = ImageOps.invert(mask)
image_dto = context.images.save(image=mask, image_category=ImageCategory.MASK)
return ImageOutput.build(image_dto)
@@ -1087,123 +1055,3 @@ class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard):
image_dto = context.images.save(image=generated_image)
return ImageOutput.build(image_dto)
@invocation(
"img_noise",
title="Add Image Noise",
tags=["image", "noise"],
category="image",
version="1.0.1",
)
class ImageNoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Add noise to an image"""
image: ImageField = InputField(description="The image to add noise to")
seed: int = InputField(
default=0,
ge=0,
le=SEED_MAX,
description=FieldDescriptions.seed,
)
noise_type: Literal["gaussian", "salt_and_pepper"] = InputField(
default="gaussian",
description="The type of noise to add",
)
amount: float = InputField(default=0.1, ge=0, le=1, description="The amount of noise to add")
noise_color: bool = InputField(default=True, description="Whether to add colored noise")
size: int = InputField(default=1, ge=1, description="The size of the noise points")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, mode="RGBA")
# Save out the alpha channel
alpha = image.getchannel("A")
# Set the seed for numpy random
rs = numpy.random.RandomState(numpy.random.MT19937(numpy.random.SeedSequence(self.seed)))
if self.noise_type == "gaussian":
if self.noise_color:
noise = rs.normal(0, 1, (image.height // self.size, image.width // self.size, 3)) * 255
else:
noise = rs.normal(0, 1, (image.height // self.size, image.width // self.size)) * 255
noise = numpy.stack([noise] * 3, axis=-1)
elif self.noise_type == "salt_and_pepper":
if self.noise_color:
noise = rs.choice(
[0, 255], (image.height // self.size, image.width // self.size, 3), p=[1 - self.amount, self.amount]
)
else:
noise = rs.choice(
[0, 255], (image.height // self.size, image.width // self.size), p=[1 - self.amount, self.amount]
)
noise = numpy.stack([noise] * 3, axis=-1)
noise = Image.fromarray(noise.astype(numpy.uint8), mode="RGB").resize(
(image.width, image.height), Image.Resampling.NEAREST
)
noisy_image = Image.blend(image.convert("RGB"), noise, self.amount).convert("RGBA")
# Paste back the alpha channel
noisy_image.putalpha(alpha)
image_dto = context.images.save(image=noisy_image)
return ImageOutput.build(image_dto)
@invocation(
"crop_image_to_bounding_box",
title="Crop Image to Bounding Box",
category="image",
version="1.0.0",
tags=["image", "crop"],
classification=Classification.Beta,
)
class CropImageToBoundingBoxInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Crop an image to the given bounding box. If the bounding box is omitted, the image is cropped to the non-transparent pixels."""
image: ImageField = InputField(description="The image to crop")
bounding_box: BoundingBoxField | None = InputField(
default=None, description="The bounding box to crop the image to"
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
bounding_box = self.bounding_box.tuple() if self.bounding_box is not None else image.getbbox()
cropped_image = image.crop(bounding_box)
image_dto = context.images.save(image=cropped_image)
return ImageOutput.build(image_dto)
@invocation(
"paste_image_into_bounding_box",
title="Paste Image into Bounding Box",
category="image",
version="1.0.0",
tags=["image", "crop"],
classification=Classification.Beta,
)
class PasteImageIntoBoundingBoxInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Paste the source image into the target image at the given bounding box.
The source image must be the same size as the bounding box, and the bounding box must fit within the target image."""
source_image: ImageField = InputField(description="The image to paste")
target_image: ImageField = InputField(description="The image to paste into")
bounding_box: BoundingBoxField = InputField(description="The bounding box to paste the image into")
def invoke(self, context: InvocationContext) -> ImageOutput:
source_image = context.images.get_pil(self.source_image.image_name, mode="RGBA")
target_image = context.images.get_pil(self.target_image.image_name, mode="RGBA")
bounding_box = self.bounding_box.tuple()
target_image.paste(source_image, bounding_box, source_image)
image_dto = context.images.save(image=target_image)
return ImageOutput.build(image_dto)

View File

@@ -2,22 +2,9 @@ import numpy as np
import torch
from PIL import Image
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
Classification,
InvocationContext,
invocation,
)
from invokeai.app.invocations.fields import (
BoundingBoxField,
ColorField,
ImageField,
InputField,
TensorField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.primitives import BoundingBoxOutput, ImageOutput, MaskOutput
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, InvocationContext, invocation
from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithBoard, WithMetadata
from invokeai.app.invocations.primitives import ImageOutput, MaskOutput
from invokeai.backend.image_util.util import pil_to_np
@@ -86,7 +73,7 @@ class AlphaMaskToTensorInvocation(BaseInvocation):
title="Invert Tensor Mask",
tags=["conditioning"],
category="conditioning",
version="1.1.0",
version="1.0.0",
classification=Classification.Beta,
)
class InvertTensorMaskInvocation(BaseInvocation):
@@ -96,15 +83,6 @@ class InvertTensorMaskInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> MaskOutput:
mask = context.tensors.load(self.mask.tensor_name)
# Verify dtype and shape.
assert mask.dtype == torch.bool
assert mask.dim() in [2, 3]
# Unsqueeze the channel dimension if it is missing. The MaskOutput type expects a single channel.
if mask.dim() == 2:
mask = mask.unsqueeze(0)
inverted = ~mask
return MaskOutput(
@@ -223,48 +201,3 @@ class ApplyMaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
image_dto = context.images.save(image=masked_image)
return ImageOutput.build(image_dto)
WHITE = ColorField(r=255, g=255, b=255, a=255)
@invocation(
"get_image_mask_bounding_box",
title="Get Image Mask Bounding Box",
tags=["mask"],
category="mask",
version="1.0.0",
classification=Classification.Beta,
)
class GetMaskBoundingBoxInvocation(BaseInvocation):
"""Gets the bounding box of the given mask image."""
mask: ImageField = InputField(description="The mask to crop.")
margin: int = InputField(default=0, description="Margin to add to the bounding box.")
mask_color: ColorField = InputField(default=WHITE, description="Color of the mask in the image.")
def invoke(self, context: InvocationContext) -> BoundingBoxOutput:
mask = context.images.get_pil(self.mask.image_name, mode="RGBA")
mask_np = np.array(mask)
# Convert mask_color to RGBA tuple
mask_color_rgb = self.mask_color.tuple()
# Find the bounding box of the mask color
y, x = np.where(np.all(mask_np == mask_color_rgb, axis=-1))
if len(x) == 0 or len(y) == 0:
# No pixels found with the given color
return BoundingBoxOutput(bounding_box=BoundingBoxField(x_min=0, y_min=0, x_max=0, y_max=0))
left, upper, right, lower = x.min(), y.min(), x.max(), y.max()
# Add the margin
left = max(0, left - self.margin)
upper = max(0, upper - self.margin)
right = min(mask_np.shape[1], right + self.margin)
lower = min(mask_np.shape[0], lower + self.margin)
bounding_box = BoundingBoxField(x_min=left, y_min=upper, x_max=right, y_max=lower)
return BoundingBoxOutput(bounding_box=bounding_box)

View File

@@ -68,7 +68,6 @@ class CLIPField(BaseModel):
class T5EncoderField(BaseModel):
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
class VAEField(BaseModel):
@@ -206,7 +205,7 @@ class LoRALoaderInvocation(BaseInvocation):
lora_key = self.lora.key
if not context.models.exists(lora_key):
raise Exception(f"Unknown lora: {lora_key}!")
raise Exception(f"Unkown lora: {lora_key}!")
if self.unet is not None and any(lora.lora.key == lora_key for lora in self.unet.loras):
raise Exception(f'LoRA "{lora_key}" already applied to unet')
@@ -257,12 +256,12 @@ class LoRASelectorInvocation(BaseInvocation):
return LoRASelectorOutput(lora=LoRAField(lora=self.lora, weight=self.weight))
@invocation("lora_collection_loader", title="LoRA Collection Loader", tags=["model"], category="model", version="1.1.0")
@invocation("lora_collection_loader", title="LoRA Collection Loader", tags=["model"], category="model", version="1.0.0")
class LoRACollectionLoader(BaseInvocation):
"""Applies a collection of LoRAs to the provided UNet and CLIP models."""
loras: Optional[LoRAField | list[LoRAField]] = InputField(
default=None, description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
loras: LoRAField | list[LoRAField] = InputField(
description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
)
unet: Optional[UNetField] = InputField(
default=None,
@@ -282,14 +281,7 @@ class LoRACollectionLoader(BaseInvocation):
loras = self.loras if isinstance(self.loras, list) else [self.loras]
added_loras: list[str] = []
if self.unet is not None:
output.unet = self.unet.model_copy(deep=True)
if self.clip is not None:
output.clip = self.clip.model_copy(deep=True)
for lora in loras:
if lora is None:
continue
if lora.lora.key in added_loras:
continue
@@ -300,10 +292,14 @@ class LoRACollectionLoader(BaseInvocation):
added_loras.append(lora.lora.key)
if self.unet is not None and output.unet is not None:
if self.unet is not None:
if output.unet is None:
output.unet = self.unet.model_copy(deep=True)
output.unet.loras.append(lora)
if self.clip is not None and output.clip is not None:
if self.clip is not None:
if output.clip is None:
output.clip = self.clip.model_copy(deep=True)
output.clip.loras.append(lora)
return output
@@ -403,13 +399,13 @@ class SDXLLoRALoaderInvocation(BaseInvocation):
title="SDXL LoRA Collection Loader",
tags=["model"],
category="model",
version="1.1.0",
version="1.0.0",
)
class SDXLLoRACollectionLoader(BaseInvocation):
"""Applies a collection of SDXL LoRAs to the provided UNet and CLIP models."""
loras: Optional[LoRAField | list[LoRAField]] = InputField(
default=None, description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
loras: LoRAField | list[LoRAField] = InputField(
description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
)
unet: Optional[UNetField] = InputField(
default=None,
@@ -435,18 +431,7 @@ class SDXLLoRACollectionLoader(BaseInvocation):
loras = self.loras if isinstance(self.loras, list) else [self.loras]
added_loras: list[str] = []
if self.unet is not None:
output.unet = self.unet.model_copy(deep=True)
if self.clip is not None:
output.clip = self.clip.model_copy(deep=True)
if self.clip2 is not None:
output.clip2 = self.clip2.model_copy(deep=True)
for lora in loras:
if lora is None:
continue
if lora.lora.key in added_loras:
continue
@@ -457,13 +442,19 @@ class SDXLLoRACollectionLoader(BaseInvocation):
added_loras.append(lora.lora.key)
if self.unet is not None and output.unet is not None:
if self.unet is not None:
if output.unet is None:
output.unet = self.unet.model_copy(deep=True)
output.unet.loras.append(lora)
if self.clip is not None and output.clip is not None:
if self.clip is not None:
if output.clip is None:
output.clip = self.clip.model_copy(deep=True)
output.clip.loras.append(lora)
if self.clip2 is not None and output.clip2 is not None:
if self.clip2 is not None:
if output.clip2 is None:
output.clip2 = self.clip2.model_copy(deep=True)
output.clip2.loras.append(lora)
return output
@@ -481,7 +472,7 @@ class VAELoaderInvocation(BaseInvocation):
key = self.vae_model.key
if not context.models.exists(key):
raise Exception(f"Unknown vae: {key}!")
raise Exception(f"Unkown vae: {key}!")
return VAEOutput(vae=VAEField(vae=self.vae_model))

View File

@@ -7,6 +7,7 @@ import torch
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
@@ -416,7 +417,6 @@ class ColorInvocation(BaseInvocation):
class MaskOutput(BaseInvocationOutput):
"""A torch mask tensor."""
# shape: [1, H, W], dtype: bool
mask: TensorField = OutputField(description="The mask.")
width: int = OutputField(description="The width of the mask in pixels.")
height: int = OutputField(description="The height of the mask in pixels.")
@@ -539,3 +539,23 @@ class BoundingBoxInvocation(BaseInvocation):
# endregion
@invocation(
"image_batch",
title="Image Batch",
tags=["primitives", "image", "batch", "internal"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class ImageBatchInvocation(BaseInvocation):
"""Create a batched generation, where the workflow is executed once for each image in the batch."""
images: list[ImageField] = InputField(min_length=1, description="The images to batch over", input=Input.Direct)
def __init__(self):
raise NotImplementedError("This class should never be executed or instantiated directly.")
def invoke(self, context: InvocationContext) -> ImageOutput:
raise NotImplementedError("This class should never be executed or instantiated directly.")

View File

@@ -10,10 +10,6 @@ from invokeai.app.invocations.baseinvocation import (
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.t5_model_identifier import (
preprocess_t5_encoder_model_identifier,
preprocess_t5_tokenizer_model_identifier,
)
from invokeai.backend.model_manager.config import SubModelType
@@ -92,13 +88,21 @@ class Sd3ModelLoaderInvocation(BaseInvocation):
if self.clip_g_model
else self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
)
tokenizer_t5 = preprocess_t5_tokenizer_model_identifier(self.t5_encoder_model or self.model)
t5_encoder = preprocess_t5_encoder_model_identifier(self.t5_encoder_model or self.model)
tokenizer_t5 = (
self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
if self.t5_encoder_model
else self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
)
t5_encoder = (
self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
if self.t5_encoder_model
else self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
)
return Sd3ModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[]),
clip_l=CLIPField(tokenizer=tokenizer_l, text_encoder=clip_encoder_l, loras=[], skipped_layers=0),
clip_g=CLIPField(tokenizer=tokenizer_g, text_encoder=clip_encoder_g, loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer_t5, text_encoder=t5_encoder, loras=[]),
t5_encoder=T5EncoderField(tokenizer=tokenizer_t5, text_encoder=t5_encoder),
vae=VAEField(vae=vae),
)

View File

@@ -49,7 +49,7 @@ class SAMPointsField(BaseModel):
title="Segment Anything",
tags=["prompt", "segmentation"],
category="segmentation",
version="1.2.0",
version="1.1.0",
)
class SegmentAnythingInvocation(BaseInvocation):
"""Runs a Segment Anything Model."""
@@ -96,10 +96,8 @@ class SegmentAnythingInvocation(BaseInvocation):
# masks contains bool values, so we merge them via max-reduce.
combined_mask, _ = torch.stack(masks).max(dim=0)
# Unsqueeze the channel dimension.
combined_mask = combined_mask.unsqueeze(0)
mask_tensor_name = context.tensors.save(combined_mask)
_, height, width = combined_mask.shape
height, width = combined_mask.shape
return MaskOutput(mask=TensorField(tensor_name=mask_tensor_name), width=width, height=height)
@staticmethod

View File

@@ -218,7 +218,6 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
seed=seed,
unet_config=unet_config,
)
pipeline = self.create_pipeline(unet=unet, scheduler=scheduler)

View File

@@ -87,7 +87,6 @@ class InvokeAIAppConfig(BaseSettings):
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
device_working_mem_gb: The amount of working memory to keep available on the compute device (in GB). Has no effect if running on CPU. If you are experiencing OOM errors, try increasing this value.
enable_partial_loading: Enable partial loading of models. This enables models to run with reduced VRAM requirements (at the cost of slower speed) by streaming the model from RAM to VRAM as its used. In some edge cases, partial loading can cause models to run more slowly if they were previously being fully loaded into VRAM.
keep_ram_copy_of_weights: Whether to keep a full RAM copy of a model's weights when the model is loaded in VRAM. Keeping a RAM copy increases average RAM usage, but speeds up model switching and LoRA patching (assuming there is sufficient RAM). Set this to False if RAM pressure is consistently high.
ram: DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_ram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.
vram: DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_vram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.
lazy_offload: DEPRECATED: This setting is no longer used. Lazy-offloading is enabled by default. This config setting will be removed once the new model cache behavior is stable.
@@ -163,7 +162,6 @@ class InvokeAIAppConfig(BaseSettings):
log_memory_usage: bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.")
device_working_mem_gb: float = Field(default=3, description="The amount of working memory to keep available on the compute device (in GB). Has no effect if running on CPU. If you are experiencing OOM errors, try increasing this value.")
enable_partial_loading: bool = Field(default=False, description="Enable partial loading of models. This enables models to run with reduced VRAM requirements (at the cost of slower speed) by streaming the model from RAM to VRAM as its used. In some edge cases, partial loading can cause models to run more slowly if they were previously being fully loaded into VRAM.")
keep_ram_copy_of_weights: bool = Field(default=True, description="Whether to keep a full RAM copy of a model's weights when the model is loaded in VRAM. Keeping a RAM copy increases average RAM usage, but speeds up model switching and LoRA patching (assuming there is sufficient RAM). Set this to False if RAM pressure is consistently high.")
# Deprecated CACHE configs
ram: Optional[float] = Field(default=None, gt=0, description="DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_ram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.")
vram: Optional[float] = Field(default=None, ge=0, description="DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_vram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.")

View File

@@ -84,7 +84,6 @@ class ModelManagerService(ModelManagerServiceBase):
ram_cache = ModelCache(
execution_device_working_mem_gb=app_config.device_working_mem_gb,
enable_partial_loading=app_config.enable_partial_loading,
keep_ram_copy_of_weights=app_config.keep_ram_copy_of_weights,
max_ram_cache_size_gb=app_config.max_cache_ram_gb,
max_vram_cache_size_gb=app_config.max_cache_vram_gb,
execution_device=execution_device or TorchDevice.choose_torch_device(),

View File

@@ -5,7 +5,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
QUEUE_ITEM_STATUS,
Batch,
BatchStatus,
CancelAllExceptCurrentResult,
CancelByBatchIDsResult,
CancelByDestinationResult,
CancelByQueueIDResult,
@@ -113,11 +112,6 @@ class SessionQueueBase(ABC):
"""Cancels all queue items with matching queue ID"""
pass
@abstractmethod
def cancel_all_except_current(self, queue_id: str) -> CancelAllExceptCurrentResult:
"""Cancels all queue items except in-progress items"""
pass
@abstractmethod
def list_queue_items(
self,

View File

@@ -108,16 +108,8 @@ class Batch(BaseModel):
return v
for batch_data_list in v:
for datum in batch_data_list:
if not datum.items:
continue
# Special handling for numbers - they can be mixed
# TODO(psyche): Update BatchDatum to have a `type` field to specify the type of the items, then we can have strict float and int fields
if all(isinstance(item, (int, float)) for item in datum.items):
continue
# Get the type of the first item in the list
first_item_type = type(datum.items[0])
first_item_type = type(datum.items[0]) if datum.items else None
for item in datum.items:
if type(item) is not first_item_type:
raise BatchItemsTypeError("All items in a batch must have the same type")
@@ -374,12 +366,6 @@ class CancelByQueueIDResult(CancelByBatchIDsResult):
pass
class CancelAllExceptCurrentResult(CancelByBatchIDsResult):
"""Result of canceling all except current"""
pass
class IsEmptyResult(BaseModel):
"""Result of checking if the session queue is empty"""

View File

@@ -9,7 +9,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
QUEUE_ITEM_STATUS,
Batch,
BatchStatus,
CancelAllExceptCurrentResult,
CancelByBatchIDsResult,
CancelByDestinationResult,
CancelByQueueIDResult,
@@ -511,39 +510,6 @@ class SqliteSessionQueue(SessionQueueBase):
self.__lock.release()
return CancelByQueueIDResult(canceled=count)
def cancel_all_except_current(self, queue_id: str) -> CancelAllExceptCurrentResult:
try:
where = """--sql
WHERE
queue_id == ?
AND status == 'pending'
"""
self.__lock.acquire()
self.__cursor.execute(
f"""--sql
SELECT COUNT(*)
FROM session_queue
{where};
""",
(queue_id,),
)
count = self.__cursor.fetchone()[0]
self.__cursor.execute(
f"""--sql
UPDATE session_queue
SET status = 'canceled'
{where};
""",
(queue_id,),
)
self.__conn.commit()
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
return CancelAllExceptCurrentResult(canceled=count)
def get_queue_item(self, item_id: int) -> SessionQueueItem:
try:
self.__lock.acquire()

View File

@@ -51,18 +51,15 @@ class Edge(BaseModel):
source: EdgeConnection = Field(description="The connection for the edge's from node and field")
destination: EdgeConnection = Field(description="The connection for the edge's to node and field")
def __str__(self):
return f"{self.source.node_id}.{self.source.field} -> {self.destination.node_id}.{self.destination.field}"
def get_output_field_type(node: BaseInvocation, field: str) -> Any:
def get_output_field(node: BaseInvocation, field: str) -> Any:
node_type = type(node)
node_outputs = get_type_hints(node_type.get_output_annotation())
node_output_field = node_outputs.get(field) or None
return node_output_field
def get_input_field_type(node: BaseInvocation, field: str) -> Any:
def get_input_field(node: BaseInvocation, field: str) -> Any:
node_type = type(node)
node_inputs = get_type_hints(node_type)
node_input_field = node_inputs.get(field) or None
@@ -96,10 +93,6 @@ def is_list_or_contains_list(t):
return False
def is_any(t: Any) -> bool:
return t == Any or Any in get_args(t)
def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
if not from_type:
return False
@@ -109,7 +102,13 @@ def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
# TODO: this is pretty forgiving on generic types. Clean that up (need to handle optionals and such)
if from_type and to_type:
# Ports are compatible
if from_type == to_type or is_any(from_type) or is_any(to_type):
if (
from_type == to_type
or from_type == Any
or to_type == Any
or Any in get_args(from_type)
or Any in get_args(to_type)
):
return True
if from_type in get_args(to_type):
@@ -141,10 +140,10 @@ def are_connections_compatible(
"""Determines if a connection between fields of two nodes is compatible."""
# TODO: handle iterators and collectors
from_type = get_output_field_type(from_node, from_field)
to_type = get_input_field_type(to_node, to_field)
from_node_field = get_output_field(from_node, from_field)
to_node_field = get_input_field(to_node, to_field)
return are_connection_types_compatible(from_type, to_type)
return are_connection_types_compatible(from_node_field, to_node_field)
T = TypeVar("T")
@@ -441,19 +440,17 @@ class Graph(BaseModel):
self.get_node(edge.destination.node_id),
edge.destination.field,
):
raise InvalidEdgeError(f"Edge source and target types do not match ({edge})")
raise InvalidEdgeError(
f"Invalid edge from {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
)
# Validate all iterators & collectors
# TODO: may need to validate all iterators & collectors in subgraphs so edge connections in parent graphs will be available
for node in self.nodes.values():
if isinstance(node, IterateInvocation):
err = self._is_iterator_connection_valid(node.id)
if err is not None:
raise InvalidEdgeError(f"Invalid iterator node ({node.id}): {err}")
if isinstance(node, CollectInvocation):
err = self._is_collector_connection_valid(node.id)
if err is not None:
raise InvalidEdgeError(f"Invalid collector node ({node.id}): {err}")
if isinstance(node, IterateInvocation) and not self._is_iterator_connection_valid(node.id):
raise InvalidEdgeError(f"Invalid iterator node {node.id}")
if isinstance(node, CollectInvocation) and not self._is_collector_connection_valid(node.id):
raise InvalidEdgeError(f"Invalid collector node {node.id}")
return None
@@ -480,11 +477,11 @@ class Graph(BaseModel):
def _is_destination_field_Any(self, edge: Edge) -> bool:
"""Checks if the destination field for an edge is of type typing.Any"""
return get_input_field_type(self.get_node(edge.destination.node_id), edge.destination.field) == Any
return get_input_field(self.get_node(edge.destination.node_id), edge.destination.field) == Any
def _is_destination_field_list_of_Any(self, edge: Edge) -> bool:
"""Checks if the destination field for an edge is of type typing.Any"""
return get_input_field_type(self.get_node(edge.destination.node_id), edge.destination.field) == list[Any]
return get_input_field(self.get_node(edge.destination.node_id), edge.destination.field) == list[Any]
def _validate_edge(self, edge: Edge):
"""Validates that a new edge doesn't create a cycle in the graph"""
@@ -494,40 +491,55 @@ class Graph(BaseModel):
from_node = self.get_node(edge.source.node_id)
to_node = self.get_node(edge.destination.node_id)
except NodeNotFoundError:
raise InvalidEdgeError(f"One or both nodes don't exist ({edge})")
raise InvalidEdgeError("One or both nodes don't exist: {edge.source.node_id} -> {edge.destination.node_id}")
# Validate that an edge to this node+field doesn't already exist
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation):
raise InvalidEdgeError(f"Edge already exists ({edge})")
raise InvalidEdgeError(
f"Edge to node {edge.destination.node_id} field {edge.destination.field} already exists"
)
# Validate that no cycles would be created
g = self.nx_graph_flat()
g.add_edge(edge.source.node_id, edge.destination.node_id)
if not nx.is_directed_acyclic_graph(g):
raise InvalidEdgeError(f"Edge creates a cycle in the graph ({edge})")
raise InvalidEdgeError(
f"Edge creates a cycle in the graph: {edge.source.node_id} -> {edge.destination.node_id}"
)
# Validate that the field types are compatible
if not are_connections_compatible(from_node, edge.source.field, to_node, edge.destination.field):
raise InvalidEdgeError(f"Field types are incompatible ({edge})")
raise InvalidEdgeError(
f"Fields are incompatible: cannot connect {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
)
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
err = self._is_iterator_connection_valid(edge.destination.node_id, new_input=edge.source)
if err is not None:
raise InvalidEdgeError(f"Iterator input type does not match iterator output type ({edge}): {err}")
if not self._is_iterator_connection_valid(edge.destination.node_id, new_input=edge.source):
raise InvalidEdgeError(
f"Iterator input type does not match iterator output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
)
# Validate if iterator input type matches output type (if this edge results in both being set)
if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
err = self._is_iterator_connection_valid(edge.source.node_id, new_output=edge.destination)
if err is not None:
raise InvalidEdgeError(f"Iterator output type does not match iterator input type ({edge}): {err}")
if not self._is_iterator_connection_valid(edge.source.node_id, new_output=edge.destination):
raise InvalidEdgeError(
f"Iterator output type does not match iterator input type:, {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
)
# Validate if collector input type matches output type (if this edge results in both being set)
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
err = self._is_collector_connection_valid(edge.destination.node_id, new_input=edge.source)
if err is not None:
raise InvalidEdgeError(f"Collector output type does not match collector input type ({edge}): {err}")
if not self._is_collector_connection_valid(edge.destination.node_id, new_input=edge.source):
raise InvalidEdgeError(
f"Collector output type does not match collector input type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
)
# Validate that we are not connecting collector to iterator (currently unsupported)
if isinstance(from_node, CollectInvocation) and isinstance(to_node, IterateInvocation):
raise InvalidEdgeError(
f"Cannot connect collector to iterator: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
)
# Validate if collector output type matches input type (if this edge results in both being set) - skip if the destination field is not Any or list[Any]
if (
@@ -536,9 +548,10 @@ class Graph(BaseModel):
and not self._is_destination_field_list_of_Any(edge)
and not self._is_destination_field_Any(edge)
):
err = self._is_collector_connection_valid(edge.source.node_id, new_output=edge.destination)
if err is not None:
raise InvalidEdgeError(f"Collector input type does not match collector output type ({edge}): {err}")
if not self._is_collector_connection_valid(edge.source.node_id, new_output=edge.destination):
raise InvalidEdgeError(
f"Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
)
def has_node(self, node_id: str) -> bool:
"""Determines whether or not a node exists in the graph."""
@@ -621,7 +634,7 @@ class Graph(BaseModel):
node_id: str,
new_input: Optional[EdgeConnection] = None,
new_output: Optional[EdgeConnection] = None,
) -> str | None:
) -> bool:
inputs = [e.source for e in self._get_input_edges(node_id, "collection")]
outputs = [e.destination for e in self._get_output_edges(node_id, "item")]
@@ -632,47 +645,29 @@ class Graph(BaseModel):
# Only one input is allowed for iterators
if len(inputs) > 1:
return "Iterator may only have one input edge"
input_node = self.get_node(inputs[0].node_id)
return False
# Get input and output fields (the fields linked to the iterator's input/output)
input_field_type = get_output_field_type(input_node, inputs[0].field)
output_field_types = [get_input_field_type(self.get_node(e.node_id), e.field) for e in outputs]
input_field = get_output_field(self.get_node(inputs[0].node_id), inputs[0].field)
output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
# Input type must be a list
if get_origin(input_field_type) is not list:
return "Iterator input must be a collection"
if get_origin(input_field) is not list:
return False
# Validate that all outputs match the input type
input_field_item_type = get_args(input_field_type)[0]
if not all((are_connection_types_compatible(input_field_item_type, t) for t in output_field_types)):
return "Iterator outputs must connect to an input with a matching type"
input_field_item_type = get_args(input_field)[0]
if not all((are_connection_types_compatible(input_field_item_type, f) for f in output_fields)):
return False
# Collector input type must match all iterator output types
if isinstance(input_node, CollectInvocation):
# Traverse the graph to find the first collector input edge. Collectors validate that their collection
# inputs are all of the same type, so we can use the first input edge to determine the collector's type
first_collector_input_edge = self._get_input_edges(input_node.id, "item")[0]
first_collector_input_type = get_output_field_type(
self.get_node(first_collector_input_edge.source.node_id), first_collector_input_edge.source.field
)
resolved_collector_type = (
first_collector_input_type
if get_origin(first_collector_input_type) is None
else get_args(first_collector_input_type)
)
if not all((are_connection_types_compatible(resolved_collector_type, t) for t in output_field_types)):
return "Iterator collection type must match all iterator output types"
return None
return True
def _is_collector_connection_valid(
self,
node_id: str,
new_input: Optional[EdgeConnection] = None,
new_output: Optional[EdgeConnection] = None,
) -> str | None:
) -> bool:
inputs = [e.source for e in self._get_input_edges(node_id, "item")]
outputs = [e.destination for e in self._get_output_edges(node_id, "collection")]
@@ -682,42 +677,38 @@ class Graph(BaseModel):
outputs.append(new_output)
# Get input and output fields (the fields linked to the iterator's input/output)
input_field_types = [get_output_field_type(self.get_node(e.node_id), e.field) for e in inputs]
output_field_types = [get_input_field_type(self.get_node(e.node_id), e.field) for e in outputs]
input_fields = [get_output_field(self.get_node(e.node_id), e.field) for e in inputs]
output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
# Validate that all inputs are derived from or match a single type
input_field_types = {
resolved_type
for input_field_type in input_field_types
for resolved_type in (
[input_field_type] if get_origin(input_field_type) is None else get_args(input_field_type)
)
if resolved_type != NoneType
t
for input_field in input_fields
for t in ([input_field] if get_origin(input_field) is None else get_args(input_field))
if t != NoneType
} # Get unique types
type_tree = nx.DiGraph()
type_tree.add_nodes_from(input_field_types)
type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])])
type_degrees = type_tree.in_degree(type_tree.nodes)
if sum((t[1] == 0 for t in type_degrees)) != 1: # type: ignore
return "Collector input collection items must be of a single type"
return False # There is more than one root type
# Get the input root type
input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore
# Verify that all outputs are lists
if not all(is_list_or_contains_list(t) or is_any(t) for t in output_field_types):
return "Collector output must connect to a collection input"
if not all(is_list_or_contains_list(f) for f in output_fields):
return False
# Verify that all outputs match the input type (are a base class or the same class)
if not all(
is_any(t)
or is_union_subtype(input_root_type, get_args(t)[0])
or issubclass(input_root_type, get_args(t)[0])
for t in output_field_types
is_union_subtype(input_root_type, get_args(f)[0]) or issubclass(input_root_type, get_args(f)[0])
for f in output_fields
):
return "Collector outputs must connect to a collection input with a matching type"
return False
return None
return True
def nx_graph(self) -> nx.DiGraph:
"""Returns a NetworkX DiGraph representing the layout of this graph"""

View File

@@ -1,26 +0,0 @@
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.backend.model_manager.config import BaseModelType, SubModelType
def preprocess_t5_encoder_model_identifier(model_identifier: ModelIdentifierField) -> ModelIdentifierField:
"""A helper function to normalize a T5 encoder model identifier so that T5 models associated with FLUX
or SD3 models can be used interchangeably.
"""
if model_identifier.base == BaseModelType.Any:
return model_identifier.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
elif model_identifier.base == BaseModelType.StableDiffusion3:
return model_identifier.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
else:
raise ValueError(f"Unsupported model base: {model_identifier.base}")
def preprocess_t5_tokenizer_model_identifier(model_identifier: ModelIdentifierField) -> ModelIdentifierField:
"""A helper function to normalize a T5 tokenizer model identifier so that T5 models associated with FLUX
or SD3 models can be used interchangeably.
"""
if model_identifier.base == BaseModelType.Any:
return model_identifier.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
elif model_identifier.base == BaseModelType.StableDiffusion3:
return model_identifier.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
else:
raise ValueError(f"Unsupported model base: {model_identifier.base}")

View File

@@ -1,19 +1,13 @@
# Initially pulled from https://github.com/black-forest-labs/flux
from torch import Tensor, nn
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
from transformers import PreTrainedModel, PreTrainedTokenizer
from invokeai.backend.util.devices import TorchDevice
class HFEncoder(nn.Module):
def __init__(
self,
encoder: PreTrainedModel,
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
is_clip: bool,
max_length: int,
):
def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool, max_length: int):
super().__init__()
self.max_length = max_length
self.is_clip = is_clip

View File

@@ -9,17 +9,12 @@ class CachedModelOnlyFullLoad:
MPS memory, etc.
"""
def __init__(
self, model: torch.nn.Module | Any, compute_device: torch.device, total_bytes: int, keep_ram_copy: bool = False
):
def __init__(self, model: torch.nn.Module | Any, compute_device: torch.device, total_bytes: int):
"""Initialize a CachedModelOnlyFullLoad.
Args:
model (torch.nn.Module | Any): The model to wrap. Should be on the CPU.
compute_device (torch.device): The compute device to move the model to.
total_bytes (int): The total size (in bytes) of all the weights in the model.
keep_ram_copy (bool): Whether to keep a read-only copy of the model's state dict in RAM. Keeping a RAM copy
increases RAM usage, but speeds up model offload from VRAM and LoRA patching (assuming there is
sufficient RAM).
"""
# model is often a torch.nn.Module, but could be any model type. Throughout this class, we handle both cases.
self._model = model
@@ -28,7 +23,7 @@ class CachedModelOnlyFullLoad:
# A CPU read-only copy of the model's state dict.
self._cpu_state_dict: dict[str, torch.Tensor] | None = None
if isinstance(model, torch.nn.Module) and keep_ram_copy:
if isinstance(model, torch.nn.Module):
self._cpu_state_dict = model.state_dict()
self._total_bytes = total_bytes

View File

@@ -6,6 +6,18 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custo
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
from invokeai.backend.util.logging import InvokeAILogger
# @contextmanager
# def apply_load_state_dict_pre_hook(model: torch.nn.Module, hook: Callable[..., None], with_module: bool = False):
# """Apply a pre-hook to the model's load_state_dict() method."""
# # NOTE(ryand): torch.nn.Module._register_load_state_dict_pre_hook() is a private method in the current version of
# # PyTorch, but has recently been made public:
# # https://github.com/pytorch/pytorch/commit/1dd10ac8029a08a88825515bdf81134a5cb61357
# handle = model._register_load_state_dict_pre_hook(hook, with_module) # type: ignore
# try:
# yield
# finally:
# handle.remove()
class CachedModelWithPartialLoad:
"""A wrapper around a PyTorch model to handle partial loads and unloads between the CPU and the compute device.
@@ -54,36 +66,6 @@ class CachedModelWithPartialLoad:
return keys_in_modules_that_do_not_support_autocast
def _group_state_dict_keys_by_module_prefix(self, state_dict: dict[str, torch.Tensor]) -> dict[str, list[str]]:
"""A helper function that groups state dict keys by module prefix.
Example:
```
state_dict = {
"weight": ...,
"module.submodule.weight": ...,
"module.submodule.bias": ...,
"module.other_submodule.weight": ...,
"module.other_submodule.bias": ...,
}
output = group_state_dict_keys_by_module_prefix(state_dict)
# The output will be:
output = {
"": [
"weight",
],
"module.submodule": [
"module.submodule.weight",
"module.submodule.bias",
],
"module.other_submodule": [
"module.other_submodule.weight",
"module.other_submodule.bias",
],
}
```
"""
state_dict_keys_by_module_prefix: dict[str, list[str]] = {}
for key in state_dict.keys():
split = key.rsplit(".", 1)
@@ -144,38 +126,17 @@ class CachedModelWithPartialLoad:
"""Unload all weights from VRAM."""
return self.partial_unload_from_vram(self.total_bytes())
def _load_state_dict_with_device_conversion(
def _load_state_dict(
self, state_dict: dict[str, torch.Tensor], keys_to_convert: set[str], target_device: torch.device
):
if self._cpu_state_dict is not None:
# Run the fast version.
self._load_state_dict_with_fast_device_conversion(
state_dict=state_dict,
keys_to_convert=keys_to_convert,
target_device=target_device,
cpu_state_dict=self._cpu_state_dict,
)
else:
# Run the low-virtual-memory version.
self._load_state_dict_with_jit_device_conversion(
state_dict=state_dict,
keys_to_convert=keys_to_convert,
target_device=target_device,
)
"""A custom state dict loading implementation.
def _load_state_dict_with_jit_device_conversion(
self,
state_dict: dict[str, torch.Tensor],
keys_to_convert: set[str],
target_device: torch.device,
):
"""A custom state dict loading implementation with good peak memory properties.
This implementation has the important property that it copies parameters to the target device one module at a time
rather than applying all of the device conversions and then calling load_state_dict(). This is done to minimize the
peak virtual memory usage. Specifically, we want to avoid a case where we hold references to all of the CPU weights
and CUDA weights simultaneously, because Windows will reserve virtual memory for both.
This implementation has two important properties:
- It copies parameters to the target device one module at a time rather than applying all of the device
conversions and then calling load_state_dict(). This is done to minimize the peak RAM usage.
- It leverages the `self._cpu_state_dict` if it exists to speed up transfers of weights to the CPU.
"""
target_device_is_cpu = target_device.type == "cpu"
for module_name, module in self._model.named_modules():
module_keys = self._state_dict_keys_by_module_prefix.get(module_name, [])
# Calculate the length of the module name prefix.
@@ -188,38 +149,23 @@ class CachedModelWithPartialLoad:
if key in keys_to_convert:
# It is important that we overwrite `state_dict[key]` to avoid keeping two copies of the same
# parameter.
state_dict[key] = state_dict[key].to(target_device)
if target_device_is_cpu and self._cpu_state_dict is not None:
state_dict[key] = self._cpu_state_dict[key]
else:
state_dict[key] = state_dict[key].to(target_device)
# Note that we keep parameters that have not been moved to a new device in case the module implements
# weird custom state dict loading logic that requires all parameters to be present.
module_state_dict[key[prefix_len:]] = state_dict[key]
if len(module_state_dict) > 0:
# We set strict=False, because if `module` has both parameters and child modules, then we are loading a
# state dict that only contains the parameters of `module` (not its children).
# state dict that only contains the parameters of `module` (not its chilren).
# We assume that it is rare for non-leaf modules to have parameters. Calling load_state_dict() on non-leaf
# modules will recurse through all of the children, so is a bit wasteful.
incompatible_keys = module.load_state_dict(module_state_dict, strict=False, assign=True)
# Missing keys are ok, unexpected keys are not.
assert len(incompatible_keys.unexpected_keys) == 0
def _load_state_dict_with_fast_device_conversion(
self,
state_dict: dict[str, torch.Tensor],
keys_to_convert: set[str],
target_device: torch.device,
cpu_state_dict: dict[str, torch.Tensor],
):
"""Convert parameters to the target device and load them into the model. Leverages the `cpu_state_dict` to speed
up transfers of weights to the CPU.
"""
for key in keys_to_convert:
if target_device.type == "cpu":
state_dict[key] = cpu_state_dict[key]
else:
state_dict[key] = state_dict[key].to(target_device)
self._model.load_state_dict(state_dict, assign=True)
@torch.no_grad()
def partial_load_to_vram(self, vram_bytes_to_load: int) -> int:
"""Load more weights into VRAM without exceeding vram_bytes_to_load.
@@ -278,7 +224,7 @@ class CachedModelWithPartialLoad:
# We load the entire state dict, not just the parameters that changed, in case there are modules that
# override _load_from_state_dict() and do some funky stuff that requires the entire state dict.
# Alternatively, in the future, grouping parameters by module could probably solve this problem.
self._load_state_dict_with_device_conversion(cur_state_dict, keys_to_load, self._compute_device)
self._load_state_dict(cur_state_dict, keys_to_load, self._compute_device)
if self._cur_vram_bytes is not None:
self._cur_vram_bytes += vram_bytes_loaded
@@ -328,7 +274,7 @@ class CachedModelWithPartialLoad:
vram_bytes_freed += self._state_dict_bytes[key]
if len(keys_to_offload) > 0:
self._load_state_dict_with_device_conversion(cur_state_dict, keys_to_offload, torch.device("cpu"))
self._load_state_dict(cur_state_dict, keys_to_offload, torch.device("cpu"))
if self._cur_vram_bytes is not None:
self._cur_vram_bytes -= vram_bytes_freed

View File

@@ -1,10 +1,8 @@
import gc
import logging
import threading
import time
from functools import wraps
from logging import Logger
from typing import Any, Callable, Dict, List, Optional
from typing import Dict, List, Optional
import psutil
import torch
@@ -43,17 +41,6 @@ def get_model_cache_key(model_key: str, submodel_type: Optional[SubModelType] =
return model_key
def synchronized(method: Callable[..., Any]) -> Callable[..., Any]:
"""A decorator that applies the class's self._lock to the method."""
@wraps(method)
def wrapper(self, *args, **kwargs):
with self._lock: # Automatically acquire and release the lock
return method(self, *args, **kwargs)
return wrapper
class ModelCache:
"""A cache for managing models in memory.
@@ -91,7 +78,6 @@ class ModelCache:
self,
execution_device_working_mem_gb: float,
enable_partial_loading: bool,
keep_ram_copy_of_weights: bool,
max_ram_cache_size_gb: float | None = None,
max_vram_cache_size_gb: float | None = None,
execution_device: torch.device | str = "cuda",
@@ -119,7 +105,6 @@ class ModelCache:
:param logger: InvokeAILogger to use (otherwise creates one)
"""
self._enable_partial_loading = enable_partial_loading
self._keep_ram_copy_of_weights = keep_ram_copy_of_weights
self._execution_device_working_mem_gb = execution_device_working_mem_gb
self._execution_device: torch.device = torch.device(execution_device)
self._storage_device: torch.device = torch.device(storage_device)
@@ -136,27 +121,16 @@ class ModelCache:
self._cached_models: Dict[str, CacheRecord] = {}
self._cache_stack: List[str] = []
self._ram_cache_size_bytes = self._calc_ram_available_to_model_cache()
# A lock applied to all public method calls to make the ModelCache thread-safe.
# At the time of writing, the ModelCache should only be accessed from two threads:
# - The graph execution thread
# - Requests to empty the cache from a separate thread
self._lock = threading.RLock()
@property
@synchronized
def stats(self) -> Optional[CacheStats]:
"""Return collected CacheStats object."""
return self._stats
@stats.setter
@synchronized
def stats(self, stats: CacheStats) -> None:
"""Set the CacheStats object for collecting cache statistics."""
self._stats = stats
@synchronized
def put(self, key: str, model: AnyModel) -> None:
"""Add a model to the cache."""
if key in self._cached_models:
@@ -180,13 +154,9 @@ class ModelCache:
# Wrap model.
if isinstance(model, torch.nn.Module) and running_with_cuda and self._enable_partial_loading:
wrapped_model = CachedModelWithPartialLoad(
model, self._execution_device, keep_ram_copy=self._keep_ram_copy_of_weights
)
wrapped_model = CachedModelWithPartialLoad(model, self._execution_device, keep_ram_copy=False)
else:
wrapped_model = CachedModelOnlyFullLoad(
model, self._execution_device, size, keep_ram_copy=self._keep_ram_copy_of_weights
)
wrapped_model = CachedModelOnlyFullLoad(model, self._execution_device, size)
cache_record = CacheRecord(key=key, cached_model=wrapped_model)
self._cached_models[key] = cache_record
@@ -195,7 +165,6 @@ class ModelCache:
f"Added model {key} (Type: {model.__class__.__name__}, Wrap mode: {wrapped_model.__class__.__name__}, Model size: {size/MB:.2f}MB)"
)
@synchronized
def get(self, key: str, stats_name: Optional[str] = None) -> CacheRecord:
"""Retrieve a model from the cache.
@@ -231,7 +200,6 @@ class ModelCache:
self._logger.debug(f"Cache hit: {key} (Type: {cache_entry.cached_model.model.__class__.__name__})")
return cache_entry
@synchronized
def lock(self, cache_entry: CacheRecord, working_mem_bytes: Optional[int]) -> None:
"""Lock a model for use and move it into VRAM."""
if cache_entry.key not in self._cached_models:
@@ -267,7 +235,6 @@ class ModelCache:
self._log_cache_state()
@synchronized
def unlock(self, cache_entry: CacheRecord) -> None:
"""Unlock a model."""
if cache_entry.key not in self._cached_models:
@@ -415,77 +382,41 @@ class ModelCache:
# Alternative definition of VRAM in use:
# return sum(ce.cached_model.cur_vram_bytes() for ce in self._cached_models.values())
def _calc_ram_available_to_model_cache(self) -> int:
"""Calculate the amount of RAM available for the cache to use."""
def _get_ram_available(self) -> int:
"""Get the amount of RAM available for the cache to use, while keeping memory pressure under control."""
# If self._max_ram_cache_size_gb is set, then it overrides the default logic.
if self._max_ram_cache_size_gb is not None:
self._logger.info(f"Using user-defined RAM cache size: {self._max_ram_cache_size_gb} GB.")
return int(self._max_ram_cache_size_gb * GB)
ram_total_available_to_cache = int(self._max_ram_cache_size_gb * GB)
return ram_total_available_to_cache - self._get_ram_in_use()
# Heuristics for dynamically calculating the RAM cache size, **in order of increasing priority**:
# 1. As an initial default, use 50% of the total RAM for InvokeAI.
# - Assume a 2GB baseline for InvokeAI's non-model RAM usage, and use the rest of the RAM for the model cache.
# 2. On a system with a lot of RAM, users probably don't want InvokeAI to eat up too much RAM.
# There are diminishing returns to storing more and more models. So, we apply an upper bound. (Keep in mind
# that most OSes have some amount of disk caching, which we still benefit from if there is excess memory,
# even if we drop models from the cache.)
# - On systems without a CUDA device, the upper bound is 32GB.
# - On systems with a CUDA device, the upper bound is 1x the amount of VRAM (less the working memory).
# 3. Absolute minimum of 4GB.
virtual_memory = psutil.virtual_memory()
ram_total = virtual_memory.total
ram_available = virtual_memory.available
ram_used = ram_total - ram_available
# NOTE(ryand): We explored dynamically adjusting the RAM cache size based on memory pressure (using psutil), but
# decided against it for now, for the following reasons:
# - It was surprisingly difficult to get memory metrics with consistent definitions across OSes. (If you go
# down this path again, don't underestimate the amount of complexity here and be sure to test rigorously on all
# OSes.)
# - Making the RAM cache size dynamic opens the door for performance regressions that are hard to diagnose and
# hard for users to understand. It is better for users to see that their RAM is maxed out, and then override
# the default value if desired.
# The total size of all the models in the cache will often be larger than the amount of RAM reported by psutil
# (due to lazy-loading and OS RAM caching behaviour). We could just rely on the psutil values, but it feels
# like a bad idea to over-fill the model cache. So, for now, we'll try to keep the total size of models in the
# cache under the total amount of system RAM.
cache_ram_used = self._get_ram_in_use()
ram_used = max(cache_ram_used, ram_used)
# Lookup the total VRAM size for the CUDA execution device.
total_cuda_vram_bytes: int | None = None
if self._execution_device.type == "cuda":
_, total_cuda_vram_bytes = torch.cuda.mem_get_info(self._execution_device)
# Aim to keep 10% of RAM free.
ram_available_based_on_memory_usage = int(ram_total * 0.9) - ram_used
# Apply heuristic 1.
# ------------------
heuristics_applied = [1]
total_system_ram_bytes = psutil.virtual_memory().total
# Assumed baseline RAM used by InvokeAI for non-model stuff.
baseline_ram_used_by_invokeai = 2 * GB
ram_available_to_model_cache = int(total_system_ram_bytes * 0.5 - baseline_ram_used_by_invokeai)
# If we are running out of RAM, then there's an increased likelihood that we will run into this issue:
# https://github.com/invoke-ai/InvokeAI/issues/7513
# To keep things running smoothly, there's a minimum RAM cache size that we always allow (even if this means
# using swap).
min_ram_cache_size_bytes = 4 * GB
ram_available_based_on_min_cache_size = min_ram_cache_size_bytes - cache_ram_used
# Apply heuristic 2.
# ------------------
max_ram_cache_size_bytes = 32 * GB
if total_cuda_vram_bytes is not None:
if self._max_vram_cache_size_gb is not None:
max_ram_cache_size_bytes = int(self._max_vram_cache_size_gb * GB)
else:
max_ram_cache_size_bytes = total_cuda_vram_bytes - int(self._execution_device_working_mem_gb * GB)
if ram_available_to_model_cache > max_ram_cache_size_bytes:
heuristics_applied.append(2)
ram_available_to_model_cache = max_ram_cache_size_bytes
# Apply heuristic 3.
# ------------------
if ram_available_to_model_cache < 4 * GB:
heuristics_applied.append(3)
ram_available_to_model_cache = 4 * GB
self._logger.info(
f"Calculated model RAM cache size: {ram_available_to_model_cache / MB:.2f} MB. Heuristics applied: {heuristics_applied}."
)
return ram_available_to_model_cache
return max(ram_available_based_on_memory_usage, ram_available_based_on_min_cache_size)
def _get_ram_in_use(self) -> int:
"""Get the amount of RAM currently in use."""
return sum(ce.cached_model.total_bytes() for ce in self._cached_models.values())
def _get_ram_available(self) -> int:
"""Get the amount of RAM available for the cache to use."""
return self._ram_cache_size_bytes - self._get_ram_in_use()
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
if self._log_memory_usage:
return MemorySnapshot.capture()
@@ -601,7 +532,6 @@ class ModelCache:
self._logger.debug(log)
@synchronized
def make_room(self, bytes_needed: int) -> None:
"""Make enough room in the cache to accommodate a new model of indicated size.

View File

@@ -7,6 +7,7 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custo
CustomModuleMixin,
)
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
from invokeai.backend.patches.layers.lora_layer import LoRALayer
@@ -21,6 +22,25 @@ def linear_lora_forward(input: torch.Tensor, lora_layer: LoRALayer, lora_weight:
return x
def concatenated_lora_forward(
input: torch.Tensor, concatenated_lora_layer: ConcatenatedLoRALayer, lora_weight: float
) -> torch.Tensor:
"""An optimized implementation of the residual calculation for a sidecar ConcatenatedLoRALayer."""
x_chunks: list[torch.Tensor] = []
for lora_layer in concatenated_lora_layer.lora_layers:
x_chunk = torch.nn.functional.linear(input, lora_layer.down)
if lora_layer.mid is not None:
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.mid)
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.up, bias=lora_layer.bias)
x_chunk *= lora_weight * lora_layer.scale()
x_chunks.append(x_chunk)
# TODO(ryand): Generalize to support concat_axis != 0.
assert concatenated_lora_layer.concat_axis == 0
x = torch.cat(x_chunks, dim=-1)
return x
def autocast_linear_forward_sidecar_patches(
orig_module: torch.nn.Linear, input: torch.Tensor, patches_and_weights: list[tuple[BaseLayerPatch, float]]
) -> torch.Tensor:
@@ -46,6 +66,8 @@ def autocast_linear_forward_sidecar_patches(
output += linear_lora_forward(orig_input, patch, patch_weight)
elif isinstance(patch, LoRALayer):
output += linear_lora_forward(input, patch, patch_weight)
elif isinstance(patch, ConcatenatedLoRALayer):
output += concatenated_lora_forward(input, patch, patch_weight)
else:
unprocessed_patches_and_weights.append((patch, patch_weight))

View File

@@ -3,8 +3,6 @@ import copy
import torch
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.param_shape_utils import get_param_shape
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
class CustomModuleMixin:
@@ -44,20 +42,6 @@ class CustomModuleMixin:
device: torch.device | None = None,
):
"""Helper function that aggregates the parameters from all patches into a single dict."""
# HACK(ryand): If the original parameters are in a quantized format whose weights can't be accessed, we replace
# them with dummy tensors on the 'meta' device. This allows patch layers to access the shapes of the original
# parameters. But, of course, any sub-layers that need to access the actual values of the parameters will fail.
for param_name in orig_params.keys():
param = orig_params[param_name]
if type(param) is torch.nn.Parameter and type(param.data) is torch.Tensor:
pass
elif type(param) is GGMLTensor:
# Move to device and dequantize here. Doing it in the patch layer can result in redundant casts /
# dequantizations.
orig_params[param_name] = param.to(device=device).get_dequantized_tensor()
else:
orig_params[param_name] = torch.empty(get_param_shape(param), device="meta")
params: dict[str, torch.Tensor] = {}
for patch, patch_weight in patches_and_weights:

View File

@@ -80,19 +80,19 @@ class FluxVAELoader(ModelLoader):
raise ValueError("Only VAECheckpointConfig models are currently supported here.")
model_path = Path(config.path)
with accelerate.init_empty_weights():
with SilenceWarnings():
model = AutoEncoder(ae_params[config.config_path])
sd = load_file(model_path)
model.load_state_dict(sd, assign=True)
# VAE is broken in float16, which mps defaults to
if self._torch_dtype == torch.float16:
try:
vae_dtype = torch.tensor([1.0], dtype=torch.bfloat16, device=self._torch_device).dtype
except TypeError:
vae_dtype = torch.float32
else:
vae_dtype = self._torch_dtype
model.to(vae_dtype)
sd = load_file(model_path)
model.load_state_dict(sd, assign=True)
# VAE is broken in float16, which mps defaults to
if self._torch_dtype == torch.float16:
try:
vae_dtype = torch.tensor([1.0], dtype=torch.bfloat16, device=self._torch_device).dtype
except TypeError:
vae_dtype = torch.float32
else:
vae_dtype = self._torch_dtype
model.to(vae_dtype)
return model
@@ -183,9 +183,7 @@ class T5EncoderCheckpointModel(ModelLoader):
case SubModelType.Tokenizer2 | SubModelType.Tokenizer3:
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
case SubModelType.TextEncoder2 | SubModelType.TextEncoder3:
return T5EncoderModel.from_pretrained(
Path(config.path) / "text_encoder_2", torch_dtype="auto", low_cpu_mem_usage=True
)
return T5EncoderModel.from_pretrained(Path(config.path) / "text_encoder_2", torch_dtype="auto")
raise ValueError(
f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
@@ -219,18 +217,17 @@ class FluxCheckpointModel(ModelLoader):
assert isinstance(config, MainCheckpointConfig)
model_path = Path(config.path)
with accelerate.init_empty_weights():
with SilenceWarnings():
model = Flux(params[config.config_path])
sd = load_file(model_path)
if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd:
sd = convert_bundle_to_flux_transformer_checkpoint(sd)
new_sd_size = sum([ten.nelement() * torch.bfloat16.itemsize for ten in sd.values()])
self._ram_cache.make_room(new_sd_size)
for k in sd.keys():
# We need to cast to bfloat16 due to it being the only currently supported dtype for inference
sd[k] = sd[k].to(torch.bfloat16)
model.load_state_dict(sd, assign=True)
sd = load_file(model_path)
if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd:
sd = convert_bundle_to_flux_transformer_checkpoint(sd)
new_sd_size = sum([ten.nelement() * torch.bfloat16.itemsize for ten in sd.values()])
self._ram_cache.make_room(new_sd_size)
for k in sd.keys():
# We need to cast to bfloat16 due to it being the only currently supported dtype for inference
sd[k] = sd[k].to(torch.bfloat16)
model.load_state_dict(sd, assign=True)
return model
@@ -261,11 +258,11 @@ class FluxGGUFCheckpointModel(ModelLoader):
assert isinstance(config, MainGGUFCheckpointConfig)
model_path = Path(config.path)
with accelerate.init_empty_weights():
with SilenceWarnings():
model = Flux(params[config.config_path])
# HACK(ryand): We shouldn't be hard-coding the compute_dtype here.
sd = gguf_sd_loader(model_path, compute_dtype=torch.bfloat16)
# HACK(ryand): We shouldn't be hard-coding the compute_dtype here.
sd = gguf_sd_loader(model_path, compute_dtype=torch.bfloat16)
# HACK(ryand): There are some broken GGUF models in circulation that have the wrong shape for img_in.weight.
# We override the shape here to fix the issue.

View File

@@ -31,10 +31,6 @@ from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils
is_state_dict_likely_in_flux_kohya_format,
lora_model_from_flux_kohya_state_dict,
)
from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_utils import (
is_state_dict_likely_in_flux_onetrainer_format,
lora_model_from_flux_onetrainer_state_dict,
)
from invokeai.backend.patches.lora_conversions.sd_lora_conversion_utils import lora_model_from_sd_state_dict
from invokeai.backend.patches.lora_conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
@@ -88,12 +84,8 @@ class LoRALoader(ModelLoader):
elif config.format == ModelFormat.LyCORIS:
if is_state_dict_likely_in_flux_kohya_format(state_dict=state_dict):
model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict)
elif is_state_dict_likely_in_flux_onetrainer_format(state_dict=state_dict):
model = lora_model_from_flux_onetrainer_state_dict(state_dict=state_dict)
elif is_state_dict_likely_flux_control(state_dict=state_dict):
model = lora_model_from_flux_control_state_dict(state_dict=state_dict)
else:
raise ValueError(f"LoRA model is in unsupported FLUX format: {config.format}")
else:
raise ValueError(f"LoRA model is in unsupported FLUX format: {config.format}")
elif self._model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:

View File

@@ -46,9 +46,6 @@ from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_ut
from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils import (
is_state_dict_likely_in_flux_kohya_format,
)
from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_utils import (
is_state_dict_likely_in_flux_onetrainer_format,
)
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
@@ -286,7 +283,7 @@ class ModelProbe(object):
return ModelType.Main
elif key.startswith(("encoder.conv_in", "decoder.conv_in")):
return ModelType.VAE
elif key.startswith(("lora_te_", "lora_unet_", "lora_te1_", "lora_te2_", "lora_transformer_")):
elif key.startswith(("lora_te_", "lora_unet_")):
return ModelType.LoRA
# "lora_A.weight" and "lora_B.weight" are associated with models in PEFT format. We don't support all PEFT
# LoRA models, but as of the time of writing, we support Diffusers FLUX PEFT LoRA models.
@@ -635,7 +632,6 @@ class LoRACheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
if (
is_state_dict_likely_in_flux_kohya_format(self.checkpoint)
or is_state_dict_likely_in_flux_onetrainer_format(self.checkpoint)
or is_state_dict_likely_in_flux_diffusers_format(self.checkpoint)
or is_state_dict_likely_flux_control(self.checkpoint)
):

View File

@@ -0,0 +1,55 @@
from typing import Optional, Sequence
import torch
from invokeai.backend.patches.layers.lora_layer import LoRALayer
from invokeai.backend.patches.layers.lora_layer_base import LoRALayerBase
class ConcatenatedLoRALayer(LoRALayerBase):
"""A LoRA layer that is composed of multiple LoRA layers concatenated along a specified axis.
This class was created to handle a special case with FLUX LoRA models. In the BFL FLUX model format, the attention
Q, K, V matrices are concatenated along the first dimension. In the diffusers LoRA format, the Q, K, V matrices are
stored as separate tensors. This class enables diffusers LoRA layers to be used in BFL FLUX models.
"""
def __init__(self, lora_layers: Sequence[LoRALayer], concat_axis: int = 0):
super().__init__(alpha=None, bias=None)
self.lora_layers = lora_layers
self.concat_axis = concat_axis
def _rank(self) -> int | None:
return None
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
# TODO(ryand): Currently, we pass orig_weight=None to the sub-layers. If we want to support sub-layers that
# require this value, we will need to implement chunking of the original weight tensor here.
# Note that we must apply the sub-layer scales here.
layer_weights = [lora_layer.get_weight(None) * lora_layer.scale() for lora_layer in self.lora_layers] # pyright: ignore[reportArgumentType]
return torch.cat(layer_weights, dim=self.concat_axis)
def get_bias(self, orig_bias: torch.Tensor | None) -> Optional[torch.Tensor]:
# TODO(ryand): Currently, we pass orig_bias=None to the sub-layers. If we want to support sub-layers that
# require this value, we will need to implement chunking of the original bias tensor here.
# Note that we must apply the sub-layer scales here.
layer_biases: list[torch.Tensor] = []
for lora_layer in self.lora_layers:
layer_bias = lora_layer.get_bias(None)
if layer_bias is not None:
layer_biases.append(layer_bias * lora_layer.scale())
if len(layer_biases) == 0:
return None
assert len(layer_biases) == len(self.lora_layers)
return torch.cat(layer_biases, dim=self.concat_axis)
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
super().to(device=device, dtype=dtype)
for lora_layer in self.lora_layers:
lora_layer.to(device=device, dtype=dtype)
def calc_size(self) -> int:
return super().calc_size() + sum(lora_layer.calc_size() for lora_layer in self.lora_layers)

View File

@@ -1,115 +0,0 @@
from typing import Dict, Optional
import torch
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
from invokeai.backend.patches.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.util.calc_tensor_size import calc_tensors_size
class DoRALayer(LoRALayerBase):
"""A DoRA layer. As defined in https://arxiv.org/pdf/2402.09353."""
def __init__(
self,
up: torch.Tensor,
down: torch.Tensor,
dora_scale: torch.Tensor,
alpha: float | None,
bias: Optional[torch.Tensor],
):
super().__init__(alpha, bias)
self.up = up
self.down = down
self.dora_scale = dora_scale
@classmethod
def from_state_dict_values(cls, values: Dict[str, torch.Tensor]):
alpha = cls._parse_alpha(values.get("alpha", None))
bias = cls._parse_bias(
values.get("bias_indices", None), values.get("bias_values", None), values.get("bias_size", None)
)
layer = cls(
up=values["lora_up.weight"],
down=values["lora_down.weight"],
dora_scale=values["dora_scale"],
alpha=alpha,
bias=bias,
)
cls.warn_on_unhandled_keys(
values=values,
handled_keys={
# Default keys.
"alpha",
"bias_indices",
"bias_values",
"bias_size",
# Layer-specific keys.
"lora_up.weight",
"lora_down.weight",
"dora_scale",
},
)
return layer
def _rank(self) -> int:
return self.down.shape[0]
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
orig_weight = cast_to_device(orig_weight, self.up.device)
# Note: Variable names (e.g. delta_v) are based on the paper.
delta_v = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
delta_v = delta_v.reshape(orig_weight.shape)
delta_v = delta_v * self.scale()
# At this point, out_weight is the unnormalized direction matrix.
out_weight = orig_weight + delta_v
# TODO(ryand): Simplify this logic.
direction_norm = (
out_weight.transpose(0, 1)
.reshape(out_weight.shape[1], -1)
.norm(dim=1, keepdim=True)
.reshape(out_weight.shape[1], *[1] * (out_weight.dim() - 1))
.transpose(0, 1)
)
out_weight *= self.dora_scale / direction_norm
return out_weight - orig_weight
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
super().to(device=device, dtype=dtype)
self.up = self.up.to(device=device, dtype=dtype)
self.down = self.down.to(device=device, dtype=dtype)
self.dora_scale = self.dora_scale.to(device=device, dtype=dtype)
def calc_size(self) -> int:
return super().calc_size() + calc_tensors_size([self.up, self.down, self.dora_scale])
def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]:
if any(p.device.type == "meta" for p in orig_parameters.values()):
# If any of the original parameters are on the 'meta' device, we assume this is because the base model is in
# a quantization format that doesn't allow easy dequantization.
raise RuntimeError(
"The base model quantization format (likely bitsandbytes) is not compatible with DoRA patches."
)
scale = self.scale()
params = {"weight": self.get_weight(orig_parameters["weight"]) * weight}
bias = self.get_bias(orig_parameters.get("bias", None))
if bias is not None:
params["bias"] = bias * (weight * scale)
# Reshape all params to match the original module's shape.
for param_name, param_weight in params.items():
orig_param = orig_parameters[param_name]
if param_weight.shape != orig_param.shape:
params[param_name] = param_weight.reshape(orig_param.shape)
return params

View File

@@ -4,7 +4,6 @@ import torch
import invokeai.backend.util.logging as logger
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.param_shape_utils import get_param_shape
from invokeai.backend.util.calc_tensor_size import calc_tensors_size
@@ -68,8 +67,8 @@ class LoRALayerBase(BaseLayerPatch):
# Reshape all params to match the original module's shape.
for param_name, param_weight in params.items():
orig_param = orig_parameters[param_name]
if param_weight.shape != get_param_shape(orig_param):
params[param_name] = param_weight.reshape(get_param_shape(orig_param))
if param_weight.shape != orig_param.shape:
params[param_name] = param_weight.reshape(orig_param.shape)
return params

View File

@@ -1,65 +0,0 @@
from dataclasses import dataclass
from typing import Sequence
import torch
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.param_shape_utils import get_param_shape
@dataclass
class Range:
start: int
end: int
class MergedLayerPatch(BaseLayerPatch):
"""A patch layer that is composed of multiple sub-layers merged together.
This class was created to handle a special case with FLUX LoRA models. In the BFL FLUX model format, the attention
Q, K, V matrices are concatenated along the first dimension. In the diffusers LoRA format, the Q, K, V matrices are
stored as separate tensors. This class enables diffusers LoRA layers to be used in BFL FLUX models.
"""
def __init__(
self,
lora_layers: Sequence[BaseLayerPatch],
ranges: Sequence[Range],
):
super().__init__()
self.lora_layers = lora_layers
# self.ranges[i] is the range for the i'th lora layer along the 0'th weight dimension.
self.ranges = ranges
assert len(self.ranges) == len(self.lora_layers)
def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]:
out_parameters: dict[str, torch.Tensor] = {}
for lora_layer, range in zip(self.lora_layers, self.ranges, strict=True):
sliced_parameters: dict[str, torch.Tensor] = {
n: p[range.start : range.end] for n, p in orig_parameters.items()
}
# Note that `weight` is applied in the sub-layers, no need to apply it in this function.
layer_out_parameters = lora_layer.get_parameters(sliced_parameters, weight)
for out_param_name, out_param in layer_out_parameters.items():
if out_param_name not in out_parameters:
# If not already in the output dict, initialize an output tensor with the same shape as the full
# original parameter.
out_parameters[out_param_name] = torch.zeros(
get_param_shape(orig_parameters[out_param_name]),
dtype=out_param.dtype,
device=out_param.device,
)
out_parameters[out_param_name][range.start : range.end] += out_param
return out_parameters
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
for lora_layer in self.lora_layers:
lora_layer.to(device=device, dtype=dtype)
def calc_size(self) -> int:
return sum(lora_layer.calc_size() for lora_layer in self.lora_layers)

View File

@@ -1,19 +0,0 @@
import torch
try:
from bitsandbytes.nn.modules import Params4bit
bnb_available: bool = True
except ImportError:
bnb_available: bool = False
def get_param_shape(param: torch.Tensor) -> torch.Size:
"""A helper function to get the shape of a parameter that handles `bitsandbytes.nn.Params4Bit` correctly."""
# Accessing the `.shape` attribute of `bitsandbytes.nn.Params4Bit` will return an incorrect result. Instead, we must
# access the `.quant_state.shape` attribute.
if bnb_available and type(param) is Params4bit: # type: ignore
quant_state = param.quant_state
if quant_state is not None:
return quant_state.shape
return param.shape

View File

@@ -3,7 +3,6 @@ from typing import Dict
import torch
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.dora_layer import DoRALayer
from invokeai.backend.patches.layers.full_layer import FullLayer
from invokeai.backend.patches.layers.ia3_layer import IA3Layer
from invokeai.backend.patches.layers.loha_layer import LoHALayer
@@ -15,9 +14,8 @@ from invokeai.backend.patches.layers.norm_layer import NormLayer
def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> BaseLayerPatch:
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
if "dora_scale" in state_dict:
return DoRALayer.from_state_dict_values(state_dict)
elif "lora_up.weight" in state_dict:
if "lora_up.weight" in state_dict:
# LoRA a.k.a LoCon
return LoRALayer.from_state_dict_values(state_dict)
elif "hada_w1_a" in state_dict:

View File

@@ -3,8 +3,8 @@ from typing import Dict
import torch
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.merged_layer_patch import MergedLayerPatch, Range
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict
from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.patches.layers.lora_layer import LoRALayer
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
@@ -33,21 +33,13 @@ def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Te
def lora_model_from_flux_diffusers_state_dict(
state_dict: Dict[str, torch.Tensor], alpha: float | None
) -> ModelPatchRaw:
# Group keys by layer.
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = _group_by_layer(state_dict)
layers = lora_layers_from_flux_diffusers_grouped_state_dict(grouped_state_dict, alpha)
return ModelPatchRaw(layers=layers)
def lora_layers_from_flux_diffusers_grouped_state_dict(
grouped_state_dict: Dict[str, Dict[str, torch.Tensor]], alpha: float | None
) -> dict[str, BaseLayerPatch]:
"""Converts a grouped state dict with Diffusers FLUX LoRA keys to LoRA layers with BFL keys (i.e. the module key
format used by Invoke).
"""Loads a state dict in the Diffusers FLUX LoRA format into a LoRAModelRaw object.
This function is based on:
https://github.com/huggingface/diffusers/blob/55ac421f7bb12fd00ccbef727be4dc2f3f920abb/scripts/convert_flux_to_diffusers.py
"""
# Group keys by layer.
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = _group_by_layer(state_dict)
# Remove the "transformer." prefix from all keys.
grouped_state_dict = {k.replace("transformer.", ""): v for k, v in grouped_state_dict.items()}
@@ -61,26 +53,17 @@ def lora_layers_from_flux_diffusers_grouped_state_dict(
layers: dict[str, BaseLayerPatch] = {}
def get_lora_layer_values(src_layer_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
if "lora_A.weight" in src_layer_dict:
# The LoRA keys are in PEFT format.
values = {
def add_lora_layer_if_present(src_key: str, dst_key: str) -> None:
if src_key in grouped_state_dict:
src_layer_dict = grouped_state_dict.pop(src_key)
value = {
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
"lora_up.weight": src_layer_dict.pop("lora_B.weight"),
}
if alpha is not None:
values["alpha"] = torch.tensor(alpha)
value["alpha"] = torch.tensor(alpha)
layers[dst_key] = LoRALayer.from_state_dict_values(values=value)
assert len(src_layer_dict) == 0
return values
else:
# Assume that the LoRA keys are in Kohya format.
return src_layer_dict
def add_lora_layer_if_present(src_key: str, dst_key: str) -> None:
if src_key in grouped_state_dict:
src_layer_dict = grouped_state_dict.pop(src_key)
values = get_lora_layer_values(src_layer_dict)
layers[dst_key] = any_lora_layer_from_state_dict(values)
def add_qkv_lora_layer_if_present(
src_keys: list[str],
@@ -96,24 +79,29 @@ def lora_layers_from_flux_diffusers_grouped_state_dict(
if not any(keys_present):
return
dim_0_offset = 0
sub_layers: list[BaseLayerPatch] = []
sub_layer_ranges: list[Range] = []
sub_layers: list[LoRALayer] = []
for src_key, src_weight_shape in zip(src_keys, src_weight_shapes, strict=True):
src_layer_dict = grouped_state_dict.pop(src_key, None)
if src_layer_dict is not None:
values = get_lora_layer_values(src_layer_dict)
# assert values["lora_down.weight"].shape[1] == src_weight_shape[1]
# assert values["lora_up.weight"].shape[0] == src_weight_shape[0]
sub_layers.append(any_lora_layer_from_state_dict(values))
sub_layer_ranges.append(Range(dim_0_offset, dim_0_offset + src_weight_shape[0]))
values = {
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
"lora_up.weight": src_layer_dict.pop("lora_B.weight"),
}
if alpha is not None:
values["alpha"] = torch.tensor(alpha)
assert values["lora_down.weight"].shape[1] == src_weight_shape[1]
assert values["lora_up.weight"].shape[0] == src_weight_shape[0]
sub_layers.append(LoRALayer.from_state_dict_values(values=values))
assert len(src_layer_dict) == 0
else:
if not allow_missing_keys:
raise ValueError(f"Missing LoRA layer: '{src_key}'.")
dim_0_offset += src_weight_shape[0]
layers[dst_qkv_key] = MergedLayerPatch(sub_layers, sub_layer_ranges)
values = {
"lora_up.weight": torch.zeros((src_weight_shape[0], 1)),
"lora_down.weight": torch.zeros((1, src_weight_shape[1])),
}
sub_layers.append(LoRALayer.from_state_dict_values(values=values))
layers[dst_qkv_key] = ConcatenatedLoRALayer(lora_layers=sub_layers)
# time_text_embed.timestep_embedder -> time_in.
add_lora_layer_if_present("time_text_embed.timestep_embedder.linear_1", "time_in.in_layer")
@@ -229,7 +217,7 @@ def lora_layers_from_flux_diffusers_grouped_state_dict(
layers_with_prefix = {f"{FLUX_LORA_TRANSFORMER_PREFIX}{k}": v for k, v in layers.items()}
return layers_with_prefix
return ModelPatchRaw(layers=layers_with_prefix)
def _group_by_layer(state_dict: Dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:

View File

@@ -7,7 +7,6 @@ from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict
from invokeai.backend.patches.lora_conversions.flux_lora_constants import (
FLUX_LORA_CLIP_PREFIX,
FLUX_LORA_T5_PREFIX,
FLUX_LORA_TRANSFORMER_PREFIX,
)
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
@@ -27,14 +26,6 @@ FLUX_KOHYA_TRANSFORMER_KEY_REGEX = (
# lora_te1_text_model_encoder_layers_0_mlp_fc1.lora_up.weight
FLUX_KOHYA_CLIP_KEY_REGEX = r"lora_te1_text_model_encoder_layers_(\d+)_(mlp|self_attn)_(\w+)\.?.*"
# A regex pattern that matches all of the T5 keys in the Kohya FLUX LoRA format.
# Example keys:
# lora_te2_encoder_block_0_layer_0_SelfAttention_k.alpha
# lora_te2_encoder_block_0_layer_0_SelfAttention_k.dora_scale
# lora_te2_encoder_block_0_layer_0_SelfAttention_k.lora_down.weight
# lora_te2_encoder_block_0_layer_0_SelfAttention_k.lora_up.weight
FLUX_KOHYA_T5_KEY_REGEX = r"lora_te2_encoder_block_(\d+)_layer_(\d+)_(DenseReluDense|SelfAttention)_(\w+)_?(\w+)?\.?.*"
def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> bool:
"""Checks if the provided state dict is likely in the Kohya FLUX LoRA format.
@@ -43,9 +34,7 @@ def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> boo
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
"""
return all(
re.match(FLUX_KOHYA_TRANSFORMER_KEY_REGEX, k)
or re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k)
or re.match(FLUX_KOHYA_T5_KEY_REGEX, k)
re.match(FLUX_KOHYA_TRANSFORMER_KEY_REGEX, k) or re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k)
for k in state_dict.keys()
)
@@ -59,34 +48,27 @@ def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -
grouped_state_dict[layer_name] = {}
grouped_state_dict[layer_name][param_name] = value
# Split the grouped state dict into transformer, CLIP, and T5 state dicts.
# Split the grouped state dict into transformer and CLIP state dicts.
transformer_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
clip_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
t5_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
for layer_name, layer_state_dict in grouped_state_dict.items():
if layer_name.startswith("lora_unet"):
transformer_grouped_sd[layer_name] = layer_state_dict
elif layer_name.startswith("lora_te1"):
clip_grouped_sd[layer_name] = layer_state_dict
elif layer_name.startswith("lora_te2"):
t5_grouped_sd[layer_name] = layer_state_dict
else:
raise ValueError(f"Layer '{layer_name}' does not match the expected pattern for FLUX LoRA weights.")
# Convert the state dicts to the InvokeAI format.
transformer_grouped_sd = _convert_flux_transformer_kohya_state_dict_to_invoke_format(transformer_grouped_sd)
clip_grouped_sd = _convert_flux_clip_kohya_state_dict_to_invoke_format(clip_grouped_sd)
t5_grouped_sd = _convert_flux_t5_kohya_state_dict_to_invoke_format(t5_grouped_sd)
# Create LoRA layers.
layers: dict[str, BaseLayerPatch] = {}
for model_prefix, grouped_sd in [
(FLUX_LORA_TRANSFORMER_PREFIX, transformer_grouped_sd),
(FLUX_LORA_CLIP_PREFIX, clip_grouped_sd),
(FLUX_LORA_T5_PREFIX, t5_grouped_sd),
]:
for layer_key, layer_state_dict in grouped_sd.items():
layers[model_prefix + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
for layer_key, layer_state_dict in transformer_grouped_sd.items():
layers[FLUX_LORA_TRANSFORMER_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
for layer_key, layer_state_dict in clip_grouped_sd.items():
layers[FLUX_LORA_CLIP_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
# Create and return the LoRAModelRaw.
return ModelPatchRaw(layers=layers)
@@ -141,31 +123,3 @@ def _convert_flux_transformer_kohya_state_dict_to_invoke_format(state_dict: Dict
raise ValueError(f"Key '{k}' does not match the expected pattern for FLUX LoRA weights.")
return converted_dict
def _convert_flux_t5_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) -> Dict[str, T]:
"""Converts a T5 LoRA state dict from the Kohya FLUX LoRA format to LoRA weight format used internally by
InvokeAI.
Example key conversions:
"lora_te2_encoder_block_0_layer_0_SelfAttention_k" -> "encoder.block.0.layer.0.SelfAttention.k"
"lora_te2_encoder_block_0_layer_1_DenseReluDense_wi_0" -> "encoder.block.0.layer.1.DenseReluDense.wi.0"
"""
def replace_func(match: re.Match[str]) -> str:
s = f"encoder.block.{match.group(1)}.layer.{match.group(2)}.{match.group(3)}.{match.group(4)}"
if match.group(5):
s += f".{match.group(5)}"
return s
converted_dict: dict[str, T] = {}
for k, v in state_dict.items():
match = re.match(FLUX_KOHYA_T5_KEY_REGEX, k)
if match:
new_key = re.sub(FLUX_KOHYA_T5_KEY_REGEX, replace_func, k)
converted_dict[new_key] = v
else:
raise ValueError(f"Key '{k}' does not match the expected pattern for FLUX LoRA weights.")
return converted_dict

View File

@@ -1,4 +1,3 @@
# Prefixes used to distinguish between transformer and CLIP text encoder keys in the FLUX InvokeAI LoRA format.
FLUX_LORA_TRANSFORMER_PREFIX = "lora_transformer-"
FLUX_LORA_CLIP_PREFIX = "lora_clip-"
FLUX_LORA_T5_PREFIX = "lora_t5-"

View File

@@ -1,163 +0,0 @@
import re
from typing import Any, Dict
import torch
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict
from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import (
lora_layers_from_flux_diffusers_grouped_state_dict,
)
from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils import (
FLUX_KOHYA_CLIP_KEY_REGEX,
FLUX_KOHYA_T5_KEY_REGEX,
_convert_flux_clip_kohya_state_dict_to_invoke_format,
_convert_flux_t5_kohya_state_dict_to_invoke_format,
)
from invokeai.backend.patches.lora_conversions.flux_lora_constants import (
FLUX_LORA_CLIP_PREFIX,
FLUX_LORA_T5_PREFIX,
)
from invokeai.backend.patches.lora_conversions.kohya_key_utils import (
INDEX_PLACEHOLDER,
ParsingTree,
insert_periods_into_kohya_key,
)
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
# A regex pattern that matches all of the transformer keys in the OneTrainer FLUX LoRA format.
# The OneTrainer format uses a mix of the Kohya and Diffusers formats:
# - The base model keys are in Diffusers format.
# - Periods are replaced with underscores, to match Kohya.
# - The LoRA key suffixes (e.g. .alpha, .lora_down.weight, .lora_up.weight) match Kohya.
# Example keys:
# - "lora_transformer_single_transformer_blocks_0_attn_to_k.alpha"
# - "lora_transformer_single_transformer_blocks_0_attn_to_k.dora_scale"
# - "lora_transformer_single_transformer_blocks_0_attn_to_k.lora_down.weight"
# - "lora_transformer_single_transformer_blocks_0_attn_to_k.lora_up.weight"
FLUX_ONETRAINER_TRANSFORMER_KEY_REGEX = (
r"lora_transformer_(single_transformer_blocks|transformer_blocks)_(\d+)_(\w+)\.(.*)"
)
def is_state_dict_likely_in_flux_onetrainer_format(state_dict: Dict[str, Any]) -> bool:
"""Checks if the provided state dict is likely in the OneTrainer FLUX LoRA format.
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
Note that OneTrainer matches the Kohya format for the CLIP and T5 models.
"""
return all(
re.match(FLUX_ONETRAINER_TRANSFORMER_KEY_REGEX, k)
or re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k)
or re.match(FLUX_KOHYA_T5_KEY_REGEX, k)
for k in state_dict.keys()
)
def lora_model_from_flux_onetrainer_state_dict(state_dict: Dict[str, torch.Tensor]) -> ModelPatchRaw: # type: ignore
# Group keys by layer.
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}
for key, value in state_dict.items():
layer_name, param_name = key.split(".", 1)
if layer_name not in grouped_state_dict:
grouped_state_dict[layer_name] = {}
grouped_state_dict[layer_name][param_name] = value
# Split the grouped state dict into transformer, CLIP, and T5 state dicts.
transformer_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
clip_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
t5_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
for layer_name, layer_state_dict in grouped_state_dict.items():
if layer_name.startswith("lora_transformer"):
transformer_grouped_sd[layer_name] = layer_state_dict
elif layer_name.startswith("lora_te1"):
clip_grouped_sd[layer_name] = layer_state_dict
elif layer_name.startswith("lora_te2"):
t5_grouped_sd[layer_name] = layer_state_dict
else:
raise ValueError(f"Layer '{layer_name}' does not match the expected pattern for FLUX LoRA weights.")
# Convert the state dicts to the InvokeAI format.
clip_grouped_sd = _convert_flux_clip_kohya_state_dict_to_invoke_format(clip_grouped_sd)
t5_grouped_sd = _convert_flux_t5_kohya_state_dict_to_invoke_format(t5_grouped_sd)
# Create LoRA layers.
layers: dict[str, BaseLayerPatch] = {}
for model_prefix, grouped_sd in [
# (FLUX_LORA_TRANSFORMER_PREFIX, transformer_grouped_sd),
(FLUX_LORA_CLIP_PREFIX, clip_grouped_sd),
(FLUX_LORA_T5_PREFIX, t5_grouped_sd),
]:
for layer_key, layer_state_dict in grouped_sd.items():
layers[model_prefix + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
# Handle the transformer.
transformer_layers = _convert_flux_transformer_onetrainer_state_dict_to_invoke_format(transformer_grouped_sd)
layers.update(transformer_layers)
# Create and return the LoRAModelRaw.
return ModelPatchRaw(layers=layers)
# This parsing tree was generated by calling `generate_kohya_parsing_tree_from_keys()` on the keys in
# flux_lora_diffusers_format.py.
flux_transformer_kohya_parsing_tree: ParsingTree = {
"transformer": {
"single_transformer_blocks": {
INDEX_PLACEHOLDER: {
"attn": {"to_k": {}, "to_q": {}, "to_v": {}},
"norm": {"linear": {}},
"proj_mlp": {},
"proj_out": {},
}
},
"transformer_blocks": {
INDEX_PLACEHOLDER: {
"attn": {
"add_k_proj": {},
"add_q_proj": {},
"add_v_proj": {},
"to_add_out": {},
"to_k": {},
"to_out": {INDEX_PLACEHOLDER: {}},
"to_q": {},
"to_v": {},
},
"ff": {"net": {INDEX_PLACEHOLDER: {"proj": {}}}},
"ff_context": {"net": {INDEX_PLACEHOLDER: {"proj": {}}}},
"norm1": {"linear": {}},
"norm1_context": {"linear": {}},
}
},
}
}
def _convert_flux_transformer_onetrainer_state_dict_to_invoke_format(
state_dict: Dict[str, Dict[str, torch.Tensor]],
) -> dict[str, BaseLayerPatch]:
"""Converts a FLUX transformer LoRA state dict from the OneTrainer FLUX LoRA format to the LoRA weight format used
internally by InvokeAI.
"""
# Step 1: Convert the Kohya-style keys with underscores to classic keys with periods.
# Example:
# "lora_transformer_single_transformer_blocks_0_attn_to_k.lora_down.weight" -> "transformer.single_transformer_blocks.0.attn.to_k.lora_down.weight"
lora_prefix = "lora_"
lora_prefix_length = len(lora_prefix)
kohya_state_dict: dict[str, Dict[str, torch.Tensor]] = {}
for key in state_dict.keys():
# Remove the "lora_" prefix.
assert key.startswith(lora_prefix)
new_key = key[lora_prefix_length:]
# Add periods to the Kohya-style module keys.
new_key = insert_periods_into_kohya_key(new_key, flux_transformer_kohya_parsing_tree)
# Replace the old key with the new key.
kohya_state_dict[new_key] = state_dict[key]
# Step 2: Convert diffusers module names to the BFL module names.
return lora_layers_from_flux_diffusers_grouped_state_dict(kohya_state_dict, alpha=None)

View File

@@ -1,102 +0,0 @@
from typing import Iterable
INDEX_PLACEHOLDER = "index_placeholder"
# Type alias for a 'ParsingTree', which is a recursive dict with string keys.
ParsingTree = dict[str, "ParsingTree"]
def insert_periods_into_kohya_key(key: str, parsing_tree: ParsingTree) -> str:
"""Insert periods into a Kohya key based on a parsing tree.
Kohya format keys are produced by replacing periods with underscores in the original key.
Example:
```
key = "module_a_module_b_0_attn_to_k"
parsing_tree = {
"module_a": {
"module_b": {
INDEX_PLACEHOLDER: {
"attn": {},
},
},
},
}
result = insert_periods_into_kohya_key(key, parsing_tree)
> "module_a.module_b.0.attn.to_k"
```
"""
# Split key into parts by underscore.
parts = key.split("_")
# Build up result by walking through parsing tree and parts.
result_parts: list[str] = []
current_part = ""
current_tree = parsing_tree
for part in parts:
if len(current_part) > 0:
current_part = current_part + "_"
current_part += part
if current_part in current_tree:
# Match found.
current_tree = current_tree[current_part]
result_parts.append(current_part)
current_part = ""
elif current_part.isnumeric() and INDEX_PLACEHOLDER in current_tree:
# Match found with index placeholder.
current_tree = current_tree[INDEX_PLACEHOLDER]
result_parts.append(current_part)
current_part = ""
if len(current_part) > 0:
raise ValueError(f"Key {key} does not match parsing tree {parsing_tree}.")
return ".".join(result_parts)
def generate_kohya_parsing_tree_from_keys(keys: Iterable[str]) -> ParsingTree:
"""Generate a parsing tree from a list of keys.
Example:
```
keys = [
"module_a.module_b.0.attn.to_k",
"module_a.module_b.1.attn.to_k",
"module_a.module_c.proj",
]
tree = generate_kohya_parsing_tree_from_keys(keys)
> {
> "module_a": {
> "module_b": {
> INDEX_PLACEHOLDER: {
> "attn": {
> "to_k": {},
> "to_q": {},
> },
> }
> },
> "module_c": {
> "proj": {},
> }
> }
> }
```
"""
tree: ParsingTree = {}
for key in keys:
subtree: ParsingTree = tree
for module_name in key.split("."):
key = module_name
if module_name.isnumeric():
key = INDEX_PLACEHOLDER
if key not in subtree:
subtree[key] = {}
subtree = subtree[key]
return tree

View File

@@ -54,9 +54,7 @@ GGML_TENSOR_OP_TABLE = {
torch.ops.aten.addmm.default: dequantize_and_run, # pyright: ignore
torch.ops.aten.mul.Tensor: dequantize_and_run, # pyright: ignore
torch.ops.aten.add.Tensor: dequantize_and_run, # pyright: ignore
torch.ops.aten.sub.Tensor: dequantize_and_run, # pyright: ignore
torch.ops.aten.allclose.default: dequantize_and_run, # pyright: ignore
torch.ops.aten.slice.Tensor: dequantize_and_run, # pyright: ignore
}
if torch.backends.mps.is_available():

View File

@@ -76,7 +76,6 @@
"konva": "^9.3.15",
"lodash-es": "^4.17.21",
"lru-cache": "^11.0.1",
"mtwist": "^1.0.2",
"nanoid": "^5.0.7",
"nanostores": "^0.11.3",
"new-github-issue-url": "^1.0.0",

View File

@@ -77,9 +77,6 @@ dependencies:
lru-cache:
specifier: ^11.0.1
version: 11.0.1
mtwist:
specifier: ^1.0.2
version: 1.0.2
nanoid:
specifier: ^5.0.7
version: 5.0.7
@@ -7019,10 +7016,6 @@ packages:
/ms@2.1.3:
resolution: {integrity: sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==}
/mtwist@1.0.2:
resolution: {integrity: sha512-eRsSga5jkLg7nNERPOV8vDNxgSwuEcj5upQfJcT0gXfJwXo3pMc7xOga0fu8rXHyrxzl7GFVWWDuaPQgpKDvgw==}
dev: false
/muggle-string@0.3.1:
resolution: {integrity: sha512-ckmWDJjphvd/FvZawgygcUeQCxzvohjFO5RxTjj4eq8kw359gFF3E1brjfI+viLMxss5JrHTDRHZvu2/tuy0Qg==}
dev: true

View File

@@ -99,15 +99,7 @@
"clipboard": "Zwischenablage",
"generating": "Generieren",
"loadingModel": "Lade Modell",
"warnings": "Warnungen",
"start": "Starten",
"count": "Anzahl",
"step": "Schritt",
"values": "Werte",
"min": "Min",
"max": "Max",
"resetToDefaults": "Auf Standard zurücksetzen",
"seed": "Seed"
"warnings": "Warnungen"
},
"gallery": {
"galleryImageSize": "Bildgröße",
@@ -127,6 +119,7 @@
"autoAssignBoardOnClick": "Board per Klick automatisch zuweisen",
"noImageSelected": "Kein Bild ausgewählt",
"starImage": "Bild markieren",
"assets": "Ressourcen",
"unstarImage": "Markierung entfernen",
"image": "Bild",
"deleteSelection": "Lösche Auswahl",
@@ -1289,16 +1282,7 @@
"unknownFieldType": "$t(nodes.unknownField) Typ: {{type}}",
"unknownField": "Unbekanntes Feld",
"unableToUpdateNodes_one": "{{count}} Knoten kann nicht aktualisiert werden",
"unableToUpdateNodes_other": "{{count}} Knoten können nicht aktualisiert werden",
"uniformRandomDistribution": "Uniforme Zufallsverteilung",
"linearDistribution": "Lineare Verteilung",
"generatorNRandomValues_one": "{{count}} Zufallswert",
"generatorNRandomValues_other": "{{count}} Zufallswerte",
"arithmeticSequence": "Arithmetische Folge",
"noBatchGroup": "keine Gruppe",
"generatorNoValues": "leer",
"generatorLoading": "wird geladen",
"generatorLoadFromFile": "Aus Datei laden"
"unableToUpdateNodes_other": "{{count}} Knoten können nicht aktualisiert werden"
},
"hrf": {
"enableHrf": "Korrektur für hohe Auflösungen",

View File

@@ -177,17 +177,7 @@
"none": "None",
"new": "New",
"generating": "Generating",
"warnings": "Warnings",
"start": "Start",
"count": "Count",
"step": "Step",
"end": "End",
"min": "Min",
"max": "Max",
"values": "Values",
"resetToDefaults": "Reset to Defaults",
"seed": "Seed",
"combinatorial": "Combinatorial"
"warnings": "Warnings"
},
"hrf": {
"hrf": "High Resolution Fix",
@@ -219,13 +209,9 @@
"pauseSucceeded": "Processor Paused",
"pauseFailed": "Problem Pausing Processor",
"cancel": "Cancel",
"cancelAllExceptCurrentQueueItemAlertDialog": "Canceling all queue items except the current one will stop pending items but allow the in-progress one to finish.",
"cancelAllExceptCurrentQueueItemAlertDialog2": "Are you sure you want to cancel all pending queue items?",
"cancelAllExceptCurrentTooltip": "Cancel All Except Current Item",
"cancelTooltip": "Cancel Current Item",
"cancelSucceeded": "Item Canceled",
"cancelFailed": "Problem Canceling Item",
"confirm": "Confirm",
"prune": "Prune",
"pruneTooltip": "Prune {{item_count}} Completed Items",
"pruneSucceeded": "Pruned {{item_count}} Completed Items from Queue",
@@ -298,14 +284,10 @@
"disableFailed": "Problem Disabling Invocation Cache",
"useCache": "Use Cache"
},
"modelCache": {
"clear": "Clear Model Cache",
"clearSucceeded": "Model Cache Cleared",
"clearFailed": "Problem Clearing Model Cache"
},
"gallery": {
"gallery": "Gallery",
"alwaysShowImageSizeBadge": "Always Show Image Size Badge",
"assets": "Assets",
"assetsTab": "Files youve uploaded for use in your projects.",
"autoAssignBoardOnClick": "Auto-Assign Board on Click",
"autoSwitchNewImages": "Auto-Switch to New Images",
@@ -868,19 +850,6 @@
"defaultVAE": "Default VAE"
},
"nodes": {
"arithmeticSequence": "Arithmetic Sequence",
"linearDistribution": "Linear Distribution",
"uniformRandomDistribution": "Uniform Random Distribution",
"parseString": "Parse String",
"splitOn": "Split On",
"noBatchGroup": "no group",
"generatorNRandomValues_one": "{{count}} random value",
"generatorNRandomValues_other": "{{count}} random values",
"generatorNoValues": "empty",
"generatorLoading": "loading",
"generatorLoadFromFile": "Load from File",
"dynamicPromptsRandom": "Dynamic Prompts (Random)",
"dynamicPromptsCombinatorial": "Dynamic Prompts (Combinatorial)",
"addNode": "Add Node",
"addNodeToolTip": "Add Node (Shift+A, Space)",
"addLinearView": "Add to Linear View",
@@ -1020,11 +989,7 @@
"imageAccessError": "Unable to find image {{image_name}}, resetting to default",
"boardAccessError": "Unable to find board {{board_id}}, resetting to default",
"modelAccessError": "Unable to find model {{key}}, resetting to default",
"saveToGallery": "Save To Gallery",
"addItem": "Add Item",
"generateValues": "Generate Values",
"floatRangeGenerator": "Float Range Generator",
"integerRangeGenerator": "Integer Range Generator"
"saveToGallery": "Save To Gallery"
},
"parameters": {
"aspect": "Aspect",
@@ -1059,22 +1024,11 @@
"addingImagesTo": "Adding images to",
"invoke": "Invoke",
"missingFieldTemplate": "Missing field template",
"missingInputForField": "missing input",
"missingInputForField": "{{nodeLabel}} -> {{fieldLabel}}: missing input",
"missingNodeTemplate": "Missing node template",
"emptyBatches": "empty batches",
"batchNodeNotConnected": "Batch node not connected: {{label}}",
"batchNodeEmptyCollection": "Some batch nodes have empty collections",
"invalidBatchConfigurationCannotCalculate": "Invalid batch configuration; cannot calculate",
"collectionTooFewItems": "too few items, minimum {{minItems}}",
"collectionTooManyItems": "too many items, maximum {{maxItems}}",
"collectionStringTooLong": "too long, max {{maxLength}}",
"collectionStringTooShort": "too short, min {{minLength}}",
"collectionNumberGTMax": "{{value}} > {{maximum}} (inc max)",
"collectionNumberLTMin": "{{value}} < {{minimum}} (inc min)",
"collectionNumberGTExclusiveMax": "{{value}} >= {{exclusiveMaximum}} (exc max)",
"collectionNumberLTExclusiveMin": "{{value}} <= {{exclusiveMinimum}} (exc min)",
"collectionNumberNotMultipleOf": "{{value}} not multiple of {{multipleOf}}",
"batchNodeCollectionSizeMismatch": "Collection size mismatch on Batch {{batchGroupId}}",
"collectionEmpty": "{{nodeLabel}} -> {{fieldLabel}} empty collection",
"collectionTooFewItems": "{{nodeLabel}} -> {{fieldLabel}}: too few items, minimum {{minItems}}",
"collectionTooManyItems": "{{nodeLabel}} -> {{fieldLabel}}: too many items, maximum {{maxItems}}",
"noModelSelected": "No model selected",
"noT5EncoderModelSelected": "No T5 Encoder model selected for FLUX generation",
"noFLUXVAEModelSelected": "No VAE model selected for FLUX generation",
@@ -1146,8 +1100,7 @@
"perPromptLabel": "Seed per Image",
"perPromptDesc": "Use a different seed for each image"
},
"loading": "Generating Dynamic Prompts...",
"promptsToGenerate": "Prompts to Generate"
"loading": "Generating Dynamic Prompts..."
},
"sdxl": {
"cfgScale": "CFG Scale",
@@ -1247,8 +1200,6 @@
"problemCopyingLayer": "Unable to Copy Layer",
"problemSavingLayer": "Unable to Save Layer",
"problemDownloadingImage": "Unable to Download Image",
"pasteSuccess": "Pasted to {{destination}}",
"pasteFailed": "Paste Failed",
"prunedQueue": "Pruned Queue",
"sentToCanvas": "Sent to Canvas",
"sentToUpscale": "Sent to Upscale",
@@ -1705,8 +1656,6 @@
"cropLayerToBbox": "Crop Layer to Bbox",
"savedToGalleryOk": "Saved to Gallery",
"savedToGalleryError": "Error saving to gallery",
"regionCopiedToClipboard": "{{region}} Copied to Clipboard",
"copyRegionError": "Error copying {{region}}",
"newGlobalReferenceImageOk": "Created Global Reference Image",
"newGlobalReferenceImageError": "Problem Creating Global Reference Image",
"newRegionalReferenceImageOk": "Created Regional Reference Image",
@@ -1817,14 +1766,6 @@
"newControlLayer": "New $t(controlLayers.controlLayer)",
"newInpaintMask": "New $t(controlLayers.inpaintMask)",
"newRegionalGuidance": "New $t(controlLayers.regionalGuidance)",
"pasteTo": "Paste To",
"pasteToAssets": "Assets",
"pasteToAssetsDesc": "Paste to Assets",
"pasteToBbox": "Bbox",
"pasteToBboxDesc": "New Layer (in Bbox)",
"pasteToCanvas": "Canvas",
"pasteToCanvasDesc": "New Layer (in Canvas)",
"pastedTo": "Pasted to {{destination}}",
"transparency": "Transparency",
"enableTransparencyEffect": "Enable Transparency Effect",
"disableTransparencyEffect": "Disable Transparency Effect",
@@ -1991,48 +1932,6 @@
"description": "Generates an edge map from the selected layer using the PiDiNet edge detection model.",
"scribble": "Scribble",
"quantize_edges": "Quantize Edges"
},
"img_blur": {
"label": "Blur Image",
"description": "Blurs the selected layer.",
"blur_type": "Blur Type",
"blur_radius": "Radius",
"gaussian_type": "Gaussian",
"box_type": "Box"
},
"img_noise": {
"label": "Noise Image",
"description": "Adds noise to the selected layer.",
"noise_type": "Noise Type",
"noise_amount": "Amount",
"gaussian_type": "Gaussian",
"salt_and_pepper_type": "Salt and Pepper",
"noise_color": "Colored Noise",
"size": "Noise Size"
},
"adjust_image": {
"label": "Adjust Image",
"description": "Adjusts the selected channel of an image.",
"channel": "Channel",
"value_setting": "Value",
"scale_values": "Scale Values",
"red": "Red (RGBA)",
"green": "Green (RGBA)",
"blue": "Blue (RGBA)",
"alpha": "Alpha (RGBA)",
"cyan": "Cyan (CMYK)",
"magenta": "Magenta (CMYK)",
"yellow": "Yellow (CMYK)",
"black": "Black (CMYK)",
"hue": "Hue (HSV)",
"saturation": "Saturation (HSV)",
"value": "Value (HSV)",
"luminosity": "Luminosity (LAB)",
"a": "A (LAB)",
"b": "B (LAB)",
"y": "Y (YCbCr)",
"cb": "Cb (YCbCr)",
"cr": "Cr (YCbCr)"
}
},
"transform": {
@@ -2106,10 +2005,7 @@
"newRasterLayer": "New Raster Layer",
"newInpaintMask": "New Inpaint Mask",
"newRegionalGuidance": "New Regional Guidance",
"cropCanvasToBbox": "Crop Canvas to Bbox",
"copyToClipboard": "Copy to Clipboard",
"copyCanvasToClipboard": "Copy Canvas to Clipboard",
"copyBboxToClipboard": "Copy Bbox to Clipboard"
"cropCanvasToBbox": "Crop Canvas to Bbox"
},
"stagingArea": {
"accept": "Accept",
@@ -2243,14 +2139,7 @@
},
"whatsNew": {
"whatsNewInInvoke": "What's New in Invoke",
"items": [
"Improved VRAM setting defaults",
"On-demand model cache clearing",
"Expanded FLUX LoRA compatibility",
"Canvas Adjust Image filter",
"Cancel all but current queue item",
"Copy from and paste to Canvas"
],
"items": ["Low-VRAM mode", "Dynamic memory management", "Faster model loading times", "Fewer memory errors"],
"readReleaseNotes": "Read Release Notes",
"watchRecentReleaseVideos": "Watch Recent Release Videos",
"watchUiUpdatesOverview": "Watch UI Updates Overview"

View File

@@ -109,6 +109,7 @@
"deleteImage_many": "Eliminar {{count}} Imágenes",
"deleteImage_other": "Eliminar {{count}} Imágenes",
"deleteImagePermanent": "Las imágenes eliminadas no se pueden restaurar.",
"assets": "Activos",
"autoAssignBoardOnClick": "Asignar automática tableros al hacer clic",
"gallery": "Galería",
"noImageSelected": "Sin imágenes seleccionadas",
@@ -900,7 +901,9 @@
}
},
"newUserExperience": {
"downloadStarterModels": "Descargar modelos de inicio",
"toGetStarted": "Para empezar, introduzca un mensaje en el cuadro y haga clic en <StrongComponent>Invocar</StrongComponent> para generar su primera imagen. Seleccione una plantilla para mejorar los resultados. Puede elegir guardar sus imágenes directamente en <StrongComponent>Galería</StrongComponent> o editarlas en <StrongComponent>Lienzo</StrongComponent>.",
"importModels": "Importar modelos",
"noModelsInstalled": "Parece que no tienes ningún modelo instalado",
"gettingStartedSeries": "¿Desea más orientación? Consulte nuestra <LinkComponent>Serie de introducción</LinkComponent> para obtener consejos sobre cómo aprovechar todo el potencial de Invoke Studio.",
"toGetStartedLocal": "Para empezar, asegúrate de descargar o importar los modelos necesarios para ejecutar Invoke. A continuación, introduzca un mensaje en el cuadro y haga clic en <StrongComponent>Invocar</StrongComponent> para generar su primera imagen. Seleccione una plantilla para mejorar los resultados. Puede elegir guardar sus imágenes directamente en <StrongComponent>Galería</StrongComponent> o editarlas en el <StrongComponent>Lienzo</StrongComponent>."

View File

@@ -114,6 +114,7 @@
"sortDirection": "Direction de tri",
"sideBySide": "Côte-à-Côte",
"hover": "Au passage de la souris",
"assets": "Ressources",
"alwaysShowImageSizeBadge": "Toujours montrer le badge de taille de l'Image",
"gallery": "Galerie",
"bulkDownloadRequestFailed": "Problème lors de la préparation du téléchargement",
@@ -351,6 +352,7 @@
"noT5EncoderModelSelected": "Aucun modèle T5 Encoder sélectionné pour la génération FLUX",
"fluxModelIncompatibleScaledBboxWidth": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), la largeur de la bounding box mise à l'échelle est {{width}}",
"canvasIsCompositing": "La toile est en train de composer",
"collectionEmpty": "{{nodeLabel}} -> {{fieldLabel}} collection vide",
"collectionTooFewItems": "{{nodeLabel}} -> {{fieldLabel}} : trop peu d'éléments, minimum {{minItems}}",
"collectionTooManyItems": "{{nodeLabel}} -> {{fieldLabel}} : trop d'éléments, maximum {{maxItems}}",
"canvasIsSelectingObject": "La toile est occupée (sélection d'objet)"
@@ -2169,6 +2171,8 @@
"toGetStarted": "Pour commencer, saisissez un prompt dans la boîte et cliquez sur <StrongComponent>Invoke</StrongComponent> pour générer votre première image. Sélectionnez un template de prompt pour améliorer les résultats. Vous pouvez choisir de sauvegarder vos images directement dans la <StrongComponent>Galerie</StrongComponent> ou de les modifier sur la <StrongComponent>Toile</StrongComponent>.",
"gettingStartedSeries": "Vous souhaitez plus de conseils? Consultez notre <LinkComponent>Série de démarrage</LinkComponent> pour des astuces sur l'exploitation du plein potentiel de l'Invoke Studio.",
"noModelsInstalled": "Il semble qu'aucun modèle ne soit installé",
"downloadStarterModels": "Télécharger les modèles de démarrage",
"importModels": "Importer des Modèles",
"toGetStartedLocal": "Pour commencer, assurez-vous de télécharger ou d'importer des modèles nécessaires pour exécuter Invoke. Ensuite, saisissez le prompt dans la boîte et cliquez sur <StrongComponent>Invoke</StrongComponent> pour générer votre première image. Sélectionnez un template de prompt pour améliorer les résultats. Vous pouvez choisir de sauvegarder vos images directement sur <StrongComponent>Galerie</StrongComponent> ou les modifier sur la <StrongComponent>Toile</StrongComponent>."
},
"upsell": {
@@ -2226,10 +2230,6 @@
"understandingImageToImageAndDenoising": {
"title": "Comprendre l'Image-à-Image et le Débruitage",
"description": "Aperçu des transformations d'image à image et du débruitage dans Invoke."
},
"howDoIOutpaint": {
"title": "Comment effectuer un outpainting?",
"description": "Guide pour l'extension au-delà des bordures de l'image originale."
}
},
"gettingStarted": "Commencer",

View File

@@ -97,15 +97,7 @@
"ok": "Ok",
"generating": "Generazione",
"loadingModel": "Caricamento del modello",
"warnings": "Avvisi",
"step": "Passo",
"values": "Valori",
"start": "Inizio",
"end": "Fine",
"resetToDefaults": "Ripristina le impostazioni predefinite",
"seed": "Seme",
"combinatorial": "Combinatorio",
"count": "Quantità"
"warnings": "Avvisi"
},
"gallery": {
"galleryImageSize": "Dimensione dell'immagine",
@@ -116,6 +108,7 @@
"deleteImage_many": "Elimina {{count}} immagini",
"deleteImage_other": "Elimina {{count}} immagini",
"deleteImagePermanent": "Le immagini eliminate non possono essere ripristinate.",
"assets": "Risorse",
"autoAssignBoardOnClick": "Assegna automaticamente la bacheca al clic",
"featuresWillReset": "Se elimini questa immagine, quelle funzionalità verranno immediatamente ripristinate.",
"loading": "Caricamento in corso",
@@ -675,7 +668,7 @@
"addingImagesTo": "Aggiungi immagini a",
"systemDisconnected": "Sistema disconnesso",
"missingNodeTemplate": "Modello di nodo mancante",
"missingInputForField": "ingresso mancante",
"missingInputForField": "{{nodeLabel}} -> {{fieldLabel}}: ingresso mancante",
"missingFieldTemplate": "Modello di campo mancante",
"fluxModelIncompatibleBboxHeight": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), altezza riquadro è {{height}}",
"fluxModelIncompatibleBboxWidth": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), larghezza riquadro è {{width}}",
@@ -688,22 +681,11 @@
"canvasIsRasterizing": "La tela è occupata (sta rasterizzando)",
"canvasIsCompositing": "La tela è occupata (in composizione)",
"canvasIsFiltering": "La tela è occupata (sta filtrando)",
"collectionTooManyItems": "troppi elementi, massimo {{maxItems}}",
"collectionTooManyItems": "{{nodeLabel}} -> {{fieldLabel}}: troppi elementi, massimo {{maxItems}}",
"canvasIsSelectingObject": "La tela è occupata (selezione dell'oggetto)",
"collectionTooFewItems": "troppi pochi elementi, minimo {{minItems}}",
"fluxModelMultipleControlLoRAs": "È possibile utilizzare solo 1 Controllo LoRA alla volta",
"collectionNumberGTMax": "{{value}} > {{maximum}} (incr max)",
"collectionStringTooLong": "troppo lungo, massimo {{maxLength}}",
"batchNodeNotConnected": "Nodo Lotto non connesso: {{label}}",
"batchNodeEmptyCollection": "Alcuni nodi lotto hanno raccolte vuote",
"emptyBatches": "lotti vuoti",
"batchNodeCollectionSizeMismatch": "Le dimensioni della raccolta nel Lotto {{batchGroupId}} non corrispondono",
"invalidBatchConfigurationCannotCalculate": "Configurazione lotto non valida; impossibile calcolare",
"collectionStringTooShort": "troppo corto, minimo {{minLength}}",
"collectionNumberNotMultipleOf": "{{value}} non è multiplo di {{multipleOf}}",
"collectionNumberLTMin": "{{value}} < {{minimum}} (incr min)",
"collectionNumberGTExclusiveMax": "{{value}} >= {{exclusiveMaximum}} (excl max)",
"collectionNumberLTExclusiveMin": "{{value}} <= {{exclusiveMinimum}} (excl min)"
"collectionTooFewItems": "{{nodeLabel}} -> {{fieldLabel}}: troppi pochi elementi, minimo {{minItems}}",
"collectionEmpty": "{{nodeLabel}} -> {{fieldLabel}} raccolta vuota",
"fluxModelMultipleControlLoRAs": "È possibile utilizzare solo 1 Controllo LoRA alla volta"
},
"useCpuNoise": "Usa la CPU per generare rumore",
"iterations": "Iterazioni",
@@ -831,8 +813,7 @@
"imagesWillBeAddedTo": "Le immagini caricate verranno aggiunte alle risorse della bacheca {{boardName}}.",
"uploadFailedInvalidUploadDesc_withCount_one": "Devi caricare al massimo 1 immagine PNG o JPEG.",
"uploadFailedInvalidUploadDesc_withCount_many": "Devi caricare al massimo {{count}} immagini PNG o JPEG.",
"uploadFailedInvalidUploadDesc_withCount_other": "Devi caricare al massimo {{count}} immagini PNG o JPEG.",
"outOfMemoryErrorDescLocal": "Segui la nostra <LinkComponent>guida per bassa VRAM</LinkComponent> per ridurre gli OOM."
"uploadFailedInvalidUploadDesc_withCount_other": "Devi caricare al massimo {{count}} immagini PNG o JPEG."
},
"accessibility": {
"invokeProgressBar": "Barra di avanzamento generazione",
@@ -991,25 +972,7 @@
"noWorkflows": "Nessun flusso di lavoro",
"workflowHelpText": "Hai bisogno di aiuto? Consulta la nostra guida <LinkComponent>Introduzione ai flussi di lavoro</LinkComponent>.",
"specialDesc": "Questa invocazione comporta una gestione speciale nell'applicazione. Ad esempio, i nodi Lotto vengono utilizzati per mettere in coda più grafici da un singolo flusso di lavoro.",
"internalDesc": "Questa invocazione è utilizzata internamente da Invoke. Potrebbe subire modifiche significative durante gli aggiornamenti dell'app e potrebbe essere rimossa in qualsiasi momento.",
"addItem": "Aggiungi elemento",
"generateValues": "Genera valori",
"generatorNoValues": "vuoto",
"linearDistribution": "Distribuzione lineare",
"parseString": "Analizza stringa",
"splitOn": "Diviso su",
"noBatchGroup": "nessun gruppo",
"generatorLoading": "caricamento",
"generatorLoadFromFile": "Carica da file",
"dynamicPromptsRandom": "Prompt dinamici (casuali)",
"dynamicPromptsCombinatorial": "Prompt dinamici (combinatori)",
"floatRangeGenerator": "Generatore di intervalli di numeri in virgola mobile",
"integerRangeGenerator": "Generatore di intervalli di numeri interi",
"uniformRandomDistribution": "Distribuzione casuale uniforme",
"generatorNRandomValues_one": "{{count}} valore casuale",
"generatorNRandomValues_many": "{{count}} valori casuali",
"generatorNRandomValues_other": "{{count}} valori casuali",
"arithmeticSequence": "Sequenza aritmetica"
"internalDesc": "Questa invocazione è utilizzata internamente da Invoke. Potrebbe subire modifiche significative durante gli aggiornamenti dell'app e potrebbe essere rimossa in qualsiasi momento."
},
"boards": {
"autoAddBoard": "Aggiungi automaticamente bacheca",
@@ -1131,11 +1094,7 @@
"generation": "Generazione",
"other": "Altro",
"gallery": "Galleria",
"batchSize": "Dimensione del lotto",
"cancelAllExceptCurrentQueueItemAlertDialog2": "Vuoi davvero annullare tutti gli elementi in coda in sospeso?",
"confirm": "Conferma",
"cancelAllExceptCurrentQueueItemAlertDialog": "L'annullamento di tutti gli elementi della coda, eccetto quello corrente, interromperà gli elementi in sospeso ma consentirà il completamento di quello in corso.",
"cancelAllExceptCurrentTooltip": "Annulla tutto tranne l'elemento corrente"
"batchSize": "Dimensione del lotto"
},
"models": {
"noMatchingModels": "Nessun modello corrispondente",
@@ -1179,8 +1138,7 @@
"dynamicPrompts": "Prompt dinamici",
"promptsPreview": "Anteprima dei prompt",
"showDynamicPrompts": "Mostra prompt dinamici",
"loading": "Generazione prompt dinamici...",
"promptsToGenerate": "Prompt da generare"
"loading": "Generazione prompt dinamici..."
},
"popovers": {
"paramScheduler": {
@@ -1949,43 +1907,7 @@
},
"forMoreControl": "Per un maggiore controllo, fare clic su Avanzate qui sotto.",
"advanced": "Avanzate",
"processingLayerWith": "Elaborazione del livello con il filtro {{type}}.",
"img_blur": {
"label": "Sfoca immagine",
"description": "Sfoca il livello selezionato.",
"blur_type": "Tipo di sfocatura",
"blur_radius": "Raggio",
"gaussian_type": "Gaussiana"
},
"img_noise": {
"size": "Dimensione del rumore",
"salt_and_pepper_type": "Sale e pepe",
"gaussian_type": "Gaussiano",
"noise_color": "Rumore colorato",
"description": "Aggiunge rumore al livello selezionato.",
"noise_type": "Tipo di rumore",
"label": "Aggiungi rumore",
"noise_amount": "Quantità"
},
"adjust_image": {
"description": "Regola il canale selezionato di un'immagine.",
"alpha": "Alfa (RGBA)",
"label": "Regola l'immagine",
"blue": "Blu (RGBA)",
"luminosity": "Luminosità (LAB)",
"channel": "Canale",
"value_setting": "Valore",
"scale_values": "Scala i valori",
"red": "Rosso (RGBA)",
"green": "Verde (RGBA)",
"cyan": "Ciano (CMYK)",
"magenta": "Magenta (CMYK)",
"yellow": "Giallo (CMYK)",
"black": "Nero (CMYK)",
"hue": "Tonalità (HSV)",
"saturation": "Saturazione (HSV)",
"value": "Valore (HSV)"
}
"processingLayerWith": "Elaborazione del livello con il filtro {{type}}."
},
"controlLayers_withCount_hidden": "Livelli di controllo ({{count}} nascosti)",
"regionalGuidance_withCount_hidden": "Guida regionale ({{count}} nascosti)",
@@ -2244,9 +2166,10 @@
"newUserExperience": {
"gettingStartedSeries": "Desideri maggiori informazioni? Consulta la nostra <LinkComponent>Getting Started Series</LinkComponent> per suggerimenti su come sfruttare appieno il potenziale di Invoke Studio.",
"toGetStarted": "Per iniziare, inserisci un prompt nella casella e fai clic su <StrongComponent>Invoke</StrongComponent> per generare la tua prima immagine. Seleziona un modello di prompt per migliorare i risultati. Puoi scegliere di salvare le tue immagini direttamente nella <StrongComponent>Galleria</StrongComponent> o modificarle nella <StrongComponent>Tela</StrongComponent>.",
"noModelsInstalled": "Sembra che non hai installato alcun modello! Puoi <DownloadStarterModelsButton>scaricare un pacchetto di modelli di avvio</DownloadStarterModelsButton> o <ImportModelsButton>importare modelli</ImportModelsButton>.",
"toGetStartedLocal": "Per iniziare, assicurati di scaricare o importare i modelli necessari per eseguire Invoke. Quindi, inserisci un prompt nella casella e fai clic su <StrongComponent>Invoke</StrongComponent> per generare la tua prima immagine. Seleziona un modello di prompt per migliorare i risultati. Puoi scegliere di salvare le tue immagini direttamente nella <StrongComponent>Galleria</StrongComponent> o modificarle nella <StrongComponent>Tela</StrongComponent>.",
"lowVRAMMode": "Per prestazioni ottimali, segui la nostra <LinkComponent>guida per bassa VRAM</LinkComponent>."
"importModels": "Importa modelli",
"downloadStarterModels": "Scarica i modelli per iniziare",
"noModelsInstalled": "Sembra che tu non abbia installato alcun modello",
"toGetStartedLocal": "Per iniziare, assicurati di scaricare o importare i modelli necessari per eseguire Invoke. Quindi, inserisci un prompt nella casella e fai clic su <StrongComponent>Invoke</StrongComponent> per generare la tua prima immagine. Seleziona un modello di prompt per migliorare i risultati. Puoi scegliere di salvare le tue immagini direttamente nella <StrongComponent>Galleria</StrongComponent> o modificarle nella <StrongComponent>Tela</StrongComponent>."
},
"whatsNew": {
"whatsNewInInvoke": "Novità in Invoke",
@@ -2254,11 +2177,7 @@
"watchRecentReleaseVideos": "Guarda i video su questa versione",
"watchUiUpdatesOverview": "Guarda le novità dell'interfaccia",
"items": [
"Modalità Bassa-VRAM",
"Gestione dinamica della memoria",
"Tempi di caricamento del modello più rapidi",
"Meno errori di memoria",
"Funzionalità lotto del flusso di lavoro ampliate"
"<StrongComponent>Livelli di controllo Flux</StrongComponent>: nuovi modelli di controllo per il rilevamento dei bordi e la mappatura della profondità sono ora supportati per i modelli di Flux dev."
]
},
"system": {
@@ -2348,10 +2267,5 @@
"watch": "Guarda",
"studioSessionsDesc1": "Dai un'occhiata a <StudioSessionsPlaylistLink /> per approfondimenti su Invoke.",
"studioSessionsDesc2": "Unisciti al nostro <DiscordLink /> per partecipare alle sessioni live e fare domande. Le sessioni vengono caricate sulla playlist la settimana successiva."
},
"modelCache": {
"clear": "Cancella la cache del modello",
"clearSucceeded": "Cache del modello cancellata",
"clearFailed": "Problema durante la cancellazione della cache del modello"
}
}

View File

@@ -106,6 +106,7 @@
"featuresWillReset": "この画像を削除すると、これらの機能は即座にリセットされます。",
"unstarImage": "スターを外す",
"loading": "ロード中",
"assets": "アセット",
"currentlyInUse": "この画像は現在下記の機能を使用しています:",
"drop": "ドロップ",
"dropOrUpload": "$t(gallery.drop) またはアップロード",

View File

@@ -68,6 +68,7 @@
"gallerySettings": "갤러리 설정",
"deleteSelection": "선택 항목 삭제",
"featuresWillReset": "이 이미지를 삭제하면 해당 기능이 즉시 재설정됩니다.",
"assets": "자산",
"noImagesInGallery": "보여줄 이미지가 없음",
"autoSwitchNewImages": "새로운 이미지로 자동 전환",
"loading": "불러오는 중",

View File

@@ -90,6 +90,7 @@
"deleteImage_one": "Verwijder afbeelding",
"deleteImage_other": "",
"deleteImagePermanent": "Verwijderde afbeeldingen kunnen niet worden hersteld.",
"assets": "Eigen onderdelen",
"autoAssignBoardOnClick": "Ken automatisch bord toe bij klikken",
"featuresWillReset": "Als je deze afbeelding verwijdert, dan worden deze functies onmiddellijk teruggezet.",
"loading": "Bezig met laden",

View File

@@ -105,6 +105,7 @@
"assetsTab": "Pliki, które wrzuciłeś do użytku w twoich projektach.",
"currentlyInUse": "Ten obraz jest obecnie w użyciu przez następujące funkcje:",
"boardsSettings": "Ustawienia tablic",
"assets": "Aktywy",
"autoAssignBoardOnClick": "Automatycznie przypisz tablicę po kliknięciu",
"copy": "Kopiuj"
},

View File

@@ -106,6 +106,7 @@
"deleteImage_one": "Удалить изображение",
"deleteImage_few": "Удалить {{count}} изображения",
"deleteImage_many": "Удалить {{count}} изображений",
"assets": "Ресурсы",
"autoAssignBoardOnClick": "Авто-назначение доски по клику",
"deleteSelection": "Удалить выделенное",
"featuresWillReset": "Если вы удалите это изображение, эти функции будут немедленно сброшены.",

View File

@@ -195,6 +195,7 @@
},
"gallery": {
"deleteImagePermanent": "Silinen görseller geri getirilemez.",
"assets": "Özkaynaklar",
"autoAssignBoardOnClick": "Tıklanan Panoya Otomatik Atama",
"loading": "Yükleniyor",
"starImage": "Yıldız Koy",

View File

@@ -86,6 +86,7 @@
"bulkDownloadRequestedDesc": "Yêu cầu tải xuống đang được chuẩn bị. Vui lòng chờ trong giây lát.",
"starImage": "Gắn Sao Cho Ảnh",
"openViewer": "Mở Trình Xem",
"assets": "Tài Nguyên",
"viewerImage": "Trình Xem Ảnh",
"sideBySide": "Cạnh Nhau",
"alwaysShowImageSizeBadge": "Luôn Hiển Thị Kích Thước Ảnh",
@@ -219,17 +220,7 @@
"tab": "Tab",
"loadingModel": "Đang Tải Model",
"generating": "Đang Tạo Sinh",
"warnings": "Cảnh Báo",
"count": "Đếm",
"step": "Bước",
"values": "Giá Trị",
"start": "Bắt Đầu",
"end": "Kết Thúc",
"min": "Tối Thiểu",
"max": "Tối Đa",
"resetToDefaults": "Đặt Lại Về Mặc Định",
"seed": "Hạt Giống",
"combinatorial": "Tổ Hợp"
"warnings": "Cảnh Báo"
},
"prompt": {
"addPromptTrigger": "Thêm Prompt Trigger",
@@ -974,23 +965,7 @@
"outputFieldTypeParseError": "Không thể phân tích loại dữ liệu đầu ra của {{node}}.{{field}} ({{message}})",
"modelAccessError": "Không thể tìm thấy model {{key}}, chuyển về mặc định",
"internalDesc": "Trình kích hoạt này được dùng bên trong bởi Invoke. Nó có thể phá hỏng thay đổi trong khi cập nhật ứng dụng và có thể bị xoá bất cứ lúc nào.",
"specialDesc": "Trình kích hoạt này có một số xử lý đặc biệt trong ứng dụng. Ví dụ, Node Hàng Loạt được dùng để xếp vào nhiều đồ thị từ một workflow.",
"addItem": "Thêm Mục",
"generateValues": "Cho Ra Giá Trị",
"floatRangeGenerator": "Phạm Vị Tạo Ra Số Thực",
"integerRangeGenerator": "Phạm Vị Tạo Ra Số Nguyên",
"linearDistribution": "Phân Bố Tuyến Tính",
"uniformRandomDistribution": "Phân Bố Ngẫu Nhiên Đồng Nhất",
"parseString": "Phân Tích Chuỗi",
"noBatchGroup": "không có nhóm",
"generatorNoValues": "trống",
"splitOn": "Tách Ở",
"arithmeticSequence": "Cấp Số Cộng",
"generatorNRandomValues_other": "{{count}} giá trị ngẫu nhiên",
"generatorLoading": "đang tải",
"generatorLoadFromFile": "Tải Từ Tệp",
"dynamicPromptsRandom": "Dynamic Prompts (Ngẫu Nhiên)",
"dynamicPromptsCombinatorial": "Dynamic Prompts (Tổ Hợp)"
"specialDesc": "Trình kích hoạt này có một số xử lý đặc biệt trong ứng dụng. Ví dụ, Node Hàng Loạt được dùng để xếp vào nhiều đồ thị từ một workflow."
},
"popovers": {
"paramCFGRescaleMultiplier": {
@@ -1458,24 +1433,13 @@
"missingNodeTemplate": "Thiếu mẫu trình bày node",
"fluxModelIncompatibleBboxHeight": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), chiều dài hộp giới hạn là {{height}}",
"fluxModelIncompatibleScaledBboxWidth": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), tỉ lệ chiều rộng hộp giới hạn là {{width}}",
"missingInputForField": "thiếu đầu vào",
"missingInputForField": "{{nodeLabel}} -> {{fieldLabel}}: thiếu đầu vào",
"missingFieldTemplate": "Thiếu vùng mẫu trình bày",
"collectionTooFewItems": "quá ít mục, ti thiểu là {{minItems}}",
"collectionTooManyItems": "quá nhiều mục, tối đa là {{maxItems}}",
"collectionEmpty": "{{nodeLabel}} -> {{fieldLabel}} tài nguyên trống",
"collectionTooFewItems": "{{nodeLabel}} -> {{fieldLabel}}: quá ít mục, tối thiểu {{minItems}}",
"collectionTooManyItems": "{{nodeLabel}} -> {{fieldLabel}}: quá nhiều mục, tối đa {{maxItems}}",
"canvasIsSelectingObject": "Canvas đang bận (đang chọn đồ vật)",
"fluxModelMultipleControlLoRAs": "Chỉ có thể dùng 1 LoRA Điều Khiển Được",
"collectionStringTooLong": "quá dài, tối đa là {{maxLength}}",
"collectionStringTooShort": "quá ngắn, tối thiểu là {{minLength}}",
"collectionNumberGTMax": "{{value}} > {{maximum}} (giá trị tối đa)",
"collectionNumberLTMin": "{{value}} < {{minimum}} (giá trị tối thiểu)",
"collectionNumberNotMultipleOf": "{{value}} không phải bội của {{multipleOf}}",
"collectionNumberLTExclusiveMin": "{{value}} <= {{exclusiveMinimum}} (giá trị chọn lọc tối thiểu)",
"collectionNumberGTExclusiveMax": "{{value}} >= {{exclusiveMaximum}} (giá trị chọn lọc tối đa)",
"batchNodeCollectionSizeMismatch": "Kích cỡ tài nguyên không phù hợp với Lô {{batchGroupId}}",
"emptyBatches": "lô trống",
"batchNodeNotConnected": "Node Hàng Loạt chưa được kết nối: {{label}}",
"batchNodeEmptyCollection": "Một vài node hàng loạt có tài nguyên rỗng",
"invalidBatchConfigurationCannotCalculate": "Thiết lập lô không hợp lệ; không thể tính toán"
"fluxModelMultipleControlLoRAs": "Chỉ có thể dùng 1 LoRA Điều Khiển Được"
},
"cfgScale": "Thang CFG",
"useSeed": "Dùng Hạt Giống",
@@ -1494,8 +1458,8 @@
"recallMetadata": "Gợi Lại Metadata",
"clipSkip": "CLIP Skip",
"general": "Cài Đặt Chung",
"boxBlur": "Làm Mờ Dạng Box",
"gaussianBlur": "Làm Mờ Dạng Gaussian",
"boxBlur": "Box Blur",
"gaussianBlur": "Gaussian Blur",
"staged": "Staged (Tăng khử nhiễu có hệ thống)",
"scaledHeight": "Tỉ Lệ Dài",
"cancel": {
@@ -1548,12 +1512,11 @@
"perPromptLabel": "Một Hạt Giống Mỗi Ảnh",
"perIterationLabel": "Hạt Giống Mỗi Lần Lặp Lại"
},
"loading": "Tạo Sinh Bằng Dynamic Prompt...",
"loading": "Tạo Sinh ng Dynamic Prompt...",
"showDynamicPrompts": "HIện Dynamic Prompt",
"maxPrompts": "Số Lệnh Tối Đa",
"promptsPreview": "Xem Trước Lệnh",
"dynamicPrompts": "Dynamic Prompt",
"promptsToGenerate": "Lệnh Để Tạo Sinh"
"dynamicPrompts": "Dynamic Prompt"
},
"settings": {
"beta": "Beta",
@@ -1896,25 +1859,7 @@
},
"advanced": "Nâng Cao",
"processingLayerWith": "Đang xử lý layer với bộ lọc {{type}}.",
"forMoreControl": "Để kiểm soát tốt hơn, bấm vào mục Nâng Cao bên dưới.",
"img_blur": {
"description": "Làm mờ layer được chọn.",
"blur_type": "Dạng Làm Mờ",
"blur_radius": "Radius",
"gaussian_type": "Gaussian",
"label": "Làm Mờ Ảnh",
"box_type": "Box"
},
"img_noise": {
"salt_and_pepper_type": "Salt and Pepper",
"noise_amount": "Lượng Nhiễu",
"label": "Độ Nhiễu Ảnh",
"description": "Tăng độ nhiễu vào layer được chọn.",
"noise_type": "Dạng Nhiễu",
"gaussian_type": "Gaussian",
"noise_color": "Màu Nhiễu",
"size": "Cỡ Nhiễu"
}
"forMoreControl": "Để kiểm soát tốt hơn, bấm vào mục Nâng Cao bên dưới."
},
"transform": {
"fitModeCover": "Che Phủ",
@@ -2122,8 +2067,7 @@
"problemCopyingImage": "Không Thể Sao Chép Ảnh",
"problemDownloadingImage": "Không Thể Tải Xuống Ảnh",
"problemCopyingLayer": "Không Thể Sao Chép Layer",
"problemSavingLayer": "Không Thể Lưu Layer",
"outOfMemoryErrorDescLocal": "Làm theo <LinkComponent>hướng dẫn VRAM Thấp</LinkComponent> của chúng tôi để hạn chế OOM (Tràn bộ nhớ)."
"problemSavingLayer": "Không Thể Lưu Layer"
},
"ui": {
"tabs": {
@@ -2209,8 +2153,9 @@
"toGetStartedLocal": "Để bắt đầu, hãy chắc chắn đã tải xuống hoặc thêm vào model cần để chạy Invoke. Sau đó, nhập lệnh vào hộp và nhấp chuột vào <StrongComponent>Kích Hoạt</StrongComponent> để tạo ra bức ảnh đầu tiên. Chọn một mẫu trình bày cho lệnh để cải thiện kết quả. Bạn có thể chọn để lưu ảnh trực tiếp vào <StrongComponent>Thư Viện</StrongComponent> hoặc chỉnh sửa chúng ở <StrongComponent>Canvas</StrongComponent>.",
"gettingStartedSeries": "Cần thêm hướng dẫn? Xem thử <LinkComponent>Bắt Đầu Làm Quen</LinkComponent> để biết thêm mẹo khai thác toàn bộ tiềm năng của Invoke Studio.",
"toGetStarted": "Để bắt đầu, hãy nhập lệnh vào hộp và nhấp chuột vào <StrongComponent>Kích Hoạt</StrongComponent> để tạo ra bức ảnh đầu tiên. Chọn một mẫu trình bày cho lệnh để cải thiện kết quả. Bạn có thể chọn để lưu ảnh trực tiếp vào <StrongComponent>Thư Viện</StrongComponent> hoặc chỉnh sửa chúng ở <StrongComponent>Canvas</StrongComponent>.",
"noModelsInstalled": "Dường như bạn chưa tải model nào cả! Bạn có thể <DownloadStarterModelsButton>tải xuống các model khởi đầu</DownloadStarterModelsButton> hoặc <ImportModelsButton>nhập vào thêm model</ImportModelsButton>.",
"lowVRAMMode": "Cho hiệu suất tốt nhất, hãy làm theo <LinkComponent>hướng dẫn VRAM Thấp</LinkComponent> của chúng tôi."
"downloadStarterModels": "Tải Xuống Model Khởi Đầu",
"importModels": "Nhập Vào Model",
"noModelsInstalled": "Hình như bạn không có model nào được tải cả"
},
"whatsNew": {
"whatsNewInInvoke": "Có Gì Mới Ở Invoke",
@@ -2218,11 +2163,7 @@
"watchRecentReleaseVideos": "Xem Video Phát Hành Mới Nhất",
"watchUiUpdatesOverview": "Xem Tổng Quan Về Những Cập Nhật Cho Giao Diện Người Dùng",
"items": [
"Chế độ VRAM thấp",
"Trình quản lý bộ nhớ động",
"Tải model nhanh hơn",
"Ít lỗi bộ nhớ hơn",
"Mở rộng khả năng xử lý hàng loạt workflow"
"<StrongComponent>Hướng Dẫn Khu Vực FLUX (beta)</StrongComponent>: Bản beta của Hướng Dẫn Khu Vực FLUX của chúng ta đã có mắt tại bảng điều khiển lệnh khu vực."
]
},
"upsell": {
@@ -2292,10 +2233,5 @@
},
"controlCanvas": "Điều Khiển Canvas",
"watch": "Xem"
},
"modelCache": {
"clearSucceeded": "Cache Model Đã Được Dọn",
"clearFailed": "Có Vấn Đề Khi Dọn Cache Model",
"clear": "Dọn Cache Model"
}
}

View File

@@ -107,6 +107,7 @@
"noImagesInGallery": "无图像可用于显示",
"deleteImage_other": "删除{{count}}张图片",
"deleteImagePermanent": "删除的图片无法被恢复。",
"assets": "素材",
"autoAssignBoardOnClick": "点击后自动分配面板",
"featuresWillReset": "如果您删除该图像,这些功能会立即被重置。",
"loading": "加载中",

View File

@@ -12,12 +12,10 @@ import { useFocusRegionWatcher } from 'common/hooks/focus';
import { useClearStorage } from 'common/hooks/useClearStorage';
import { useGlobalHotkeys } from 'common/hooks/useGlobalHotkeys';
import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardModal';
import { CanvasPasteModal } from 'features/controlLayers/components/CanvasPasteModal';
import {
NewCanvasSessionDialog,
NewGallerySessionDialog,
} from 'features/controlLayers/components/NewSessionConfirmationAlertDialog';
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageModal';
import { FullscreenDropzone } from 'features/dnd/FullscreenDropzone';
import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicPromptsPreviewModal';
@@ -25,7 +23,6 @@ import DeleteBoardModal from 'features/gallery/components/Boards/DeleteBoardModa
import { ImageContextMenu } from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast';
import { ShareWorkflowModal } from 'features/nodes/components/sidePanel/WorkflowListMenu/ShareWorkflowModal';
import { CancelAllExceptCurrentQueueItemConfirmationAlertDialog } from 'features/queue/components/CancelAllExceptCurrentQueueItemConfirmationAlertDialog';
import { ClearQueueConfirmationsAlertDialog } from 'features/queue/components/ClearQueueConfirmationAlertDialog';
import { DeleteStylePresetDialog } from 'features/stylePresets/components/DeleteStylePresetDialog';
import { StylePresetModal } from 'features/stylePresets/components/StylePresetForm/StylePresetModal';
@@ -100,7 +97,6 @@ const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
<ChangeBoardModal />
<DynamicPromptsModal />
<StylePresetModal />
<CancelAllExceptCurrentQueueItemConfirmationAlertDialog />
<ClearQueueConfirmationsAlertDialog />
<NewWorkflowConfirmationAlertDialog />
<DeleteStylePresetDialog />
@@ -114,9 +110,6 @@ const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
<ImageContextMenu />
<FullscreenDropzone />
<VideosModal />
<CanvasManagerProviderGate>
<CanvasPasteModal />
</CanvasManagerProviderGate>
</ErrorBoundary>
);
};

View File

@@ -3,7 +3,7 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
import { zPydanticValidationError } from 'features/system/store/zodSchemas';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { truncate } from 'lodash-es';
import { truncate, upperFirst } from 'lodash-es';
import { serializeError } from 'serialize-error';
import { queueApi } from 'services/api/endpoints/queue';
import type { JsonObject } from 'type-fest';
@@ -52,12 +52,15 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) =
const result = zPydanticValidationError.safeParse(response);
if (result.success) {
result.data.data.detail.map((e) => {
const description = truncate(e.msg.replace(/^(Value|Index|Key) error, /i, ''), { length: 256 });
toast({
id: 'QUEUE_BATCH_FAILED',
title: t('queue.batchFailedToQueue'),
title: truncate(upperFirst(e.msg), { length: 128 }),
status: 'error',
description,
description: truncate(
`Path:
${e.loc.join('.')}`,
{ length: 128 }
),
});
});
} else if (response.status !== 403) {

View File

@@ -1,14 +1,16 @@
import { logger } from 'app/logging/logger';
import { enqueueRequested } from 'app/store/actions';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectNodesSlice } from 'features/nodes/store/selectors';
import { isBatchNode, isInvocationNode } from 'features/nodes/types/invocation';
import { isImageFieldCollectionInputInstance } from 'features/nodes/types/field';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { buildNodesGraph } from 'features/nodes/util/graph/buildNodesGraph';
import { buildWorkflowWithValidation } from 'features/nodes/util/workflow/buildWorkflow';
import { resolveBatchValue } from 'features/queue/store/readiness';
import { groupBy } from 'lodash-es';
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
import type { Batch, BatchConfig } from 'services/api/types';
const log = logger('workflows');
export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) => {
startAppListening({
predicate: (action): action is ReturnType<typeof enqueueRequested> =>
@@ -31,54 +33,28 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
const data: Batch['data'] = [];
const invocationNodes = nodes.nodes.filter(isInvocationNode);
const batchNodes = invocationNodes.filter(isBatchNode);
// Handle zipping batch nodes. First group the batch nodes by their batch_group_id
const groupedBatchNodes = groupBy(batchNodes, (node) => node.data.inputs['batch_group_id']?.value);
// Then, we will create a batch data collection item for each group
for (const [batchGroupId, batchNodes] of Object.entries(groupedBatchNodes)) {
const zippedBatchDataCollectionItems: NonNullable<Batch['data']>[number] = [];
for (const node of batchNodes) {
const value = resolveBatchValue(node, invocationNodes, nodes.edges);
const sourceHandle = node.data.type === 'image_batch' ? 'image' : 'value';
const edgesFromBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === sourceHandle);
if (batchGroupId !== 'None') {
// If this batch node has a batch_group_id, we will zip the data collection items
for (const edge of edgesFromBatch) {
if (!edge.targetHandle) {
break;
}
zippedBatchDataCollectionItems.push({
node_path: edge.target,
field_name: edge.targetHandle,
items: value,
});
}
} else {
// Otherwise add the data collection items to root of the batch so they are not zipped
const productBatchDataCollectionItems: NonNullable<Batch['data']>[number] = [];
for (const edge of edgesFromBatch) {
if (!edge.targetHandle) {
break;
}
productBatchDataCollectionItems.push({
node_path: edge.target,
field_name: edge.targetHandle,
items: value,
});
}
if (productBatchDataCollectionItems.length > 0) {
data.push(productBatchDataCollectionItems);
}
}
// Skip edges from batch nodes - these should not be in the graph, they exist only in the UI
const imageBatchNodes = nodes.nodes.filter(isInvocationNode).filter((node) => node.data.type === 'image_batch');
for (const node of imageBatchNodes) {
const images = node.data.inputs['images'];
if (!isImageFieldCollectionInputInstance(images)) {
log.warn({ nodeId: node.id }, 'Image batch images field is not an image collection');
break;
}
// Finally, if this batch data collection item has any items, add it to the data array
if (batchGroupId !== 'None' && zippedBatchDataCollectionItems.length > 0) {
data.push(zippedBatchDataCollectionItems);
const edgesFromImageBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'image');
const batchDataCollectionItem: NonNullable<Batch['data']>[number] = [];
for (const edge of edgesFromImageBatch) {
if (!edge.targetHandle) {
break;
}
batchDataCollectionItem.push({
node_path: edge.target,
field_name: edge.targetHandle,
items: images.value,
});
}
if (batchDataCollectionItem.length > 0) {
data.push(batchDataCollectionItem);
}
}

View File

@@ -23,7 +23,6 @@ export type AppFeature =
| 'pauseQueue'
| 'resumeQueue'
| 'invocationCache'
| 'modelCache'
| 'bulkDownload'
| 'starterModels'
| 'hfToken';

View File

@@ -2,7 +2,6 @@ import { Menu, MenuButton, MenuGroup, MenuItem, MenuList } from '@invoke-ai/ui-l
import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu';
import { CanvasContextMenuItemsCropCanvasToBbox } from 'features/controlLayers/components/CanvasContextMenu/CanvasContextMenuItemsCropCanvasToBbox';
import { NewLayerIcon } from 'features/controlLayers/components/common/icons';
import { useCopyCanvasToClipboard } from 'features/controlLayers/hooks/copyHooks';
import {
useNewControlLayerFromBbox,
useNewGlobalReferenceImageFromBbox,
@@ -14,13 +13,12 @@ import {
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiCopyBold, PiFloppyDiskBold } from 'react-icons/pi';
import { PiFloppyDiskBold } from 'react-icons/pi';
export const CanvasContextMenuGlobalMenuItems = memo(() => {
const { t } = useTranslation();
const saveSubMenu = useSubMenu();
const newSubMenu = useSubMenu();
const copySubMenu = useSubMenu();
const isBusy = useCanvasIsBusy();
const saveCanvasToGallery = useSaveCanvasToGallery();
const saveBboxToGallery = useSaveBboxToGallery();
@@ -28,8 +26,6 @@ export const CanvasContextMenuGlobalMenuItems = memo(() => {
const newGlobalReferenceImageFromBbox = useNewGlobalReferenceImageFromBbox();
const newRasterLayerFromBbox = useNewRasterLayerFromBbox();
const newControlLayerFromBbox = useNewControlLayerFromBbox();
const copyCanvasToClipboard = useCopyCanvasToClipboard('canvas');
const copyBboxToClipboard = useCopyCanvasToClipboard('bbox');
return (
<>
@@ -71,21 +67,6 @@ export const CanvasContextMenuGlobalMenuItems = memo(() => {
</MenuList>
</Menu>
</MenuItem>
<MenuItem {...copySubMenu.parentMenuItemProps} icon={<PiCopyBold />}>
<Menu {...copySubMenu.menuProps}>
<MenuButton {...copySubMenu.menuButtonProps}>
<SubMenuButtonContent label={t('controlLayers.canvasContextMenu.copyToClipboard')} />
</MenuButton>
<MenuList {...copySubMenu.menuListProps}>
<MenuItem icon={<PiCopyBold />} isDisabled={isBusy} onClick={copyCanvasToClipboard}>
{t('controlLayers.canvasContextMenu.copyCanvasToClipboard')}
</MenuItem>
<MenuItem icon={<PiCopyBold />} isDisabled={isBusy} onClick={copyBboxToClipboard}>
{t('controlLayers.canvasContextMenu.copyBboxToClipboard')}
</MenuItem>
</MenuList>
</Menu>
</MenuItem>
</MenuGroup>
</>
);

View File

@@ -1,150 +0,0 @@
import {
Button,
Flex,
Modal,
ModalBody,
ModalCloseButton,
ModalContent,
ModalFooter,
ModalHeader,
ModalOverlay,
} from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppStore } from 'app/store/nanostores/store';
import { useAppSelector } from 'app/store/storeHooks';
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
import { createNewCanvasEntityFromImage } from 'features/imageActions/actions';
import { toast } from 'features/toast/toast';
import { atom } from 'nanostores';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiBoundingBoxBold, PiImageBold } from 'react-icons/pi';
import { useUploadImageMutation } from 'services/api/endpoints/images';
const $imageFile = atom<File | null>(null);
export const setFileToPaste = (file: File) => $imageFile.set(file);
const clearFileToPaste = () => $imageFile.set(null);
export const CanvasPasteModal = memo(() => {
useAssertSingleton('CanvasPasteModal');
const { dispatch, getState } = useAppStore();
const { t } = useTranslation();
const imageToPaste = useStore($imageFile);
const canvasManager = useCanvasManager();
const autoAddBoardId = useAppSelector(selectAutoAddBoardId);
const [uploadImage, { isLoading }] = useUploadImageMutation({ fixedCacheKey: 'canvasPasteModal' });
const getPosition = useCallback(
(destination: 'canvas' | 'bbox') => {
const { x, y } = canvasManager.stateApi.getBbox().rect;
if (destination === 'bbox') {
return { x, y };
}
const rasterLayerAdapters = canvasManager.compositor.getVisibleAdaptersOfType('raster_layer');
if (rasterLayerAdapters.length === 0) {
return { x, y };
}
{
const { x, y } = canvasManager.compositor.getRectOfAdapters(rasterLayerAdapters);
return { x, y };
}
},
[canvasManager.compositor, canvasManager.stateApi]
);
const handlePaste = useCallback(
async (file: File, destination: 'assets' | 'canvas' | 'bbox') => {
try {
const is_intermediate = destination !== 'assets';
const imageDTO = await uploadImage({
file,
is_intermediate,
image_category: 'user',
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
}).unwrap();
if (destination !== 'assets') {
createNewCanvasEntityFromImage({
type: 'raster_layer',
imageDTO,
dispatch,
getState,
overrides: { position: getPosition(destination) },
});
}
} catch {
toast({
title: t('toast.pasteFailed'),
status: 'error',
});
} finally {
clearFileToPaste();
toast({
title: t('toast.pasteSuccess', {
destination:
destination === 'assets'
? t('controlLayers.pasteToAssets')
: destination === 'bbox'
? t('controlLayers.pasteToBbox')
: t('controlLayers.pasteToCanvas'),
}),
status: 'success',
});
}
},
[autoAddBoardId, dispatch, getPosition, getState, t, uploadImage]
);
const pasteToAssets = useCallback(() => {
if (!imageToPaste) {
return;
}
handlePaste(imageToPaste, 'assets');
}, [handlePaste, imageToPaste]);
const pasteToCanvas = useCallback(() => {
if (!imageToPaste) {
return;
}
handlePaste(imageToPaste, 'canvas');
}, [handlePaste, imageToPaste]);
const pasteToBbox = useCallback(() => {
if (!imageToPaste) {
return;
}
handlePaste(imageToPaste, 'bbox');
}, [handlePaste, imageToPaste]);
return (
<Modal isOpen={imageToPaste !== null} onClose={clearFileToPaste} useInert={false} isCentered size="2xl">
<ModalOverlay />
<ModalContent>
<ModalHeader>{t('controlLayers.pasteTo')}</ModalHeader>
<ModalCloseButton />
<ModalBody display="flex" justifyContent="center">
<Flex flexDir="column" gap={4} w="min-content">
<Button size="lg" onClick={pasteToCanvas} isDisabled={isLoading} leftIcon={<PiImageBold />}>
{t('controlLayers.pasteToCanvasDesc')}
</Button>
<Button size="lg" onClick={pasteToBbox} isDisabled={isLoading} leftIcon={<PiBoundingBoxBold />}>
{t('controlLayers.pasteToBboxDesc')}
</Button>
<Button size="lg" onClick={pasteToAssets} isDisabled={isLoading} variant="ghost">
{t('controlLayers.pasteToAssetsDesc')}
</Button>
</Flex>
</ModalBody>
<ModalFooter>
<Button onClick={clearFileToPaste} variant="ghost" isLoading={isLoading}>
{t('common.cancel')}
</Button>
</ModalFooter>
</ModalContent>
</Modal>
);
});
CanvasPasteModal.displayName = 'CanvasPasteModal';

View File

@@ -1,99 +0,0 @@
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
import { Combobox, CompositeNumberInput, CompositeSlider, FormControl, FormLabel, Switch } from '@invoke-ai/ui-library';
import type { AdjustImageFilterConfig, AjustImageChannels } from 'features/controlLayers/store/filters';
import { IMAGE_FILTERS, isAjustImageChannels } from 'features/controlLayers/store/filters';
import type { ChangeEvent } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { FilterComponentProps } from './types';
type Props = FilterComponentProps<AdjustImageFilterConfig>;
const DEFAULTS = IMAGE_FILTERS.adjust_image.buildDefaults();
export const FilterAdjustImage = memo(({ onChange, config }: Props) => {
const { t } = useTranslation();
const handleChannelChange = useCallback<ComboboxOnChange>(
(v) => {
if (!isAjustImageChannels(v?.value)) {
return;
}
onChange({ ...config, channel: v.value });
},
[config, onChange]
);
const handleValueChange = useCallback(
(v: number) => {
onChange({ ...config, value: v });
},
[config, onChange]
);
const handleScaleChange = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
onChange({ ...config, scale_values: e.target.checked });
},
[config, onChange]
);
const options: { label: string; value: AjustImageChannels }[] = useMemo(
() => [
{ label: t('controlLayers.filter.adjust_image.red'), value: 'Red (RGBA)' },
{ label: t('controlLayers.filter.adjust_image.green'), value: 'Green (RGBA)' },
{ label: t('controlLayers.filter.adjust_image.blue'), value: 'Blue (RGBA)' },
{ label: t('controlLayers.filter.adjust_image.alpha'), value: 'Alpha (RGBA)' },
{ label: t('controlLayers.filter.adjust_image.cyan'), value: 'Cyan (CMYK)' },
{ label: t('controlLayers.filter.adjust_image.magenta'), value: 'Magenta (CMYK)' },
{ label: t('controlLayers.filter.adjust_image.yellow'), value: 'Yellow (CMYK)' },
{ label: t('controlLayers.filter.adjust_image.black'), value: 'Black (CMYK)' },
{ label: t('controlLayers.filter.adjust_image.hue'), value: 'Hue (HSV)' },
{ label: t('controlLayers.filter.adjust_image.saturation'), value: 'Saturation (HSV)' },
{ label: t('controlLayers.filter.adjust_image.value'), value: 'Value (HSV)' },
{ label: t('controlLayers.filter.adjust_image.luminosity'), value: 'Luminosity (LAB)' },
{ label: t('controlLayers.filter.adjust_image.a'), value: 'A (LAB)' },
{ label: t('controlLayers.filter.adjust_image.b'), value: 'B (LAB)' },
{ label: t('controlLayers.filter.adjust_image.y'), value: 'Y (YCbCr)' },
{ label: t('controlLayers.filter.adjust_image.cb'), value: 'Cb (YCbCr)' },
{ label: t('controlLayers.filter.adjust_image.cr'), value: 'Cr (YCbCr)' },
],
[t]
);
const value = useMemo(() => options.filter((o) => o.value === config.channel)[0], [options, config.channel]);
return (
<>
<FormControl>
<FormLabel m={0}>{t('controlLayers.filter.adjust_image.channel')}</FormLabel>
<Combobox value={value} options={options} onChange={handleChannelChange} isSearchable={false} />
</FormControl>
<FormControl>
<FormLabel m={0}>{t('controlLayers.filter.adjust_image.value_setting')}</FormLabel>
<CompositeSlider
value={config.value}
defaultValue={DEFAULTS.value}
onChange={handleValueChange}
min={0}
max={2}
step={0.0025}
marks
/>
<CompositeNumberInput
value={config.value}
defaultValue={DEFAULTS.value}
onChange={handleValueChange}
min={0}
max={255}
step={0.0025}
/>
</FormControl>
<FormControl w="max-content">
<FormLabel m={0}>{t('controlLayers.filter.adjust_image.scale_values')}</FormLabel>
<Switch defaultChecked={DEFAULTS.scale_values} isChecked={config.scale_values} onChange={handleScaleChange} />
</FormControl>
</>
);
});
FilterAdjustImage.displayName = 'FilterAdjustImage';

View File

@@ -1,72 +0,0 @@
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
import { Combobox, CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
import type { BlurFilterConfig, BlurTypes } from 'features/controlLayers/store/filters';
import { IMAGE_FILTERS, isBlurTypes } from 'features/controlLayers/store/filters';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { FilterComponentProps } from './types';
type Props = FilterComponentProps<BlurFilterConfig>;
const DEFAULTS = IMAGE_FILTERS.img_blur.buildDefaults();
export const FilterBlur = memo(({ onChange, config }: Props) => {
const { t } = useTranslation();
const handleBlurTypeChange = useCallback<ComboboxOnChange>(
(v) => {
if (!isBlurTypes(v?.value)) {
return;
}
onChange({ ...config, blur_type: v.value });
},
[config, onChange]
);
const handleRadiusChange = useCallback(
(v: number) => {
onChange({ ...config, radius: v });
},
[config, onChange]
);
const options: { label: string; value: BlurTypes }[] = useMemo(
() => [
{ label: t('controlLayers.filter.img_blur.gaussian_type'), value: 'gaussian' },
{ label: t('controlLayers.filter.img_blur.box_type'), value: 'box' },
],
[t]
);
const value = useMemo(() => options.filter((o) => o.value === config.blur_type)[0], [options, config.blur_type]);
return (
<>
<FormControl>
<FormLabel m={0}>{t('controlLayers.filter.img_blur.blur_type')}</FormLabel>
<Combobox value={value} options={options} onChange={handleBlurTypeChange} isSearchable={false} />
</FormControl>
<FormControl>
<FormLabel m={0}>{t('controlLayers.filter.img_blur.blur_radius')}</FormLabel>
<CompositeSlider
value={config.radius}
defaultValue={DEFAULTS.radius}
onChange={handleRadiusChange}
min={1}
max={64}
step={0.1}
marks
/>
<CompositeNumberInput
value={config.radius}
defaultValue={DEFAULTS.radius}
onChange={handleRadiusChange}
min={1}
max={4096}
step={0.1}
/>
</FormControl>
</>
);
});
FilterBlur.displayName = 'FilterBlur';

View File

@@ -1,111 +0,0 @@
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
import { Combobox, CompositeNumberInput, CompositeSlider, FormControl, FormLabel, Switch } from '@invoke-ai/ui-library';
import type { NoiseFilterConfig, NoiseTypes } from 'features/controlLayers/store/filters';
import { IMAGE_FILTERS, isNoiseTypes } from 'features/controlLayers/store/filters';
import type { ChangeEvent } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { FilterComponentProps } from './types';
type Props = FilterComponentProps<NoiseFilterConfig>;
const DEFAULTS = IMAGE_FILTERS.img_noise.buildDefaults();
export const FilterNoise = memo(({ onChange, config }: Props) => {
const { t } = useTranslation();
const handleNoiseTypeChange = useCallback<ComboboxOnChange>(
(v) => {
if (!isNoiseTypes(v?.value)) {
return;
}
onChange({ ...config, noise_type: v.value });
},
[config, onChange]
);
const handleAmountChange = useCallback(
(v: number) => {
onChange({ ...config, amount: v });
},
[config, onChange]
);
const handleColorChange = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
onChange({ ...config, noise_color: e.target.checked });
},
[config, onChange]
);
const handleSizeChange = useCallback(
(v: number) => {
onChange({ ...config, size: v });
},
[config, onChange]
);
const options: { label: string; value: NoiseTypes }[] = useMemo(
() => [
{ label: t('controlLayers.filter.img_noise.gaussian_type'), value: 'gaussian' },
{ label: t('controlLayers.filter.img_noise.salt_and_pepper_type'), value: 'salt_and_pepper' },
],
[t]
);
const value = useMemo(() => options.filter((o) => o.value === config.noise_type)[0], [options, config.noise_type]);
return (
<>
<FormControl>
<FormLabel m={0}>{t('controlLayers.filter.img_noise.noise_type')}</FormLabel>
<Combobox value={value} options={options} onChange={handleNoiseTypeChange} isSearchable={false} />
</FormControl>
<FormControl>
<FormLabel m={0}>{t('controlLayers.filter.img_noise.noise_amount')}</FormLabel>
<CompositeSlider
value={config.amount}
defaultValue={DEFAULTS.amount}
onChange={handleAmountChange}
min={0}
max={1}
step={0.01}
marks
/>
<CompositeNumberInput
value={config.amount}
defaultValue={DEFAULTS.amount}
onChange={handleAmountChange}
min={0}
max={1}
step={0.01}
/>
</FormControl>
<FormControl>
<FormLabel m={0}>{t('controlLayers.filter.img_noise.size')}</FormLabel>
<CompositeSlider
value={config.size}
defaultValue={DEFAULTS.size}
onChange={handleSizeChange}
min={1}
max={16}
step={1}
marks
/>
<CompositeNumberInput
value={config.size}
defaultValue={DEFAULTS.size}
onChange={handleSizeChange}
min={1}
max={256}
step={1}
/>
</FormControl>
<FormControl w="max-content">
<FormLabel m={0}>{t('controlLayers.filter.img_noise.noise_color')}</FormLabel>
<Switch defaultChecked={DEFAULTS.noise_color} isChecked={config.noise_color} onChange={handleColorChange} />
</FormControl>
</>
);
});
FilterNoise.displayName = 'Filternoise';

View File

@@ -1,6 +1,4 @@
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { FilterAdjustImage } from 'features/controlLayers/components/Filters/FilterAdjustImage';
import { FilterBlur } from 'features/controlLayers/components/Filters/FilterBlur';
import { FilterCannyEdgeDetection } from 'features/controlLayers/components/Filters/FilterCannyEdgeDetection';
import { FilterColorMap } from 'features/controlLayers/components/Filters/FilterColorMap';
import { FilterContentShuffle } from 'features/controlLayers/components/Filters/FilterContentShuffle';
@@ -10,7 +8,6 @@ import { FilterHEDEdgeDetection } from 'features/controlLayers/components/Filter
import { FilterLineartEdgeDetection } from 'features/controlLayers/components/Filters/FilterLineartEdgeDetection';
import { FilterMediaPipeFaceDetection } from 'features/controlLayers/components/Filters/FilterMediaPipeFaceDetection';
import { FilterMLSDDetection } from 'features/controlLayers/components/Filters/FilterMLSDDetection';
import { FilterNoise } from 'features/controlLayers/components/Filters/FilterNoise';
import { FilterPiDiNetEdgeDetection } from 'features/controlLayers/components/Filters/FilterPiDiNetEdgeDetection';
import { FilterSpandrel } from 'features/controlLayers/components/Filters/FilterSpandrel';
import type { FilterConfig } from 'features/controlLayers/store/filters';
@@ -22,10 +19,6 @@ type Props = { filterConfig: FilterConfig; onChange: (filterConfig: FilterConfig
export const FilterSettings = memo(({ filterConfig, onChange }: Props) => {
const { t } = useTranslation();
if (filterConfig.type === 'adjust_image') {
return <FilterAdjustImage config={filterConfig} onChange={onChange} />;
}
if (filterConfig.type === 'canny_edge_detection') {
return <FilterCannyEdgeDetection config={filterConfig} onChange={onChange} />;
}
@@ -66,14 +59,6 @@ export const FilterSettings = memo(({ filterConfig, onChange }: Props) => {
return <FilterPiDiNetEdgeDetection config={filterConfig} onChange={onChange} />;
}
if (filterConfig.type === 'img_blur') {
return <FilterBlur config={filterConfig} onChange={onChange} />;
}
if (filterConfig.type === 'img_noise') {
return <FilterNoise config={filterConfig} onChange={onChange} />;
}
if (filterConfig.type === 'spandrel_filter') {
return <FilterSpandrel config={filterConfig} onChange={onChange} />;
}

View File

@@ -1,8 +1,8 @@
import { MenuItem } from '@invoke-ai/ui-library';
import { useEntityAdapterSafe } from 'features/controlLayers/contexts/EntityAdapterContext';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useCopyLayerToClipboard } from 'features/controlLayers/hooks/copyHooks';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { useCopyLayerToClipboard } from 'features/controlLayers/hooks/useCopyLayerToClipboard';
import { useEntityIsEmpty } from 'features/controlLayers/hooks/useEntityIsEmpty';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';

View File

@@ -1,93 +0,0 @@
import { logger } from 'app/logging/logger';
import { withResultAsync } from 'common/util/result';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer';
import type { CanvasEntityAdapterInpaintMask } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterInpaintMask';
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
import type { CanvasEntityAdapterRegionalGuidance } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRegionalGuidance';
import { canvasToBlob } from 'features/controlLayers/konva/util';
import { copyBlobToClipboard } from 'features/system/util/copyBlobToClipboard';
import { toast } from 'features/toast/toast';
import { startCase } from 'lodash-es';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { serializeError } from 'serialize-error';
const log = logger('canvas');
export const useCopyLayerToClipboard = () => {
const { t } = useTranslation();
const copyLayerToCipboard = useCallback(
async (
adapter:
| CanvasEntityAdapterRasterLayer
| CanvasEntityAdapterControlLayer
| CanvasEntityAdapterInpaintMask
| CanvasEntityAdapterRegionalGuidance
| null
) => {
if (!adapter) {
return;
}
const result = await withResultAsync(async () => {
const canvas = adapter.getCanvas();
const blob = await canvasToBlob(canvas);
copyBlobToClipboard(blob);
});
if (result.isOk()) {
log.trace('Layer copied to clipboard');
toast({
status: 'info',
title: t('toast.layerCopiedToClipboard'),
});
} else {
log.error({ error: serializeError(result.error) }, 'Problem copying layer to clipboard');
toast({
status: 'error',
title: t('toast.problemCopyingLayer'),
});
}
},
[t]
);
return copyLayerToCipboard;
};
export const useCopyCanvasToClipboard = (region: 'canvas' | 'bbox') => {
const { t } = useTranslation();
const canvasManager = useCanvasManager();
const copyCanvasToClipboard = useCallback(async () => {
const rect =
region === 'bbox'
? canvasManager.stateApi.getBbox().rect
: canvasManager.compositor.getVisibleRectOfType('raster_layer');
if (rect.width === 0 || rect.height === 0) {
toast({
title: t('controlLayers.copyRegionError', { region: startCase(region) }),
description: t('controlLayers.regionIsEmpty'),
status: 'warning',
});
return;
}
const result = await withResultAsync(async () => {
const rasterAdapters = canvasManager.compositor.getVisibleAdaptersOfType('raster_layer');
const canvasElement = canvasManager.compositor.getCompositeCanvas(rasterAdapters, rect);
const blob = await canvasToBlob(canvasElement);
copyBlobToClipboard(blob);
});
if (result.isOk()) {
toast({ title: t('controlLayers.regionCopiedToClipboard', { region: startCase(region) }) });
} else {
log.error({ error: serializeError(result.error) }, 'Failed to save canvas to gallery');
toast({ title: t('controlLayers.copyRegionError', { region: startCase(region) }), status: 'error' });
}
}, [canvasManager.compositor, canvasManager.stateApi, region, t]);
return copyCanvasToClipboard;
};

View File

@@ -0,0 +1,50 @@
import { logger } from 'app/logging/logger';
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer';
import type { CanvasEntityAdapterInpaintMask } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterInpaintMask';
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
import type { CanvasEntityAdapterRegionalGuidance } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRegionalGuidance';
import { canvasToBlob } from 'features/controlLayers/konva/util';
import { copyBlobToClipboard } from 'features/system/util/copyBlobToClipboard';
import { toast } from 'features/toast/toast';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { serializeError } from 'serialize-error';
const log = logger('canvas');
export const useCopyLayerToClipboard = () => {
const { t } = useTranslation();
const copyLayerToCipboard = useCallback(
async (
adapter:
| CanvasEntityAdapterRasterLayer
| CanvasEntityAdapterControlLayer
| CanvasEntityAdapterInpaintMask
| CanvasEntityAdapterRegionalGuidance
| null
) => {
if (!adapter) {
return;
}
try {
const canvas = adapter.getCanvas();
const blob = await canvasToBlob(canvas);
copyBlobToClipboard(blob);
log.trace('Layer copied to clipboard');
toast({
status: 'info',
title: t('toast.layerCopiedToClipboard'),
});
} catch (error) {
log.error({ error: serializeError(error) }, 'Problem copying layer to clipboard');
toast({
status: 'error',
title: t('toast.problemCopyingLayer'),
});
}
},
[t]
);
return copyLayerToCipboard;
};

View File

@@ -68,7 +68,7 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
/**
* The config for the filter.
*/
$filterConfig = atom<FilterConfig>(IMAGE_FILTERS.adjust_image.buildDefaults());
$filterConfig = atom<FilterConfig>(IMAGE_FILTERS.canny_edge_detection.buildDefaults());
/**
* The initial filter config, used to reset the filter config.
@@ -212,7 +212,7 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
return filter.buildDefaults();
} else {
// Otherwise, used the default filter
return IMAGE_FILTERS.adjust_image.buildDefaults();
return IMAGE_FILTERS.canny_edge_detection.buildDefaults();
}
};
@@ -297,9 +297,10 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
const imageState = imageDTOToImageObject(filterResult.value);
this.$imageState.set(imageState);
// Stash the existing image module - we will destroy it after the new image is rendered to prevent a flash
// of an empty layer
const oldImageModule = this.imageModule;
// Destroy any existing masked image and create a new one
if (this.imageModule) {
this.imageModule.destroy();
}
this.imageModule = new CanvasObjectImage(imageState, this);
@@ -308,16 +309,6 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
this.konva.group.add(this.imageModule.konva.group);
// The filtered image have some transparency, so we need to hide the objects of the parent entity to prevent the
// two images from blending. We will show the objects again in the teardown method, which is always called after
// the filter finishes (applied or canceled).
this.parent.renderer.hideObjects();
if (oldImageModule) {
// Destroy the old image module now that the new one is rendered
oldImageModule.destroy();
}
// The porcessing is complete, set can set the last processed hash and isProcessing to false
this.$lastProcessedHash.set(hash);
@@ -433,8 +424,6 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
teardown = () => {
this.unsubscribe();
// Re-enable the objects of the parent entity
this.parent.renderer.showObjects();
this.konva.group.remove();
// The reset must be done _after_ unsubscribing from listeners, in case the listeners would otherwise react to
// the reset. For example, if auto-processing is enabled and we reset the state, it may trigger processing.

View File

@@ -185,14 +185,6 @@ export class CanvasEntityObjectRenderer extends CanvasModuleBase {
return didRender;
};
hideObjects = () => {
this.konva.objectGroup.hide();
};
showObjects = () => {
this.konva.objectGroup.show();
};
adoptObjectRenderer = (renderer: AnyObjectRenderer) => {
this.renderers.set(renderer.id, renderer);
renderer.konva.group.moveTo(this.konva.objectGroup);

View File

@@ -10,7 +10,6 @@ import {
getKonvaNodeDebugAttrs,
getPrefixedId,
offsetCoord,
roundRect,
} from 'features/controlLayers/konva/util';
import { selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
import type { Coordinate, Rect, RectWithRotation } from 'features/controlLayers/store/types';
@@ -774,7 +773,7 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
const rect = this.getRelativeRect();
const rasterizeResult = await withResultAsync(() =>
this.parent.renderer.rasterize({
rect: roundRect(rect),
rect,
replaceObjects: true,
ignoreCache: true,
attrs: { opacity: 1, filters: [] },

View File

@@ -740,12 +740,3 @@ export const getColorAtCoordinate = (stage: Konva.Stage, coord: Coordinate): Rgb
return { r, g, b };
};
export const roundRect = (rect: Rect): Rect => {
return {
x: Math.round(rect.x),
y: Math.round(rect.y),
width: Math.round(rect.width),
height: Math.round(rect.height),
};
};

View File

@@ -6,35 +6,6 @@ import type { ControlLoRAModelConfig, ControlNetModelConfig, T2IAdapterModelConf
import { assert } from 'tsafe';
import { z } from 'zod';
const zAjustImageChannels = z.enum([
'Red (RGBA)',
'Green (RGBA)',
'Blue (RGBA)',
'Alpha (RGBA)',
'Cyan (CMYK)',
'Magenta (CMYK)',
'Yellow (CMYK)',
'Black (CMYK)',
'Hue (HSV)',
'Saturation (HSV)',
'Value (HSV)',
'Luminosity (LAB)',
'A (LAB)',
'B (LAB)',
'Y (YCbCr)',
'Cb (YCbCr)',
'Cr (YCbCr)',
]);
export type AjustImageChannels = z.infer<typeof zAjustImageChannels>;
export const isAjustImageChannels = (v: unknown): v is AjustImageChannels => zAjustImageChannels.safeParse(v).success;
const zAdjustImageFilterConfig = z.object({
type: z.literal('adjust_image'),
channel: zAjustImageChannels,
value: z.number(),
scale_values: z.boolean().optional(),
});
export type AdjustImageFilterConfig = z.infer<typeof zAdjustImageFilterConfig>;
const zCannyEdgeDetectionFilterConfig = z.object({
type: z.literal('canny_edge_detection'),
low_threshold: z.number().int().gte(0).lte(255),
@@ -124,30 +95,7 @@ const zSpandrelFilterConfig = z.object({
});
export type SpandrelFilterConfig = z.infer<typeof zSpandrelFilterConfig>;
const zBlurTypes = z.enum(['gaussian', 'box']);
export type BlurTypes = z.infer<typeof zBlurTypes>;
export const isBlurTypes = (v: unknown): v is BlurTypes => zBlurTypes.safeParse(v).success;
const zBlurFilterConfig = z.object({
type: z.literal('img_blur'),
blur_type: zBlurTypes,
radius: z.number().gte(0),
});
export type BlurFilterConfig = z.infer<typeof zBlurFilterConfig>;
const zNoiseTypes = z.enum(['gaussian', 'salt_and_pepper']);
export type NoiseTypes = z.infer<typeof zNoiseTypes>;
export const isNoiseTypes = (v: unknown): v is NoiseTypes => zNoiseTypes.safeParse(v).success;
const zNoiseFilterConfig = z.object({
type: z.literal('img_noise'),
noise_type: zNoiseTypes,
amount: z.number().gte(0).lte(1),
noise_color: z.boolean(),
size: z.number().int().gte(1),
});
export type NoiseFilterConfig = z.infer<typeof zNoiseFilterConfig>;
const zFilterConfig = z.discriminatedUnion('type', [
zAdjustImageFilterConfig,
zCannyEdgeDetectionFilterConfig,
zColorMapFilterConfig,
zContentShuffleFilterConfig,
@@ -161,13 +109,10 @@ const zFilterConfig = z.discriminatedUnion('type', [
zPiDiNetEdgeDetectionFilterConfig,
zDWOpenposeDetectionFilterConfig,
zSpandrelFilterConfig,
zBlurFilterConfig,
zNoiseFilterConfig,
]);
export type FilterConfig = z.infer<typeof zFilterConfig>;
const zFilterType = z.enum([
'adjust_image',
'canny_edge_detection',
'color_map',
'content_shuffle',
@@ -181,8 +126,6 @@ const zFilterType = z.enum([
'pidi_edge_detection',
'dw_openpose_detection',
'spandrel_filter',
'img_blur',
'img_noise',
]);
export type FilterType = z.infer<typeof zFilterType>;
export const isFilterType = (v: unknown): v is FilterType => zFilterType.safeParse(v).success;
@@ -198,42 +141,6 @@ type ImageFilterData<T extends FilterConfig['type']> = {
};
export const IMAGE_FILTERS: { [key in FilterConfig['type']]: ImageFilterData<key> } = {
adjust_image: {
type: 'adjust_image',
buildDefaults: () => ({
type: 'adjust_image',
channel: 'Luminosity (LAB)',
value: 1,
scale_values: false,
}),
buildGraph: ({ image_name }, { channel, value, scale_values }) => {
const graph = new Graph(getPrefixedId('adjust_image_filter'));
let node;
if (scale_values) {
node = graph.addNode({
id: getPrefixedId('img_channel_multiply'),
type: 'img_channel_multiply',
image: { image_name },
channel,
scale: value,
invert_channel: false,
});
} else {
value = Math.min(value, 2); // Limit value to a maximum of 2
node = graph.addNode({
id: getPrefixedId('img_channel_offset'),
type: 'img_channel_offset',
image: { image_name },
channel,
offset: Math.round(255 * (value - 1)), // value is in range [0, 2], offset is in range [-255, 255]
});
}
return {
graph,
outputNodeId: node.id,
};
},
},
canny_edge_detection: {
type: 'canny_edge_detection',
buildDefaults: () => ({
@@ -522,62 +429,6 @@ export const IMAGE_FILTERS: { [key in FilterConfig['type']]: ImageFilterData<key
return true;
},
},
img_blur: {
type: 'img_blur',
buildDefaults: () => ({
type: 'img_blur',
blur_type: 'gaussian',
radius: 8,
}),
buildGraph: ({ image_name }, { blur_type, radius }) => {
const graph = new Graph(getPrefixedId('img_blur'));
const node = graph.addNode({
id: getPrefixedId('img_blur'),
type: 'img_blur',
image: { image_name },
blur_type: blur_type,
radius: radius,
});
return {
graph,
outputNodeId: node.id,
};
},
},
img_noise: {
type: 'img_noise',
buildDefaults: () => ({
type: 'img_noise',
noise_type: 'gaussian',
amount: 0.3,
noise_color: true,
size: 1,
}),
buildGraph: ({ image_name }, { noise_type, amount, noise_color, size }) => {
const graph = new Graph(getPrefixedId('img_noise'));
const node = graph.addNode({
id: getPrefixedId('img_noise'),
type: 'img_noise',
image: { image_name },
noise_type: noise_type,
amount: amount,
noise_color: noise_color,
size: size,
});
const rand = graph.addNode({
id: getPrefixedId('rand_int'),
use_cache: false,
type: 'rand_int',
low: 0,
high: 2147483647,
});
graph.addEdge(rand, 'value', node, 'seed');
return {
graph,
outputNodeId: node.id,
};
},
},
} as const;
/**

View File

@@ -1,5 +1,4 @@
import type {
BlurFilterConfig,
CannyEdgeDetectionFilterConfig,
ColorMapFilterConfig,
ContentShuffleFilterConfig,
@@ -13,7 +12,6 @@ import type {
LineartEdgeDetectionFilterConfig,
MediaPipeFaceDetectionFilterConfig,
MLSDDetectionFilterConfig,
NoiseFilterConfig,
NormalMapFilterConfig,
PiDiNetEdgeDetectionFilterConfig,
} from 'features/controlLayers/store/filters';
@@ -56,7 +54,6 @@ describe('Control Adapter Types', () => {
});
test('Processor Configs', () => {
// Types derived from OpenAPI
type _BlurFilterConfig = Required<Pick<Invocation<'img_blur'>, 'type' | 'radius' | 'blur_type'>>;
type _CannyEdgeDetectionFilterConfig = Required<
Pick<Invocation<'canny_edge_detection'>, 'type' | 'low_threshold' | 'high_threshold'>
>;
@@ -74,9 +71,6 @@ describe('Control Adapter Types', () => {
type _MLSDDetectionFilterConfig = Required<
Pick<Invocation<'mlsd_detection'>, 'type' | 'score_threshold' | 'distance_threshold'>
>;
type _NoiseFilterConfig = Required<
Pick<Invocation<'img_noise'>, 'type' | 'noise_type' | 'amount' | 'noise_color' | 'size'>
>;
type _NormalMapFilterConfig = Required<Pick<Invocation<'normal_map'>, 'type'>>;
type _DWOpenposeDetectionFilterConfig = Required<
Pick<Invocation<'dw_openpose_detection'>, 'type' | 'draw_body' | 'draw_face' | 'draw_hands'>
@@ -87,7 +81,6 @@ describe('Control Adapter Types', () => {
// The processor configs are manually modeled zod schemas. This test ensures that the inferred types are correct.
// The types prefixed with `_` are types generated from OpenAPI, while the types without the prefix are manually modeled.
assert<Equals<_BlurFilterConfig, BlurFilterConfig>>();
assert<Equals<_CannyEdgeDetectionFilterConfig, CannyEdgeDetectionFilterConfig>>();
assert<Equals<_ColorMapFilterConfig, ColorMapFilterConfig>>();
assert<Equals<_ContentShuffleFilterConfig, ContentShuffleFilterConfig>>();
@@ -97,7 +90,6 @@ describe('Control Adapter Types', () => {
assert<Equals<_LineartEdgeDetectionFilterConfig, LineartEdgeDetectionFilterConfig>>();
assert<Equals<_MediaPipeFaceDetectionFilterConfig, MediaPipeFaceDetectionFilterConfig>>();
assert<Equals<_MLSDDetectionFilterConfig, MLSDDetectionFilterConfig>>();
assert<Equals<_NoiseFilterConfig, NoiseFilterConfig>>();
assert<Equals<_NormalMapFilterConfig, NormalMapFilterConfig>>();
assert<Equals<_DWOpenposeDetectionFilterConfig, DWOpenposeDetectionFilterConfig>>();
assert<Equals<_PiDiNetEdgeDetectionFilterConfig, PiDiNetEdgeDetectionFilterConfig>>();

View File

@@ -4,18 +4,14 @@ import { containsFiles, getFiles } from '@atlaskit/pragmatic-drag-and-drop/exter
import { preventUnhandled } from '@atlaskit/pragmatic-drag-and-drop/prevent-unhandled';
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Box, Flex, Heading } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { getStore } from 'app/store/nanostores/store';
import { useAppSelector } from 'app/store/storeHooks';
import { setFileToPaste } from 'features/controlLayers/components/CanvasPasteModal';
import { DndDropOverlay } from 'features/dnd/DndDropOverlay';
import type { DndTargetState } from 'features/dnd/types';
import { $imageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
import { selectMaxImageUploadCount } from 'features/system/store/configSlice';
import { toast } from 'features/toast/toast';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { memo, useCallback, useEffect, useRef, useState } from 'react';
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { uploadImages } from 'services/api/endpoints/images';
import { useBoardName } from 'services/api/hooks/useBoardName';
@@ -75,13 +71,12 @@ export const FullscreenDropzone = memo(() => {
const ref = useRef<HTMLDivElement>(null);
const maxImageUploadCount = useAppSelector(selectMaxImageUploadCount);
const [dndState, setDndState] = useState<DndTargetState>('idle');
const activeTab = useAppSelector(selectActiveTab);
const isImageViewerOpen = useStore($imageViewer);
const uploadFilesSchema = useMemo(() => getFilesSchema(maxImageUploadCount), [maxImageUploadCount]);
const validateAndUploadFiles = useCallback(
(files: File[]) => {
const { getState } = getStore();
const uploadFilesSchema = getFilesSchema(maxImageUploadCount);
const parseResult = uploadFilesSchema.safeParse(files);
if (!parseResult.success) {
@@ -98,15 +93,6 @@ export const FullscreenDropzone = memo(() => {
});
return;
}
// While on the canvas tab and when pasting a single image, canvas may want to create a new layer. Let it handle
// the paste event.
const [firstImageFile] = files;
if (!isImageViewerOpen && activeTab === 'canvas' && files.length === 1 && firstImageFile) {
setFileToPaste(firstImageFile);
return;
}
const autoAddBoardId = selectAutoAddBoardId(getState());
const uploadArgs: UploadImageArg[] = files.map((file, i) => ({
@@ -119,18 +105,7 @@ export const FullscreenDropzone = memo(() => {
uploadImages(uploadArgs);
},
[activeTab, isImageViewerOpen, maxImageUploadCount, t]
);
const onPaste = useCallback(
(e: ClipboardEvent) => {
if (!e.clipboardData?.files) {
return;
}
const files = Array.from(e.clipboardData.files);
validateAndUploadFiles(files);
},
[validateAndUploadFiles]
[maxImageUploadCount, t, uploadFilesSchema]
);
useEffect(() => {
@@ -169,12 +144,24 @@ export const FullscreenDropzone = memo(() => {
}, [validateAndUploadFiles]);
useEffect(() => {
window.addEventListener('paste', onPaste);
const controller = new AbortController();
document.addEventListener(
'paste',
(e) => {
if (!e.clipboardData?.files) {
return;
}
const files = Array.from(e.clipboardData.files);
validateAndUploadFiles(files);
},
{ signal: controller.signal }
);
return () => {
window.removeEventListener('paste', onPaste);
controller.abort();
};
}, [onPaste]);
}, [validateAndUploadFiles]);
return (
<Box ref={ref} data-dnd-state={dndState} sx={sx}>

View File

@@ -1,4 +1,3 @@
import { logger } from 'app/logging/logger';
import type { AppDispatch, RootState } from 'app/store/store';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import type {
@@ -10,6 +9,7 @@ import { selectComparisonImages } from 'features/gallery/components/ImageViewer/
import type { BoardId } from 'features/gallery/store/types';
import {
addImagesToBoard,
addImagesToNodeImageFieldCollectionAction,
createNewCanvasEntityFromImage,
removeImagesFromBoard,
replaceCanvasEntityObjectsWithImage,
@@ -19,14 +19,10 @@ import {
setRegionalGuidanceReferenceImage,
setUpscaleInitialImage,
} from 'features/imageActions/actions';
import { fieldImageCollectionValueChanged } from 'features/nodes/store/nodesSlice';
import { selectFieldInputInstance, selectNodesSlice } from 'features/nodes/store/selectors';
import { type FieldIdentifier, isImageFieldCollectionInputInstance } from 'features/nodes/types/field';
import type { FieldIdentifier } from 'features/nodes/types/field';
import type { ImageDTO } from 'services/api/types';
import type { JsonObject } from 'type-fest';
const log = logger('dnd');
type RecordUnknown = Record<string | symbol, unknown>;
type DndData<
@@ -272,27 +268,15 @@ export const addImagesToNodeImageFieldCollectionDndTarget: DndTarget<
}
const { fieldIdentifier } = targetData.payload;
const fieldInputInstance = selectFieldInputInstance(
selectNodesSlice(getState()),
fieldIdentifier.nodeId,
fieldIdentifier.fieldName
);
if (!isImageFieldCollectionInputInstance(fieldInputInstance)) {
log.warn({ fieldIdentifier }, 'Attempted to add images to a non-image field collection');
return;
}
const newValue = fieldInputInstance.value ? [...fieldInputInstance.value] : [];
const imageDTOs: ImageDTO[] = [];
if (singleImageDndSource.typeGuard(sourceData)) {
newValue.push({ image_name: sourceData.payload.imageDTO.image_name });
imageDTOs.push(sourceData.payload.imageDTO);
} else {
newValue.push(...sourceData.payload.imageDTOs.map(({ image_name }) => ({ image_name })));
imageDTOs.push(...sourceData.payload.imageDTOs);
}
dispatch(fieldImageCollectionValueChanged({ ...fieldIdentifier, value: newValue }));
addImagesToNodeImageFieldCollectionAction({ fieldIdentifier, imageDTOs, dispatch, getState });
},
};
//#endregion

View File

@@ -1,3 +1,4 @@
import { logger } from 'app/logging/logger';
import type { AppDispatch, RootState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import { selectDefaultIPAdapter } from 'features/controlLayers/hooks/addLayerHooks';
@@ -19,26 +20,30 @@ import { selectBboxModelBase, selectBboxRect } from 'features/controlLayers/stor
import type {
CanvasControlLayerState,
CanvasEntityIdentifier,
CanvasEntityState,
CanvasEntityType,
CanvasInpaintMaskState,
CanvasRasterLayerState,
CanvasRegionalGuidanceState,
CanvasRenderableEntityIdentifier,
CanvasRenderableEntityState,
} from 'features/controlLayers/store/types';
import { imageDTOToImageObject, imageDTOToImageWithDims, initialControlNet } from 'features/controlLayers/store/util';
import { calculateNewSize } from 'features/controlLayers/util/getScaledBoundingBoxDimensions';
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
import type { BoardId } from 'features/gallery/store/types';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import type { FieldIdentifier } from 'features/nodes/types/field';
import { fieldImageCollectionValueChanged, fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { selectFieldInputInstance, selectNodesSlice } from 'features/nodes/store/selectors';
import { type FieldIdentifier, isImageFieldCollectionInputInstance } from 'features/nodes/types/field';
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
import { getOptimalDimension } from 'features/parameters/util/optimalDimension';
import { uniqBy } from 'lodash-es';
import { imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
const log = logger('system');
export const setGlobalReferenceImage = (arg: {
imageDTO: ImageDTO;
entityIdentifier: CanvasEntityIdentifier<'reference_image'>;
@@ -72,6 +77,54 @@ export const setNodeImageFieldImage = (arg: {
dispatch(fieldImageValueChanged({ ...fieldIdentifier, value: imageDTO }));
};
export const addImagesToNodeImageFieldCollectionAction = (arg: {
imageDTOs: ImageDTO[];
fieldIdentifier: FieldIdentifier;
dispatch: AppDispatch;
getState: () => RootState;
}) => {
const { imageDTOs, fieldIdentifier, dispatch, getState } = arg;
const fieldInputInstance = selectFieldInputInstance(
selectNodesSlice(getState()),
fieldIdentifier.nodeId,
fieldIdentifier.fieldName
);
if (!isImageFieldCollectionInputInstance(fieldInputInstance)) {
log.warn({ fieldIdentifier }, 'Attempted to add images to a non-image field collection');
return;
}
const images = fieldInputInstance.value ? [...fieldInputInstance.value] : [];
images.push(...imageDTOs.map(({ image_name }) => ({ image_name })));
const uniqueImages = uniqBy(images, 'image_name');
dispatch(fieldImageCollectionValueChanged({ ...fieldIdentifier, value: uniqueImages }));
};
export const removeImageFromNodeImageFieldCollectionAction = (arg: {
imageName: string;
fieldIdentifier: FieldIdentifier;
dispatch: AppDispatch;
getState: () => RootState;
}) => {
const { imageName, fieldIdentifier, dispatch, getState } = arg;
const fieldInputInstance = selectFieldInputInstance(
selectNodesSlice(getState()),
fieldIdentifier.nodeId,
fieldIdentifier.fieldName
);
if (!isImageFieldCollectionInputInstance(fieldInputInstance)) {
log.warn({ fieldIdentifier }, 'Attempted to remove image from a non-image field collection');
return;
}
const images = fieldInputInstance.value ? [...fieldInputInstance.value] : [];
const imagesWithoutTheImageToRemove = images.filter((image) => image.image_name !== imageName);
const uniqueImages = uniqBy(imagesWithoutTheImageToRemove, 'image_name');
dispatch(fieldImageCollectionValueChanged({ ...fieldIdentifier, value: uniqueImages }));
};
export const setComparisonImage = (arg: { imageDTO: ImageDTO; dispatch: AppDispatch }) => {
const { imageDTO, dispatch } = arg;
dispatch(imageToCompareChanged(imageDTO));
@@ -82,7 +135,7 @@ export const createNewCanvasEntityFromImage = (arg: {
type: CanvasEntityType | 'regional_guidance_with_reference_image';
dispatch: AppDispatch;
getState: () => RootState;
overrides?: Partial<Pick<CanvasRenderableEntityState, 'isEnabled' | 'isLocked' | 'name' | 'position'>>;
overrides?: Partial<Pick<CanvasEntityState, 'isEnabled' | 'isLocked' | 'name'>>;
}) => {
const { type, imageDTO, dispatch, getState, overrides: _overrides } = arg;
const state = getState();

View File

@@ -36,15 +36,10 @@ const FieldHandle = (props: FieldHandleProps) => {
borderWidth: !isSingle(type) ? 4 : 0,
borderStyle: 'solid',
borderColor: color,
borderRadius: isModelType || type.batch ? 4 : '100%',
borderRadius: isModelType ? 4 : '100%',
zIndex: 1,
transformOrigin: 'center',
};
if (type.batch) {
s.transform = 'rotate(45deg) translateX(-0.3rem) translateY(-0.3rem)';
}
if (handleType === 'target') {
s.insetInlineStart = '-1rem';
} else {

View File

@@ -1,10 +1,5 @@
import { FloatGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FloatGeneratorFieldComponent';
import { ImageFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldCollectionInputComponent';
import { IntegerGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerGeneratorFieldComponent';
import ModelIdentifierFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent';
import { NumberFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/NumberFieldCollectionInputComponent';
import { StringFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/StringFieldCollectionInputComponent';
import { StringGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/StringGeneratorFieldComponent';
import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance';
import { useFieldInputTemplate } from 'features/nodes/hooks/useFieldInputTemplate';
import {
@@ -26,12 +21,8 @@ import {
isControlNetModelFieldInputTemplate,
isEnumFieldInputInstance,
isEnumFieldInputTemplate,
isFloatFieldCollectionInputInstance,
isFloatFieldCollectionInputTemplate,
isFloatFieldInputInstance,
isFloatFieldInputTemplate,
isFloatGeneratorFieldInputInstance,
isFloatGeneratorFieldInputTemplate,
isFluxMainModelFieldInputInstance,
isFluxMainModelFieldInputTemplate,
isFluxVAEModelFieldInputInstance,
@@ -40,12 +31,8 @@ import {
isImageFieldCollectionInputTemplate,
isImageFieldInputInstance,
isImageFieldInputTemplate,
isIntegerFieldCollectionInputInstance,
isIntegerFieldCollectionInputTemplate,
isIntegerFieldInputInstance,
isIntegerFieldInputTemplate,
isIntegerGeneratorFieldInputInstance,
isIntegerGeneratorFieldInputTemplate,
isIPAdapterModelFieldInputInstance,
isIPAdapterModelFieldInputTemplate,
isLoRAModelFieldInputInstance,
@@ -64,12 +51,8 @@ import {
isSDXLRefinerModelFieldInputTemplate,
isSpandrelImageToImageModelFieldInputInstance,
isSpandrelImageToImageModelFieldInputTemplate,
isStringFieldCollectionInputInstance,
isStringFieldCollectionInputTemplate,
isStringFieldInputInstance,
isStringFieldInputTemplate,
isStringGeneratorFieldInputInstance,
isStringGeneratorFieldInputTemplate,
isT2IAdapterModelFieldInputInstance,
isT2IAdapterModelFieldInputTemplate,
isT5EncoderModelFieldInputInstance,
@@ -114,10 +97,6 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
const fieldInstance = useFieldInputInstance(nodeId, fieldName);
const fieldTemplate = useFieldInputTemplate(nodeId, fieldName);
if (isStringFieldCollectionInputInstance(fieldInstance) && isStringFieldCollectionInputTemplate(fieldTemplate)) {
return <StringFieldCollectionInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isStringFieldInputInstance(fieldInstance) && isStringFieldInputTemplate(fieldTemplate)) {
return <StringFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
@@ -126,22 +105,13 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
return <BooleanFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isIntegerFieldInputInstance(fieldInstance) && isIntegerFieldInputTemplate(fieldTemplate)) {
if (
(isIntegerFieldInputInstance(fieldInstance) && isIntegerFieldInputTemplate(fieldTemplate)) ||
(isFloatFieldInputInstance(fieldInstance) && isFloatFieldInputTemplate(fieldTemplate))
) {
return <NumberFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isFloatFieldInputInstance(fieldInstance) && isFloatFieldInputTemplate(fieldTemplate)) {
return <NumberFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isIntegerFieldCollectionInputInstance(fieldInstance) && isIntegerFieldCollectionInputTemplate(fieldTemplate)) {
return <NumberFieldCollectionInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isFloatFieldCollectionInputInstance(fieldInstance) && isFloatFieldCollectionInputTemplate(fieldTemplate)) {
return <NumberFieldCollectionInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isEnumFieldInputInstance(fieldInstance) && isEnumFieldInputTemplate(fieldTemplate)) {
return <EnumFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
@@ -246,18 +216,6 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
return <SchedulerFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isFloatGeneratorFieldInputInstance(fieldInstance) && isFloatGeneratorFieldInputTemplate(fieldTemplate)) {
return <FloatGeneratorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isIntegerGeneratorFieldInputInstance(fieldInstance) && isIntegerGeneratorFieldInputTemplate(fieldTemplate)) {
return <IntegerGeneratorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isStringGeneratorFieldInputInstance(fieldInstance) && isStringGeneratorFieldInputTemplate(fieldTemplate)) {
return <StringGeneratorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (fieldTemplate) {
// Fallback for when there is no component for the type
return null;

View File

@@ -1,6 +1,6 @@
import { Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { $templates } from 'features/nodes/store/nodesSlice';
import { selectInvocationNode, selectNodesSlice } from 'features/nodes/store/selectors';
@@ -18,7 +18,7 @@ export const InvocationInputFieldCheck = memo(({ nodeId, fieldName, children }:
const templates = useStore($templates);
const selector = useMemo(
() =>
createMemoizedSelector(selectNodesSlice, (nodesSlice) => {
createSelector(selectNodesSlice, (nodesSlice) => {
const node = selectInvocationNode(nodesSlice, nodeId);
const instance = node.data.inputs[fieldName];
const template = templates[node.data.type];

View File

@@ -26,7 +26,7 @@ const EnumFieldInputComponent = (props: FieldComponentProps<EnumFieldInputInstan
);
return (
<Select className="nowheel nodrag" onChange={handleValueChanged} value={field.value} size="sm">
<Select className="nowheel nodrag" onChange={handleValueChanged} value={field.value}>
{fieldTemplate.options.map((option) => (
<option key={option} value={option}>
{fieldTemplate.ui_choice_labels ? fieldTemplate.ui_choice_labels[option] : option}

View File

@@ -1,57 +0,0 @@
import { CompositeNumberInput, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import type { FloatGeneratorArithmeticSequence } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
type FloatGeneratorArithmeticSequenceSettingsProps = {
state: FloatGeneratorArithmeticSequence;
onChange: (state: FloatGeneratorArithmeticSequence) => void;
};
export const FloatGeneratorArithmeticSequenceSettings = memo(
({ state, onChange }: FloatGeneratorArithmeticSequenceSettingsProps) => {
const { t } = useTranslation();
const onChangeStart = useCallback(
(start: number) => {
onChange({ ...state, start });
},
[onChange, state]
);
const onChangeStep = useCallback(
(step: number) => {
onChange({ ...state, step });
},
[onChange, state]
);
const onChangeCount = useCallback(
(count: number) => {
onChange({ ...state, count });
},
[onChange, state]
);
return (
<Flex gap={2} alignItems="flex-end">
<FormControl orientation="vertical">
<FormLabel>{t('common.start')}</FormLabel>
<CompositeNumberInput
value={state.start}
onChange={onChangeStart}
min={-Infinity}
max={Infinity}
step={0.01}
/>
</FormControl>
<FormControl orientation="vertical">
<FormLabel>{t('common.step')}</FormLabel>
<CompositeNumberInput value={state.step} onChange={onChangeStep} min={-Infinity} max={Infinity} step={0.01} />
</FormControl>
<FormControl orientation="vertical">
<FormLabel>{t('common.count')}</FormLabel>
<CompositeNumberInput value={state.count} onChange={onChangeCount} min={1} max={Infinity} />
</FormControl>
</Flex>
);
}
);
FloatGeneratorArithmeticSequenceSettings.displayName = 'FloatGeneratorArithmeticSequenceSettings';

View File

@@ -1,120 +0,0 @@
import { Flex, Select, Text } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { getOverlayScrollbarsParams, overlayScrollbarsStyles } from 'common/components/OverlayScrollbars/constants';
import { FloatGeneratorArithmeticSequenceSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FloatGeneratorArithmeticSequenceSettings';
import { FloatGeneratorLinearDistributionSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FloatGeneratorLinearDistributionSettings';
import { FloatGeneratorParseStringSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FloatGeneratorParseStringSettings';
import { FloatGeneratorUniformRandomDistributionSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FloatGeneratorUniformRandomDistributionSettings';
import type { FieldComponentProps } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/types';
import { fieldFloatGeneratorValueChanged } from 'features/nodes/store/nodesSlice';
import type { FloatGeneratorFieldInputInstance, FloatGeneratorFieldInputTemplate } from 'features/nodes/types/field';
import {
FloatGeneratorArithmeticSequenceType,
FloatGeneratorLinearDistributionType,
FloatGeneratorParseStringType,
FloatGeneratorUniformRandomDistributionType,
getFloatGeneratorDefaults,
resolveFloatGeneratorField,
} from 'features/nodes/types/field';
import { isNil, round } from 'lodash-es';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import type { ChangeEvent } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useDebounce } from 'use-debounce';
const overlayscrollbarsOptions = getOverlayScrollbarsParams().options;
export const FloatGeneratorFieldInputComponent = memo(
(props: FieldComponentProps<FloatGeneratorFieldInputInstance, FloatGeneratorFieldInputTemplate>) => {
const { nodeId, field } = props;
const { t } = useTranslation();
const dispatch = useAppDispatch();
const onChange = useCallback(
(value: FloatGeneratorFieldInputInstance['value']) => {
dispatch(
fieldFloatGeneratorValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const onChangeGeneratorType = useCallback(
(e: ChangeEvent<HTMLSelectElement>) => {
const value = getFloatGeneratorDefaults(e.target.value as FloatGeneratorFieldInputInstance['value']['type']);
if (!value) {
return;
}
dispatch(
fieldFloatGeneratorValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const [debouncedField] = useDebounce(field, 300);
const resolvedValuesAsString = useMemo(() => {
if (
debouncedField.value.type === FloatGeneratorUniformRandomDistributionType &&
isNil(debouncedField.value.seed)
) {
const { count } = debouncedField.value;
return `<${t('nodes.generatorNRandomValues', { count })}>`;
}
const resolvedValues = resolveFloatGeneratorField(debouncedField);
if (resolvedValues.length === 0) {
return `<${t('nodes.generatorNoValues')}>`;
} else {
return resolvedValues.map((val) => round(val, 2)).join(', ');
}
}, [debouncedField, t]);
return (
<Flex flexDir="column" gap={2}>
<Select className="nowheel nodrag" onChange={onChangeGeneratorType} value={field.value.type} size="sm">
<option value={FloatGeneratorArithmeticSequenceType}>{t('nodes.arithmeticSequence')}</option>
<option value={FloatGeneratorLinearDistributionType}>{t('nodes.linearDistribution')}</option>
<option value={FloatGeneratorUniformRandomDistributionType}>{t('nodes.uniformRandomDistribution')}</option>
<option value={FloatGeneratorParseStringType}>{t('nodes.parseString')}</option>
</Select>
{field.value.type === FloatGeneratorArithmeticSequenceType && (
<FloatGeneratorArithmeticSequenceSettings state={field.value} onChange={onChange} />
)}
{field.value.type === FloatGeneratorLinearDistributionType && (
<FloatGeneratorLinearDistributionSettings state={field.value} onChange={onChange} />
)}
{field.value.type === FloatGeneratorUniformRandomDistributionType && (
<FloatGeneratorUniformRandomDistributionSettings state={field.value} onChange={onChange} />
)}
{field.value.type === FloatGeneratorParseStringType && (
<FloatGeneratorParseStringSettings state={field.value} onChange={onChange} />
)}
<Flex w="full" h="full" p={2} borderWidth={1} borderRadius="base" maxH={128}>
<Flex w="full" h="auto">
<OverlayScrollbarsComponent
className="nodrag nowheel"
defer
style={overlayScrollbarsStyles}
options={overlayscrollbarsOptions}
>
<Text className="nodrag nowheel" fontFamily="monospace" userSelect="text" cursor="text">
{resolvedValuesAsString}
</Text>
</OverlayScrollbarsComponent>
</Flex>
</Flex>
</Flex>
);
}
);
FloatGeneratorFieldInputComponent.displayName = 'FloatGeneratorFieldInputComponent';

View File

@@ -1,57 +0,0 @@
import { CompositeNumberInput, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import type { FloatGeneratorLinearDistribution } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
type FloatGeneratorLinearDistributionSettingsProps = {
state: FloatGeneratorLinearDistribution;
onChange: (state: FloatGeneratorLinearDistribution) => void;
};
export const FloatGeneratorLinearDistributionSettings = memo(
({ state, onChange }: FloatGeneratorLinearDistributionSettingsProps) => {
const { t } = useTranslation();
const onChangeStart = useCallback(
(start: number) => {
onChange({ ...state, start });
},
[onChange, state]
);
const onChangeEnd = useCallback(
(end: number) => {
onChange({ ...state, end });
},
[onChange, state]
);
const onChangeCount = useCallback(
(count: number) => {
onChange({ ...state, count });
},
[onChange, state]
);
return (
<Flex gap={2} alignItems="flex-end">
<FormControl orientation="vertical">
<FormLabel>{t('common.start')}</FormLabel>
<CompositeNumberInput
value={state.start}
onChange={onChangeStart}
min={-Infinity}
max={Infinity}
step={0.01}
/>
</FormControl>
<FormControl orientation="vertical">
<FormLabel>{t('common.end')}</FormLabel>
<CompositeNumberInput value={state.end} onChange={onChangeEnd} min={-Infinity} max={Infinity} step={0.01} />
</FormControl>
<FormControl orientation="vertical">
<FormLabel>{t('common.count')}</FormLabel>
<CompositeNumberInput value={state.count} onChange={onChangeCount} min={1} max={Infinity} />
</FormControl>
</Flex>
);
}
);
FloatGeneratorLinearDistributionSettings.displayName = 'FloatGeneratorLinearDistributionSettings';

View File

@@ -1,39 +0,0 @@
import { Flex, FormControl, FormLabel, Input } from '@invoke-ai/ui-library';
import { GeneratorTextareaWithFileUpload } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/GeneratorTextareaWithFileUpload';
import type { FloatGeneratorParseString } from 'features/nodes/types/field';
import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
type FloatGeneratorParseStringSettingsProps = {
state: FloatGeneratorParseString;
onChange: (state: FloatGeneratorParseString) => void;
};
export const FloatGeneratorParseStringSettings = memo(({ state, onChange }: FloatGeneratorParseStringSettingsProps) => {
const { t } = useTranslation();
const onChangeSplitOn = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
onChange({ ...state, splitOn: e.target.value });
},
[onChange, state]
);
const onChangeInput = useCallback(
(input: string) => {
onChange({ ...state, input });
},
[onChange, state]
);
return (
<Flex gap={2} flexDir="column">
<FormControl orientation="vertical">
<FormLabel>{t('nodes.splitOn')}</FormLabel>
<Input value={state.splitOn} onChange={onChangeSplitOn} />
</FormControl>
<GeneratorTextareaWithFileUpload value={state.input} onChange={onChangeInput} />
</Flex>
);
});
FloatGeneratorParseStringSettings.displayName = 'FloatGeneratorParseStringSettings';

View File

@@ -1,78 +0,0 @@
import { Checkbox, CompositeNumberInput, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import type { FloatGeneratorUniformRandomDistribution } from 'features/nodes/types/field';
import { isNil } from 'lodash-es';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
type FloatGeneratorUniformRandomDistributionSettingsProps = {
state: FloatGeneratorUniformRandomDistribution;
onChange: (state: FloatGeneratorUniformRandomDistribution) => void;
};
export const FloatGeneratorUniformRandomDistributionSettings = memo(
({ state, onChange }: FloatGeneratorUniformRandomDistributionSettingsProps) => {
const { t } = useTranslation();
const onChangeMin = useCallback(
(min: number) => {
onChange({ ...state, min });
},
[onChange, state]
);
const onChangeMax = useCallback(
(max: number) => {
onChange({ ...state, max });
},
[onChange, state]
);
const onChangeCount = useCallback(
(count: number) => {
onChange({ ...state, count });
},
[onChange, state]
);
const onToggleSeed = useCallback(() => {
onChange({ ...state, seed: isNil(state.seed) ? 0 : null });
}, [onChange, state]);
const onChangeSeed = useCallback(
(seed?: number | null) => {
onChange({ ...state, seed });
},
[onChange, state]
);
return (
<Flex gap={2} flexDir="column">
<Flex gap={2} alignItems="flex-end">
<FormControl orientation="vertical">
<FormLabel>{t('common.min')}</FormLabel>
<CompositeNumberInput value={state.min} onChange={onChangeMin} min={-Infinity} max={Infinity} step={0.01} />
</FormControl>
<FormControl orientation="vertical">
<FormLabel>{t('common.max')}</FormLabel>
<CompositeNumberInput value={state.max} onChange={onChangeMax} min={-Infinity} max={Infinity} step={0.01} />
</FormControl>
<FormControl orientation="vertical">
<FormLabel>{t('common.count')}</FormLabel>
<CompositeNumberInput value={state.count} onChange={onChangeCount} min={1} max={Infinity} />
</FormControl>
<FormControl orientation="vertical">
<FormLabel alignItems="center" justifyContent="space-between" m={0} display="flex" w="full">
{t('common.seed')}
<Checkbox onChange={onToggleSeed} isChecked={!isNil(state.seed)} />
</FormLabel>
<CompositeNumberInput
isDisabled={isNil(state.seed)}
// This cast is save only because we disable the element when seed is not a number - the `...` is
// rendered in the input field in this case
value={state.seed ?? ('...' as unknown as number)}
onChange={onChangeSeed}
min={-Infinity}
max={Infinity}
/>
</FormControl>
</Flex>
</Flex>
);
}
);
FloatGeneratorUniformRandomDistributionSettings.displayName = 'FloatGeneratorUniformRandomDistributionSettings';

View File

@@ -1,85 +0,0 @@
import { FormControl, FormLabel, IconButton, Spacer, Textarea } from '@invoke-ai/ui-library';
import { toast } from 'features/toast/toast';
import { isString } from 'lodash-es';
import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react';
import { useDropzone } from 'react-dropzone';
import { useTranslation } from 'react-i18next';
import { PiUploadFill } from 'react-icons/pi';
const MAX_SIZE = 1024 * 128; // 128KB, we don't want to load huge files into node values...
type Props = {
value: string;
onChange: (value: string) => void;
};
export const GeneratorTextareaWithFileUpload = memo(({ value, onChange }: Props) => {
const { t } = useTranslation();
const onDropAccepted = useCallback(
(files: File[]) => {
const file = files[0];
if (!file) {
return;
}
const reader = new FileReader();
reader.onload = () => {
const result = reader.result;
if (!isString(result)) {
return;
}
onChange(result);
};
reader.onerror = () => {
toast({
title: 'Failed to load file',
status: 'error',
});
};
reader.readAsText(file);
},
[onChange]
);
const { getInputProps, getRootProps } = useDropzone({
accept: { 'text/csv': ['.csv'], 'text/plain': ['.txt'] },
maxSize: MAX_SIZE,
onDropAccepted,
noDrag: true,
multiple: false,
});
const onChangeInput = useCallback(
(e: ChangeEvent<HTMLTextAreaElement>) => {
onChange(e.target.value);
},
[onChange]
);
return (
<FormControl orientation="vertical" position="relative" alignItems="stretch">
<FormLabel m={0} display="flex" alignItems="center">
{t('common.input')}
<Spacer />
<IconButton
tooltip={t('nodes.generatorLoadFromFile')}
aria-label={t('nodes.generatorLoadFromFile')}
icon={<PiUploadFill />}
variant="link"
{...getRootProps()}
/>
<input {...getInputProps()} />
</FormLabel>
<Textarea
className="nowheel nodrag nopan"
value={value}
onChange={onChangeInput}
p={2}
resize="none"
rows={5}
fontSize="sm"
/>
</FormControl>
);
});
GeneratorTextareaWithFileUpload.displayName = 'GeneratorTextareaWithFileUpload';

View File

@@ -10,9 +10,9 @@ import { addImagesToNodeImageFieldCollectionDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { DndImage } from 'features/dnd/DndImage';
import { DndImageIcon } from 'features/dnd/DndImageIcon';
import { removeImageFromNodeImageFieldCollectionAction } from 'features/imageActions/actions';
import { useFieldIsInvalid } from 'features/nodes/hooks/useFieldIsInvalid';
import { fieldImageCollectionValueChanged } from 'features/nodes/store/nodesSlice';
import type { ImageField } from 'features/nodes/types/common';
import type { ImageFieldCollectionInputInstance, ImageFieldCollectionInputTemplate } from 'features/nodes/types/field';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import { memo, useCallback, useMemo } from 'react';
@@ -61,12 +61,15 @@ export const ImageFieldCollectionInputComponent = memo(
);
const onRemoveImage = useCallback(
(index: number) => {
const newValue = field.value ? [...field.value] : [];
newValue.splice(index, 1);
store.dispatch(fieldImageCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
(imageName: string) => {
removeImageFromNodeImageFieldCollectionAction({
imageName,
fieldIdentifier: { nodeId, fieldName: field.name },
dispatch: store.dispatch,
getState: store.getState,
});
},
[field.name, field.value, nodeId, store]
[field.name, nodeId, store.dispatch, store.getState]
);
return (
@@ -87,7 +90,7 @@ export const ImageFieldCollectionInputComponent = memo(
isError={isInvalid}
onUpload={onUpload}
fontSize={24}
variant="ghost"
variant="outline"
/>
)}
{field.value && field.value.length > 0 && (
@@ -99,9 +102,9 @@ export const ImageFieldCollectionInputComponent = memo(
options={overlayscrollbarsOptions}
>
<Grid w="full" h="full" templateColumns="repeat(4, 1fr)" gap={1}>
{field.value.map((value, index) => (
<GridItem key={index} position="relative" className="nodrag">
<ImageGridItemContent value={value} index={index} onRemoveImage={onRemoveImage} />
{field.value.map(({ image_name }) => (
<GridItem key={image_name} position="relative" className="nodrag">
<ImageGridItemContent imageName={image_name} onRemoveImage={onRemoveImage} />
</GridItem>
))}
</Grid>
@@ -121,11 +124,11 @@ export const ImageFieldCollectionInputComponent = memo(
ImageFieldCollectionInputComponent.displayName = 'ImageFieldCollectionInputComponent';
const ImageGridItemContent = memo(
({ value, index, onRemoveImage }: { value: ImageField; index: number; onRemoveImage: (index: number) => void }) => {
const query = useGetImageDTOQuery(value.image_name);
({ imageName, onRemoveImage }: { imageName: string; onRemoveImage: (imageName: string) => void }) => {
const query = useGetImageDTOQuery(imageName);
const onClickRemove = useCallback(() => {
onRemoveImage(index);
}, [index, onRemoveImage]);
onRemoveImage(imageName);
}, [imageName, onRemoveImage]);
if (query.isLoading) {
return <IAINoContentFallbackWithSpinner />;

View File

@@ -1,51 +0,0 @@
import { CompositeNumberInput, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import type { IntegerGeneratorArithmeticSequence } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
type IntegerGeneratorArithmeticSequenceSettingsProps = {
state: IntegerGeneratorArithmeticSequence;
onChange: (state: IntegerGeneratorArithmeticSequence) => void;
};
export const IntegerGeneratorArithmeticSequenceSettings = memo(
({ state, onChange }: IntegerGeneratorArithmeticSequenceSettingsProps) => {
const { t } = useTranslation();
const onChangeStart = useCallback(
(start: number) => {
onChange({ ...state, start });
},
[onChange, state]
);
const onChangeStep = useCallback(
(step: number) => {
onChange({ ...state, step });
},
[onChange, state]
);
const onChangeCount = useCallback(
(count: number) => {
onChange({ ...state, count });
},
[onChange, state]
);
return (
<Flex gap={2} alignItems="flex-end">
<FormControl orientation="vertical">
<FormLabel>{t('common.start')}</FormLabel>
<CompositeNumberInput value={state.start} onChange={onChangeStart} min={-Infinity} max={Infinity} />
</FormControl>
<FormControl orientation="vertical">
<FormLabel>{t('common.step')}</FormLabel>
<CompositeNumberInput value={state.step} onChange={onChangeStep} min={-Infinity} max={Infinity} />
</FormControl>
<FormControl orientation="vertical">
<FormLabel>{t('common.count')}</FormLabel>
<CompositeNumberInput value={state.count} onChange={onChangeCount} min={1} max={Infinity} />
</FormControl>
</Flex>
);
}
);
IntegerGeneratorArithmeticSequenceSettings.displayName = 'IntegerGeneratorArithmeticSequenceSettings';

View File

@@ -1,122 +0,0 @@
import { Flex, Select, Text } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { getOverlayScrollbarsParams, overlayScrollbarsStyles } from 'common/components/OverlayScrollbars/constants';
import { IntegerGeneratorArithmeticSequenceSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerGeneratorArithmeticSequenceSettings';
import { IntegerGeneratorLinearDistributionSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerGeneratorLinearDistributionSettings';
import { IntegerGeneratorParseStringSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerGeneratorParseStringSettings';
import { IntegerGeneratorUniformRandomDistributionSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerGeneratorUniformRandomDistributionSettings';
import type { FieldComponentProps } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/types';
import { fieldIntegerGeneratorValueChanged } from 'features/nodes/store/nodesSlice';
import type {
IntegerGeneratorFieldInputInstance,
IntegerGeneratorFieldInputTemplate,
} from 'features/nodes/types/field';
import {
getIntegerGeneratorDefaults,
IntegerGeneratorArithmeticSequenceType,
IntegerGeneratorLinearDistributionType,
IntegerGeneratorParseStringType,
IntegerGeneratorUniformRandomDistributionType,
resolveIntegerGeneratorField,
} from 'features/nodes/types/field';
import { isNil, round } from 'lodash-es';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import type { ChangeEvent } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useDebounce } from 'use-debounce';
const overlayscrollbarsOptions = getOverlayScrollbarsParams().options;
export const IntegerGeneratorFieldInputComponent = memo(
(props: FieldComponentProps<IntegerGeneratorFieldInputInstance, IntegerGeneratorFieldInputTemplate>) => {
const { nodeId, field } = props;
const { t } = useTranslation();
const dispatch = useAppDispatch();
const onChange = useCallback(
(value: IntegerGeneratorFieldInputInstance['value']) => {
dispatch(
fieldIntegerGeneratorValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const onChangeGeneratorType = useCallback(
(e: ChangeEvent<HTMLSelectElement>) => {
const value = getIntegerGeneratorDefaults(
e.target.value as IntegerGeneratorFieldInputInstance['value']['type']
);
dispatch(
fieldIntegerGeneratorValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const [debouncedField] = useDebounce(field, 300);
const resolvedValuesAsString = useMemo(() => {
if (
debouncedField.value.type === IntegerGeneratorUniformRandomDistributionType &&
isNil(debouncedField.value.seed)
) {
const { count } = debouncedField.value;
return `<${t('nodes.generatorNRandomValues', { count })}>`;
}
const resolvedValues = resolveIntegerGeneratorField(debouncedField);
if (resolvedValues.length === 0) {
return `<${t('nodes.generatorNoValues')}>`;
} else {
return resolvedValues.map((val) => round(val, 2)).join(', ');
}
}, [debouncedField, t]);
return (
<Flex flexDir="column" gap={2}>
<Select className="nowheel nodrag" onChange={onChangeGeneratorType} value={field.value.type} size="sm">
<option value={IntegerGeneratorArithmeticSequenceType}>{t('nodes.arithmeticSequence')}</option>
<option value={IntegerGeneratorLinearDistributionType}>{t('nodes.linearDistribution')}</option>
<option value={IntegerGeneratorUniformRandomDistributionType}>{t('nodes.uniformRandomDistribution')}</option>
<option value={IntegerGeneratorParseStringType}>{t('nodes.parseString')}</option>
</Select>
{field.value.type === IntegerGeneratorArithmeticSequenceType && (
<IntegerGeneratorArithmeticSequenceSettings state={field.value} onChange={onChange} />
)}
{field.value.type === IntegerGeneratorLinearDistributionType && (
<IntegerGeneratorLinearDistributionSettings state={field.value} onChange={onChange} />
)}
{field.value.type === IntegerGeneratorUniformRandomDistributionType && (
<IntegerGeneratorUniformRandomDistributionSettings state={field.value} onChange={onChange} />
)}
{field.value.type === IntegerGeneratorParseStringType && (
<IntegerGeneratorParseStringSettings state={field.value} onChange={onChange} />
)}
<Flex w="full" h="full" p={2} borderWidth={1} borderRadius="base" maxH={128}>
<Flex w="full" h="auto">
<OverlayScrollbarsComponent
className="nodrag nowheel"
defer
style={overlayScrollbarsStyles}
options={overlayscrollbarsOptions}
>
<Text className="nodrag nowheel" fontFamily="monospace" userSelect="text" cursor="text">
{resolvedValuesAsString}
</Text>
</OverlayScrollbarsComponent>
</Flex>
</Flex>
</Flex>
);
}
);
IntegerGeneratorFieldInputComponent.displayName = 'IntegerGeneratorFieldInputComponent';

View File

@@ -1,51 +0,0 @@
import { CompositeNumberInput, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import type { IntegerGeneratorLinearDistribution } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
type IntegerGeneratorLinearDistributionSettingsProps = {
state: IntegerGeneratorLinearDistribution;
onChange: (state: IntegerGeneratorLinearDistribution) => void;
};
export const IntegerGeneratorLinearDistributionSettings = memo(
({ state, onChange }: IntegerGeneratorLinearDistributionSettingsProps) => {
const { t } = useTranslation();
const onChangeStart = useCallback(
(start: number) => {
onChange({ ...state, start });
},
[onChange, state]
);
const onChangeEnd = useCallback(
(end: number) => {
onChange({ ...state, end });
},
[onChange, state]
);
const onChangeCount = useCallback(
(count: number) => {
onChange({ ...state, count });
},
[onChange, state]
);
return (
<Flex gap={2} alignItems="flex-end">
<FormControl orientation="vertical">
<FormLabel>{t('common.start')}</FormLabel>
<CompositeNumberInput value={state.start} onChange={onChangeStart} min={-Infinity} max={Infinity} />
</FormControl>
<FormControl orientation="vertical">
<FormLabel>{t('common.end')}</FormLabel>
<CompositeNumberInput value={state.end} onChange={onChangeEnd} min={-Infinity} max={Infinity} />
</FormControl>
<FormControl orientation="vertical">
<FormLabel>{t('common.count')}</FormLabel>
<CompositeNumberInput value={state.count} onChange={onChangeCount} min={1} max={Infinity} />
</FormControl>
</Flex>
);
}
);
IntegerGeneratorLinearDistributionSettings.displayName = 'IntegerGeneratorLinearDistributionSettings';

View File

@@ -1,41 +0,0 @@
import { Flex, FormControl, FormLabel, Input } from '@invoke-ai/ui-library';
import { GeneratorTextareaWithFileUpload } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/GeneratorTextareaWithFileUpload';
import type { IntegerGeneratorParseString } from 'features/nodes/types/field';
import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
type IntegerGeneratorParseStringSettingsProps = {
state: IntegerGeneratorParseString;
onChange: (state: IntegerGeneratorParseString) => void;
};
export const IntegerGeneratorParseStringSettings = memo(
({ state, onChange }: IntegerGeneratorParseStringSettingsProps) => {
const { t } = useTranslation();
const onChangeSplitOn = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
onChange({ ...state, splitOn: e.target.value });
},
[onChange, state]
);
const onChangeInput = useCallback(
(input: string) => {
onChange({ ...state, input });
},
[onChange, state]
);
return (
<Flex gap={2} flexDir="column">
<FormControl orientation="vertical">
<FormLabel>{t('nodes.splitOn')}</FormLabel>
<Input value={state.splitOn} onChange={onChangeSplitOn} />
</FormControl>
<GeneratorTextareaWithFileUpload value={state.input} onChange={onChangeInput} />
</Flex>
);
}
);
IntegerGeneratorParseStringSettings.displayName = 'IntegerGeneratorParseStringSettings';

Some files were not shown because too many files have changed in this diff Show More