mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-22 01:38:17 -05:00
Compare commits
148 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7596c07a64 | ||
|
|
98fd1d949b | ||
|
|
6312e6aa8f | ||
|
|
6435f11bae | ||
|
|
1c69b9b1fa | ||
|
|
731970ff88 | ||
|
|
038bac1614 | ||
|
|
ed9efe7740 | ||
|
|
ffa0beba7a | ||
|
|
75d793f1c4 | ||
|
|
2b086917e0 | ||
|
|
a9f2738086 | ||
|
|
3a56799ea5 | ||
|
|
3162ce94dc | ||
|
|
c0dc6ac4e1 | ||
|
|
fed1995525 | ||
|
|
5006e23456 | ||
|
|
2f063bddda | ||
|
|
23a26422fd | ||
|
|
434f195a96 | ||
|
|
6a4c2d692c | ||
|
|
5127a07cf9 | ||
|
|
0b4c6f0ab4 | ||
|
|
d8450033ea | ||
|
|
3938736bd8 | ||
|
|
fb2c7b9566 | ||
|
|
29449ec27d | ||
|
|
e38f778d28 | ||
|
|
f5e78436a8 | ||
|
|
6a15b5d9be | ||
|
|
a629102c87 | ||
|
|
848ade8ab8 | ||
|
|
2110feb01c | ||
|
|
f3e1821957 | ||
|
|
bbcf93089a | ||
|
|
66f41aa307 | ||
|
|
8a709766b3 | ||
|
|
efaa20a7a1 | ||
|
|
3e4c808b23 | ||
|
|
00e3931af4 | ||
|
|
08bea07f8b | ||
|
|
166d2f0e39 | ||
|
|
21f346717a | ||
|
|
f966fb8b9c | ||
|
|
c2b20a5387 | ||
|
|
bed9089fe6 | ||
|
|
d34a4f765c | ||
|
|
efe4708b8b | ||
|
|
7cb1f61a9e | ||
|
|
6e2ef34cba | ||
|
|
d208b99a47 | ||
|
|
47eeafa5cb | ||
|
|
0cb00fbe53 | ||
|
|
a7e8ed3bc2 | ||
|
|
22eb25be48 | ||
|
|
a077f3fefc | ||
|
|
c013a6e38d | ||
|
|
6cfeb71bed | ||
|
|
534f993023 | ||
|
|
67f9b6420c | ||
|
|
61bf065237 | ||
|
|
e78cf889ee | ||
|
|
5d13f0ba15 | ||
|
|
633b9afa46 | ||
|
|
f1889b259d | ||
|
|
ed21d0b57e | ||
|
|
df90da28e1 | ||
|
|
702054aa62 | ||
|
|
636ec1de6e | ||
|
|
063d07fd41 | ||
|
|
c78eac624e | ||
|
|
05de3b7a84 | ||
|
|
9cc2232b6f | ||
|
|
9fdc06b447 | ||
|
|
5ea3ec5cc8 | ||
|
|
f13a07ba6a | ||
|
|
a913f0163d | ||
|
|
f7cfbd1323 | ||
|
|
2806b60701 | ||
|
|
d8c3af624b | ||
|
|
feed44b68d | ||
|
|
247f3b5d67 | ||
|
|
8e14f9d971 | ||
|
|
bdb44ee48d | ||
|
|
b57f5330c5 | ||
|
|
ade3c015b4 | ||
|
|
7fe4d4c21a | ||
|
|
133a7fde55 | ||
|
|
6375214878 | ||
|
|
b9972be7f1 | ||
|
|
e61c5a3f26 | ||
|
|
8c633786f6 | ||
|
|
8703eea49b | ||
|
|
c8888be4c3 | ||
|
|
11963a65a4 | ||
|
|
ab6422fdf7 | ||
|
|
1f8632029e | ||
|
|
88a762474d | ||
|
|
e6dd721e33 | ||
|
|
2a09604baf | ||
|
|
f94f00ede0 | ||
|
|
37af281299 | ||
|
|
fc82775d7a | ||
|
|
9ed46f60b7 | ||
|
|
9a389e6b93 | ||
|
|
2ef1ecf381 | ||
|
|
41de112932 | ||
|
|
e9714fe476 | ||
|
|
3f29293e39 | ||
|
|
db1aa38e98 | ||
|
|
12717d4a4d | ||
|
|
1953f3cbcd | ||
|
|
3469fc9843 | ||
|
|
7cdd4187a9 | ||
|
|
ad66c101d2 | ||
|
|
28d3356710 | ||
|
|
81e70fb9d2 | ||
|
|
971c425734 | ||
|
|
b09008c530 | ||
|
|
f9f99f873d | ||
|
|
7f93f1b600 | ||
|
|
b1d336ce8a | ||
|
|
40c7be8f5d | ||
|
|
24218b34bf | ||
|
|
d970c6d6d5 | ||
|
|
e5308be0bb | ||
|
|
7d5687e9ff | ||
|
|
654e992630 | ||
|
|
21f247f499 | ||
|
|
8bcd9fe4b7 | ||
|
|
637b93d2d8 | ||
|
|
565b160060 | ||
|
|
bdd0b90769 | ||
|
|
4377158503 | ||
|
|
c8c27079ed | ||
|
|
d8b9a8d0dd | ||
|
|
39a4608d15 | ||
|
|
cd2d5431db | ||
|
|
c04cdd9779 | ||
|
|
b86ac5e049 | ||
|
|
665236bb79 | ||
|
|
f45400a275 | ||
|
|
be53b89203 | ||
|
|
a215eeaabf | ||
|
|
d86b392bfd | ||
|
|
3e9e45b177 | ||
|
|
907d960745 | ||
|
|
bfdace6437 |
3
.gitattributes
vendored
3
.gitattributes
vendored
@@ -2,4 +2,5 @@
|
||||
# Only affects text files and ignores other file types.
|
||||
# For more info see: https://www.aleksandrhovhannisyan.com/blog/crlf-vs-lf-normalizing-line-endings-in-git/
|
||||
* text=auto
|
||||
docker/** text eol=lf
|
||||
docker/** text eol=lf
|
||||
tests/test_model_probe/stripped_models/** filter=lfs diff=lfs merge=lfs -text
|
||||
|
||||
3
.github/workflows/python-tests.yml
vendored
3
.github/workflows/python-tests.yml
vendored
@@ -72,7 +72,8 @@ jobs:
|
||||
PIP_USE_PEP517: '1'
|
||||
steps:
|
||||
- name: checkout
|
||||
uses: actions/checkout@v4
|
||||
# https://github.com/nschloe/action-cached-lfs-checkout
|
||||
uses: nschloe/action-cached-lfs-checkout@f46300cd8952454b9f0a21a3d133d4bd5684cfc2
|
||||
|
||||
- name: check for changed python files
|
||||
if: ${{ inputs.always_run != true }}
|
||||
|
||||
@@ -18,9 +18,19 @@ If you just want to use Invoke, you should use the [launcher][launcher link].
|
||||
|
||||
2. [Fork and clone][forking link] the [InvokeAI repo][repo link].
|
||||
|
||||
3. Create an directory for user data (images, models, db, etc). This is typically at `~/invokeai`, but if you already have a non-dev install, you may want to create a separate directory for the dev install.
|
||||
3. This repository uses Git LFS to manage large files. To ensure all assets are downloaded:
|
||||
- Install git-lfs → [Download here](https://git-lfs.com/)
|
||||
- Enable automatic LFS fetching for this repository:
|
||||
```shell
|
||||
git config lfs.fetchinclude "*"
|
||||
```
|
||||
- Fetch files from LFS (only needs to be done once; subsequent `git pull` will fetch changes automatically):
|
||||
```
|
||||
git lfs pull
|
||||
```
|
||||
4. Create an directory for user data (images, models, db, etc). This is typically at `~/invokeai`, but if you already have a non-dev install, you may want to create a separate directory for the dev install.
|
||||
|
||||
4. Follow the [manual install][manual install link] guide, with some modifications to the install command:
|
||||
5. Follow the [manual install][manual install link] guide, with some modifications to the install command:
|
||||
|
||||
- Use `.` instead of `invokeai` to install from the current directory. You don't need to specify the version.
|
||||
|
||||
@@ -34,19 +44,19 @@ If you just want to use Invoke, you should use the [launcher][launcher link].
|
||||
uv pip install -e ".[dev,test,docs,xformers]" --python 3.11 --python-preference only-managed --index=https://download.pytorch.org/whl/cu124 --reinstall
|
||||
```
|
||||
|
||||
5. At this point, you should have Invoke installed, a venv set up and activated, and the server running. But you will see a warning in the terminal that no UI was found. If you go to the URL for the server, you won't get a UI.
|
||||
6. At this point, you should have Invoke installed, a venv set up and activated, and the server running. But you will see a warning in the terminal that no UI was found. If you go to the URL for the server, you won't get a UI.
|
||||
|
||||
This is because the UI build is not distributed with the source code. You need to build it manually. End the running server instance.
|
||||
|
||||
If you only want to edit the docs, you can stop here and skip to the **Documentation** section below.
|
||||
|
||||
6. Install the frontend dev toolchain:
|
||||
7. Install the frontend dev toolchain:
|
||||
|
||||
- [`nodejs`](https://nodejs.org/) (v20+)
|
||||
|
||||
- [`pnpm`](https://pnpm.io/8.x/installation) (must be v8 - not v9!)
|
||||
|
||||
7. Do a production build of the frontend:
|
||||
8. Do a production build of the frontend:
|
||||
|
||||
```sh
|
||||
cd <PATH_TO_INVOKEAI_REPO>/invokeai/frontend/web
|
||||
@@ -54,7 +64,7 @@ If you just want to use Invoke, you should use the [launcher][launcher link].
|
||||
pnpm build
|
||||
```
|
||||
|
||||
8. Restart the server and navigate to the URL. You should get a UI. After making changes to the python code, restart the server to see those changes.
|
||||
9. Restart the server and navigate to the URL. You should get a UI. After making changes to the python code, restart the server to see those changes.
|
||||
|
||||
## Updating the UI
|
||||
|
||||
|
||||
@@ -59,6 +59,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
ControlLoRAModel = "ControlLoRAModelField"
|
||||
SigLipModel = "SigLipModelField"
|
||||
FluxReduxModel = "FluxReduxModelField"
|
||||
LlavaOnevisionModel = "LLaVAModelField"
|
||||
# endregion
|
||||
|
||||
# region Misc Field Types
|
||||
@@ -205,6 +206,8 @@ class FieldDescriptions:
|
||||
freeu_b2 = "Scaling factor for stage 2 to amplify the contributions of backbone features."
|
||||
instantx_control_mode = "The control mode for InstantX ControlNet union models. Ignored for other ControlNet models. The standard mapping is: canny (0), tile (1), depth (2), blur (3), pose (4), gray (5), low quality (6). Negative values will be treated as 'None'."
|
||||
flux_redux_conditioning = "FLUX Redux conditioning tensor"
|
||||
vllm_model = "The VLLM model to use"
|
||||
flux_fill_conditioning = "FLUX Fill conditioning tensor"
|
||||
|
||||
|
||||
class ImageField(BaseModel):
|
||||
@@ -274,6 +277,13 @@ class FluxReduxConditioningField(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class FluxFillConditioningField(BaseModel):
|
||||
"""A FLUX Fill conditioning field."""
|
||||
|
||||
image: ImageField = Field(description="The FLUX Fill reference image.")
|
||||
mask: TensorField = Field(description="The FLUX Fill inpaint mask.")
|
||||
|
||||
|
||||
class SD3ConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from invokeai.app.invocations.fields import (
|
||||
DenoiseMaskField,
|
||||
FieldDescriptions,
|
||||
FluxConditioningField,
|
||||
FluxFillConditioningField,
|
||||
FluxReduxConditioningField,
|
||||
ImageField,
|
||||
Input,
|
||||
@@ -48,7 +49,7 @@ from invokeai.backend.flux.sampling_utils import (
|
||||
unpack,
|
||||
)
|
||||
from invokeai.backend.flux.text_conditioning import FluxReduxConditioning, FluxTextConditioning
|
||||
from invokeai.backend.model_manager.config import ModelFormat
|
||||
from invokeai.backend.model_manager.config import ModelFormat, ModelVariantType
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
@@ -62,7 +63,7 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
title="FLUX Denoise",
|
||||
tags=["image", "flux"],
|
||||
category="image",
|
||||
version="3.2.3",
|
||||
version="3.3.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
@@ -109,6 +110,11 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
description="FLUX Redux conditioning tensor.",
|
||||
input=Input.Connection,
|
||||
)
|
||||
fill_conditioning: FluxFillConditioningField | None = InputField(
|
||||
default=None,
|
||||
description="FLUX Fill conditioning.",
|
||||
input=Input.Connection,
|
||||
)
|
||||
cfg_scale: float | list[float] = InputField(default=1.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
|
||||
cfg_scale_start_step: int = InputField(
|
||||
default=0,
|
||||
@@ -261,8 +267,19 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
if is_schnell and self.control_lora:
|
||||
raise ValueError("Control LoRAs cannot be used with FLUX Schnell")
|
||||
|
||||
# Prepare the extra image conditioning tensor if a FLUX structural control image is provided.
|
||||
img_cond = self._prep_structural_control_img_cond(context)
|
||||
# Prepare the extra image conditioning tensor (img_cond) for either FLUX structural control or FLUX Fill.
|
||||
img_cond: torch.Tensor | None = None
|
||||
is_flux_fill = transformer_config.variant == ModelVariantType.Inpaint # type: ignore
|
||||
if is_flux_fill:
|
||||
img_cond = self._prep_flux_fill_img_cond(
|
||||
context, device=TorchDevice.choose_torch_device(), dtype=inference_dtype
|
||||
)
|
||||
else:
|
||||
if self.fill_conditioning is not None:
|
||||
raise ValueError("fill_conditioning was provided, but the model is not a FLUX Fill model.")
|
||||
|
||||
if self.control_lora is not None:
|
||||
img_cond = self._prep_structural_control_img_cond(context)
|
||||
|
||||
inpaint_mask = self._prep_inpaint_mask(context, x)
|
||||
|
||||
@@ -271,7 +288,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
# Pack all latent tensors.
|
||||
init_latents = pack(init_latents) if init_latents is not None else None
|
||||
inpaint_mask = pack(inpaint_mask) if inpaint_mask is not None else None
|
||||
img_cond = pack(img_cond) if img_cond is not None else None
|
||||
noise = pack(noise)
|
||||
x = pack(x)
|
||||
|
||||
@@ -664,7 +680,70 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
img_cond = einops.rearrange(img_cond, "h w c -> 1 c h w")
|
||||
|
||||
vae_info = context.models.load(self.controlnet_vae.vae)
|
||||
return FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=img_cond)
|
||||
img_cond = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=img_cond)
|
||||
|
||||
return pack(img_cond)
|
||||
|
||||
def _prep_flux_fill_img_cond(
|
||||
self, context: InvocationContext, device: torch.device, dtype: torch.dtype
|
||||
) -> torch.Tensor:
|
||||
"""Prepare the FLUX Fill conditioning. This method should be called iff the model is a FLUX Fill model.
|
||||
|
||||
This logic is based on:
|
||||
https://github.com/black-forest-labs/flux/blob/716724eb276d94397be99710a0a54d352664e23b/src/flux/sampling.py#L107-L157
|
||||
"""
|
||||
# Validate inputs.
|
||||
if self.fill_conditioning is None:
|
||||
raise ValueError("A FLUX Fill model is being used without fill_conditioning.")
|
||||
# TODO(ryand): We should probable rename controlnet_vae. It's used for more than just ControlNets.
|
||||
if self.controlnet_vae is None:
|
||||
raise ValueError("A FLUX Fill model is being used without controlnet_vae.")
|
||||
if self.control_lora is not None:
|
||||
raise ValueError(
|
||||
"A FLUX Fill model is being used, but a control_lora was provided. Control LoRAs are not compatible with FLUX Fill models."
|
||||
)
|
||||
|
||||
# Log input warnings related to FLUX Fill usage.
|
||||
if self.denoise_mask is not None:
|
||||
context.logger.warning(
|
||||
"Both fill_conditioning and a denoise_mask were provided. You probably meant to use one or the other."
|
||||
)
|
||||
if self.guidance < 25.0:
|
||||
context.logger.warning("A guidance value of ~30.0 is recommended for FLUX Fill models.")
|
||||
|
||||
# Load the conditioning image and resize it to the target image size.
|
||||
cond_img = context.images.get_pil(self.fill_conditioning.image.image_name, mode="RGB")
|
||||
cond_img = cond_img.resize((self.width, self.height), Image.Resampling.BICUBIC)
|
||||
cond_img = np.array(cond_img)
|
||||
cond_img = torch.from_numpy(cond_img).float() / 127.5 - 1.0
|
||||
cond_img = einops.rearrange(cond_img, "h w c -> 1 c h w")
|
||||
cond_img = cond_img.to(device=device, dtype=dtype)
|
||||
|
||||
# Load the mask and resize it to the target image size.
|
||||
mask = context.tensors.load(self.fill_conditioning.mask.tensor_name)
|
||||
# We expect mask to be a bool tensor with shape [1, H, W].
|
||||
assert mask.dtype == torch.bool
|
||||
assert mask.dim() == 3
|
||||
assert mask.shape[0] == 1
|
||||
mask = tv_resize(mask, size=[self.height, self.width], interpolation=tv_transforms.InterpolationMode.NEAREST)
|
||||
mask = mask.to(device=device, dtype=dtype)
|
||||
mask = einops.rearrange(mask, "1 h w -> 1 1 h w")
|
||||
|
||||
# Prepare image conditioning.
|
||||
cond_img = cond_img * (1 - mask)
|
||||
vae_info = context.models.load(self.controlnet_vae.vae)
|
||||
cond_img = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=cond_img)
|
||||
cond_img = pack(cond_img)
|
||||
|
||||
# Prepare mask conditioning.
|
||||
mask = mask[:, 0, :, :]
|
||||
# Rearrange mask to a 16-channel representation that matches the shape of the VAE-encoded latent space.
|
||||
mask = einops.rearrange(mask, "b (h ph) (w pw) -> b (ph pw) h w", ph=8, pw=8)
|
||||
mask = pack(mask)
|
||||
|
||||
# Merge image and mask conditioning.
|
||||
img_cond = torch.cat((cond_img, mask), dim=-1)
|
||||
return img_cond
|
||||
|
||||
def _normalize_ip_adapter_fields(self) -> list[IPAdapterField]:
|
||||
if self.ip_adapter is None:
|
||||
|
||||
46
invokeai/app/invocations/flux_fill.py
Normal file
46
invokeai/app/invocations/flux_fill.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
FluxFillConditioningField,
|
||||
InputField,
|
||||
OutputField,
|
||||
TensorField,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
@invocation_output("flux_fill_output")
|
||||
class FluxFillOutput(BaseInvocationOutput):
|
||||
"""The conditioning output of a FLUX Fill invocation."""
|
||||
|
||||
fill_cond: FluxFillConditioningField = OutputField(
|
||||
description=FieldDescriptions.flux_redux_conditioning, title="Conditioning"
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_fill",
|
||||
title="FLUX Fill Conditioning",
|
||||
tags=["inpaint"],
|
||||
category="inpaint",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxFillInvocation(BaseInvocation):
|
||||
"""Prepare the FLUX Fill conditioning data."""
|
||||
|
||||
image: ImageField = InputField(description="The FLUX Fill reference image.")
|
||||
mask: TensorField = InputField(
|
||||
description="The bool inpainting mask. Excluded regions should be set to "
|
||||
"False, included regions should be set to True.",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxFillOutput:
|
||||
return FluxFillOutput(fill_cond=FluxFillConditioningField(image=self.image, mask=self.mask))
|
||||
@@ -28,10 +28,10 @@ class FluxLoRALoaderOutput(BaseInvocationOutput):
|
||||
|
||||
@invocation(
|
||||
"flux_lora_loader",
|
||||
title="FLUX LoRA",
|
||||
title="Apply LoRA - FLUX",
|
||||
tags=["lora", "model", "flux"],
|
||||
category="model",
|
||||
version="1.2.0",
|
||||
version="1.2.1",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxLoRALoaderInvocation(BaseInvocation):
|
||||
@@ -107,10 +107,10 @@ class FluxLoRALoaderInvocation(BaseInvocation):
|
||||
|
||||
@invocation(
|
||||
"flux_lora_collection_loader",
|
||||
title="FLUX LoRA Collection Loader",
|
||||
title="Apply LoRA Collection - FLUX",
|
||||
tags=["lora", "model", "flux"],
|
||||
category="model",
|
||||
version="1.3.0",
|
||||
version="1.3.1",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FLUXLoRACollectionLoader(BaseInvocation):
|
||||
|
||||
@@ -1051,7 +1051,7 @@ class MaskFromIDInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
tags=["image", "mask", "id"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
classification=Classification.Internal,
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Handles Canvas V2 image output masking and cropping"""
|
||||
@@ -1089,6 +1089,112 @@ class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
|
||||
@invocation(
|
||||
"expand_mask_with_fade", title="Expand Mask with Fade", tags=["image", "mask"], category="image", version="1.0.0"
|
||||
)
|
||||
class ExpandMaskWithFadeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Expands a mask with a fade effect. The mask uses black to indicate areas to keep from the generated image and white for areas to discard.
|
||||
The mask is thresholded to create a binary mask, and then a distance transform is applied to create a fade effect.
|
||||
The fade size is specified in pixels, and the mask is expanded by that amount. The result is a mask with a smooth transition from black to white.
|
||||
"""
|
||||
|
||||
mask: ImageField = InputField(description="The mask to expand")
|
||||
threshold: int = InputField(default=0, ge=0, le=255, description="The threshold for the binary mask (0-255)")
|
||||
fade_size_px: int = InputField(default=32, ge=0, description="The size of the fade in pixels")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
pil_mask = context.images.get_pil(self.mask.image_name, mode="L")
|
||||
|
||||
np_mask = numpy.array(pil_mask)
|
||||
|
||||
# Threshold the mask to create a binary mask - 0 for black, 255 for white
|
||||
# If we don't threshold we can get some weird artifacts
|
||||
np_mask = numpy.where(np_mask > self.threshold, 255, 0).astype(numpy.uint8)
|
||||
|
||||
# Create a mask for the black region (1 where black, 0 otherwise)
|
||||
black_mask = (np_mask == 0).astype(numpy.uint8)
|
||||
|
||||
# Invert the black region
|
||||
bg_mask = 1 - black_mask
|
||||
|
||||
# Create a distance transform of the inverted mask
|
||||
dist = cv2.distanceTransform(bg_mask, cv2.DIST_L2, 5)
|
||||
|
||||
# Normalize distances so that pixels <fade_size_px become a linear gradient (0 to 1)
|
||||
d_norm = numpy.clip(dist / self.fade_size_px, 0, 1)
|
||||
|
||||
# Control points: x values (normalized distance) and corresponding fade pct y values.
|
||||
|
||||
# There are some magic numbers here that are used to create a smooth transition:
|
||||
# - The first point is at 0% of fade size from edge of mask (meaning the edge of the mask), and is 0% fade (black)
|
||||
# - The second point is 1px from the edge of the mask and also has 0% fade, effectively expanding the mask
|
||||
# by 1px. This fixes an issue where artifacts can occur at the edge of the mask
|
||||
# - The third point is at 20% of the fade size from the edge of the mask and has 20% fade
|
||||
# - The fourth point is at 80% of the fade size from the edge of the mask and has 90% fade
|
||||
# - The last point is at 100% of the fade size from the edge of the mask and has 100% fade (white)
|
||||
|
||||
# x values: 0 = mask edge, 1 = fade_size_px from edge
|
||||
x_control = numpy.array([0.0, 1.0 / self.fade_size_px, 0.2, 0.8, 1.0])
|
||||
# y values: 0 = black, 1 = white
|
||||
y_control = numpy.array([0.0, 0.0, 0.2, 0.9, 1.0])
|
||||
|
||||
# Fit a cubic polynomial that smoothly passes through the control points
|
||||
coeffs = numpy.polyfit(x_control, y_control, 3)
|
||||
poly = numpy.poly1d(coeffs)
|
||||
|
||||
# Evaluate and clip the smooth mapping
|
||||
feather = numpy.clip(poly(d_norm), 0, 1)
|
||||
|
||||
# Build final image.
|
||||
np_result = numpy.where(black_mask == 1, 0, (feather * 255).astype(numpy.uint8))
|
||||
|
||||
# Convert back to PIL, grayscale
|
||||
pil_result = Image.fromarray(np_result.astype(numpy.uint8), mode="L")
|
||||
|
||||
image_dto = context.images.save(image=pil_result, image_category=ImageCategory.MASK)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
|
||||
@invocation(
|
||||
"apply_mask_to_image",
|
||||
title="Apply Mask to Image",
|
||||
tags=["image", "mask", "blend"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ApplyMaskToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""
|
||||
Extracts a region from a generated image using a mask and blends it seamlessly onto a source image.
|
||||
The mask uses black to indicate areas to keep from the generated image and white for areas to discard.
|
||||
"""
|
||||
|
||||
image: ImageField = InputField(description="The image from which to extract the masked region")
|
||||
mask: ImageField = InputField(description="The mask defining the region (black=keep, white=discard)")
|
||||
invert_mask: bool = InputField(
|
||||
default=False,
|
||||
description="Whether to invert the mask before applying it",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# Load images
|
||||
image = context.images.get_pil(self.image.image_name, mode="RGBA")
|
||||
mask = context.images.get_pil(self.mask.image_name, mode="L")
|
||||
|
||||
if self.invert_mask:
|
||||
# Invert the mask if requested
|
||||
mask = ImageOps.invert(mask.copy())
|
||||
|
||||
# Combine the mask as the alpha channel of the image
|
||||
r, g, b, _ = image.split() # Split the image into RGB and alpha channels
|
||||
result_image = Image.merge("RGBA", (r, g, b, mask)) # Use the mask as the new alpha channel
|
||||
|
||||
# Save the resulting image
|
||||
image_dto = context.images.save(image=result_image)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
|
||||
@invocation(
|
||||
"img_noise",
|
||||
title="Add Image Noise",
|
||||
|
||||
60
invokeai/app/invocations/llava_onevision_vllm.py
Normal file
60
invokeai/app/invocations/llava_onevision_vllm.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
from pydantic import field_validator
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, UIComponent, UIType
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import StringOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.llava_onevision_model import LlavaOnevisionModel
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation("llava_onevision_vllm", title="LLaVA OneVision VLLM", tags=["vllm"], category="vllm", version="1.0.0")
|
||||
class LlavaOnevisionVllmInvocation(BaseInvocation):
|
||||
"""Run a LLaVA OneVision VLLM model."""
|
||||
|
||||
images: list[ImageField] | ImageField | None = InputField(default=None, max_length=3, description="Input image.")
|
||||
prompt: str = InputField(
|
||||
default="",
|
||||
description="Input text prompt.",
|
||||
ui_component=UIComponent.Textarea,
|
||||
)
|
||||
vllm_model: ModelIdentifierField = InputField(
|
||||
title="LLaVA Model Type",
|
||||
description=FieldDescriptions.vllm_model,
|
||||
ui_type=UIType.LlavaOnevisionModel,
|
||||
)
|
||||
|
||||
@field_validator("images", mode="before")
|
||||
def listify_images(cls, v: Any) -> list:
|
||||
if v is None:
|
||||
return v
|
||||
if not isinstance(v, list):
|
||||
return [v]
|
||||
return v
|
||||
|
||||
def _get_images(self, context: InvocationContext) -> list[Image]:
|
||||
if self.images is None:
|
||||
return []
|
||||
|
||||
image_fields = self.images if isinstance(self.images, list) else [self.images]
|
||||
return [context.images.get_pil(image_field.image_name, "RGB") for image_field in image_fields]
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> StringOutput:
|
||||
images = self._get_images(context)
|
||||
|
||||
with context.models.load(self.vllm_model) as vllm_model:
|
||||
assert isinstance(vllm_model, LlavaOnevisionModel)
|
||||
output = vllm_model.run(
|
||||
prompt=self.prompt,
|
||||
images=images,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
dtype=TorchDevice.choose_torch_dtype(),
|
||||
)
|
||||
|
||||
return StringOutput(value=output)
|
||||
@@ -67,7 +67,7 @@ class AlphaMaskToTensorInvocation(BaseInvocation):
|
||||
invert: bool = InputField(default=False, description="Whether to invert the mask.")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
image = context.images.get_pil(self.image.image_name, mode="RGBA")
|
||||
mask = torch.zeros((1, image.height, image.width), dtype=torch.bool)
|
||||
if self.invert:
|
||||
mask[0] = torch.tensor(np.array(image)[:, :, 3] == 0, dtype=torch.bool)
|
||||
|
||||
@@ -181,7 +181,7 @@ class LoRALoaderOutput(BaseInvocationOutput):
|
||||
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
|
||||
|
||||
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.3")
|
||||
@invocation("lora_loader", title="Apply LoRA - SD1.5", tags=["model"], category="model", version="1.0.4")
|
||||
class LoRALoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
|
||||
@@ -244,7 +244,7 @@ class LoRASelectorOutput(BaseInvocationOutput):
|
||||
lora: LoRAField = OutputField(description="LoRA model and weight", title="LoRA")
|
||||
|
||||
|
||||
@invocation("lora_selector", title="LoRA Model - SD1.5", tags=["model"], category="model", version="1.0.2")
|
||||
@invocation("lora_selector", title="Select LoRA", tags=["model"], category="model", version="1.0.3")
|
||||
class LoRASelectorInvocation(BaseInvocation):
|
||||
"""Selects a LoRA model and weight."""
|
||||
|
||||
@@ -258,7 +258,7 @@ class LoRASelectorInvocation(BaseInvocation):
|
||||
|
||||
|
||||
@invocation(
|
||||
"lora_collection_loader", title="LoRA Collection - SD1.5", tags=["model"], category="model", version="1.1.1"
|
||||
"lora_collection_loader", title="Apply LoRA Collection - SD1.5", tags=["model"], category="model", version="1.1.2"
|
||||
)
|
||||
class LoRACollectionLoader(BaseInvocation):
|
||||
"""Applies a collection of LoRAs to the provided UNet and CLIP models."""
|
||||
@@ -322,10 +322,10 @@ class SDXLLoRALoaderOutput(BaseInvocationOutput):
|
||||
|
||||
@invocation(
|
||||
"sdxl_lora_loader",
|
||||
title="LoRA Model - SDXL",
|
||||
title="Apply LoRA - SDXL",
|
||||
tags=["lora", "model"],
|
||||
category="model",
|
||||
version="1.0.4",
|
||||
version="1.0.5",
|
||||
)
|
||||
class SDXLLoRALoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
@@ -402,10 +402,10 @@ class SDXLLoRALoaderInvocation(BaseInvocation):
|
||||
|
||||
@invocation(
|
||||
"sdxl_lora_collection_loader",
|
||||
title="LoRA Collection - SDXL",
|
||||
title="Apply LoRA Collection - SDXL",
|
||||
tags=["model"],
|
||||
category="model",
|
||||
version="1.1.1",
|
||||
version="1.1.2",
|
||||
)
|
||||
class SDXLLoRACollectionLoader(BaseInvocation):
|
||||
"""Applies a collection of SDXL LoRAs to the provided UNet and CLIP models."""
|
||||
|
||||
@@ -49,8 +49,6 @@ def run_app() -> None:
|
||||
# Miscellaneous startup tasks.
|
||||
apply_monkeypatches()
|
||||
register_mime_types()
|
||||
if app_config.dev_reload:
|
||||
enable_dev_reload()
|
||||
check_cudnn(logger)
|
||||
|
||||
# Initialize the app and event loop.
|
||||
@@ -61,6 +59,11 @@ def run_app() -> None:
|
||||
# core nodes have been imported so that we can catch when a custom node clobbers a core node.
|
||||
load_custom_nodes(custom_nodes_path=app_config.custom_nodes_path, logger=logger)
|
||||
|
||||
if app_config.dev_reload:
|
||||
# load_custom_nodes seems to bypass jurrigged's import sniffer, so be sure to call it *after* they're already
|
||||
# imported.
|
||||
enable_dev_reload(custom_nodes_path=app_config.custom_nodes_path)
|
||||
|
||||
# Start the server.
|
||||
config = uvicorn.Config(
|
||||
app=app,
|
||||
|
||||
@@ -38,9 +38,11 @@ from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
CheckpointConfigBase,
|
||||
InvalidModelConfigException,
|
||||
ModelConfigBase,
|
||||
ModelRepoVariant,
|
||||
ModelSourceType,
|
||||
)
|
||||
from invokeai.backend.model_manager.legacy_probe import ModelProbe
|
||||
from invokeai.backend.model_manager.metadata import (
|
||||
AnyModelRepoMetadata,
|
||||
HuggingFaceMetadataFetch,
|
||||
@@ -49,7 +51,6 @@ from invokeai.backend.model_manager.metadata import (
|
||||
RemoteModelFile,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMetadata
|
||||
from invokeai.backend.model_manager.probe import ModelProbe
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
from invokeai.backend.util import InvokeAILogger
|
||||
from invokeai.backend.util.catch_sigint import catch_sigint
|
||||
@@ -182,9 +183,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
) -> str: # noqa D102
|
||||
model_path = Path(model_path)
|
||||
config = config or ModelRecordChanges()
|
||||
info: AnyModelConfig = ModelProbe.probe(
|
||||
Path(model_path), config.model_dump(), hash_algo=self._app_config.hashing_algorithm
|
||||
) # type: ignore
|
||||
info: AnyModelConfig = self._probe(Path(model_path), config) # type: ignore
|
||||
|
||||
if preferred_name := config.name:
|
||||
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
|
||||
@@ -644,12 +643,22 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
move(old_path, new_path)
|
||||
return new_path
|
||||
|
||||
def _probe(self, model_path: Path, config: Optional[ModelRecordChanges] = None):
|
||||
config = config or ModelRecordChanges()
|
||||
hash_algo = self._app_config.hashing_algorithm
|
||||
fields = config.model_dump()
|
||||
|
||||
try:
|
||||
return ModelConfigBase.classify(model_path=model_path, hash_algo=hash_algo, **fields)
|
||||
except InvalidModelConfigException:
|
||||
return ModelProbe.probe(model_path=model_path, fields=fields, hash_algo=hash_algo) # type: ignore
|
||||
|
||||
def _register(
|
||||
self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None
|
||||
) -> str:
|
||||
config = config or ModelRecordChanges()
|
||||
|
||||
info = info or ModelProbe.probe(model_path, config.model_dump(), hash_algo=self._app_config.hashing_algorithm) # type: ignore
|
||||
info = info or self._probe(model_path, config)
|
||||
|
||||
model_path = model_path.resolve()
|
||||
|
||||
|
||||
@@ -85,8 +85,11 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
|
||||
def torch_load_file(checkpoint: Path) -> AnyModel:
|
||||
scan_result = scan_file_path(checkpoint)
|
||||
if scan_result.infected_files != 0 or scan_result.scan_err:
|
||||
raise Exception("The model at {checkpoint} is potentially infected by malware. Aborting load.")
|
||||
if scan_result.infected_files != 0:
|
||||
raise Exception(f"The model at {checkpoint} is potentially infected by malware. Aborting load.")
|
||||
if scan_result.scan_err:
|
||||
raise Exception(f"Error scanning model at {checkpoint} for malware. Aborting load.")
|
||||
|
||||
result = torch_load(checkpoint, map_location="cpu")
|
||||
return result
|
||||
|
||||
|
||||
@@ -570,7 +570,10 @@ ValueToInsertTuple: TypeAlias = tuple[
|
||||
str | None, # destination (optional)
|
||||
int | None, # retried_from_item_id (optional, this is always None for new items)
|
||||
]
|
||||
"""A type alias for the tuple of values to insert into the session queue table."""
|
||||
"""A type alias for the tuple of values to insert into the session queue table.
|
||||
|
||||
**If you change this, be sure to update the `enqueue_batch` and `retry_items_by_id` methods in the session queue service!**
|
||||
"""
|
||||
|
||||
|
||||
def prepare_values_to_insert(
|
||||
|
||||
@@ -27,6 +27,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
SessionQueueItemDTO,
|
||||
SessionQueueItemNotFoundError,
|
||||
SessionQueueStatus,
|
||||
ValueToInsertTuple,
|
||||
calc_session_count,
|
||||
prepare_values_to_insert,
|
||||
)
|
||||
@@ -689,7 +690,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
"""Retries the given queue items"""
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
values_to_insert: list[tuple] = []
|
||||
values_to_insert: list[ValueToInsertTuple] = []
|
||||
retried_item_ids: list[int] = []
|
||||
|
||||
for item_id in item_ids:
|
||||
@@ -715,16 +716,16 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
else queue_item.item_id
|
||||
)
|
||||
|
||||
value_to_insert = (
|
||||
value_to_insert: ValueToInsertTuple = (
|
||||
queue_item.queue_id,
|
||||
queue_item.batch_id,
|
||||
queue_item.destination,
|
||||
field_values_json,
|
||||
queue_item.origin,
|
||||
queue_item.priority,
|
||||
workflow_json,
|
||||
cloned_session_json,
|
||||
cloned_session.id,
|
||||
queue_item.batch_id,
|
||||
field_values_json,
|
||||
queue_item.priority,
|
||||
workflow_json,
|
||||
queue_item.origin,
|
||||
queue_item.destination,
|
||||
retried_from_item_id,
|
||||
)
|
||||
values_to_insert.append(value_to_insert)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
import mimetypes
|
||||
import socket
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
@@ -33,8 +34,9 @@ def check_cudnn(logger: logging.Logger) -> None:
|
||||
)
|
||||
|
||||
|
||||
def enable_dev_reload() -> None:
|
||||
def enable_dev_reload(custom_nodes_path=None) -> None:
|
||||
"""Enable hot reloading on python file changes during development."""
|
||||
import invokeai
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
try:
|
||||
@@ -44,7 +46,10 @@ def enable_dev_reload() -> None:
|
||||
'Can\'t start `--dev_reload` because jurigged is not found; `pip install -e ".[dev]"` to include development dependencies.'
|
||||
) from e
|
||||
else:
|
||||
jurigged.watch(logger=InvokeAILogger.get_logger(name="jurigged").info)
|
||||
paths = [str(Path(invokeai.__file__).with_name("*.py"))]
|
||||
if custom_nodes_path:
|
||||
paths.append(str(custom_nodes_path / "*.py"))
|
||||
jurigged.watch(pattern=paths, logger=InvokeAILogger.get_logger(name="jurigged").info)
|
||||
|
||||
|
||||
def apply_monkeypatches() -> None:
|
||||
|
||||
@@ -20,6 +20,7 @@ class ModelSpec:
|
||||
|
||||
max_seq_lengths: Dict[str, Literal[256, 512]] = {
|
||||
"flux-dev": 512,
|
||||
"flux-dev-fill": 512,
|
||||
"flux-schnell": 256,
|
||||
}
|
||||
|
||||
@@ -68,4 +69,19 @@ params = {
|
||||
qkv_bias=True,
|
||||
guidance_embed=False,
|
||||
),
|
||||
"flux-dev-fill": FluxParams(
|
||||
in_channels=384,
|
||||
out_channels=64,
|
||||
vec_in_dim=768,
|
||||
context_in_dim=4096,
|
||||
hidden_size=3072,
|
||||
mlp_ratio=4.0,
|
||||
num_heads=24,
|
||||
depth=19,
|
||||
depth_single_blocks=38,
|
||||
axes_dim=[16, 56, 56],
|
||||
theta=10_000,
|
||||
qkv_bias=True,
|
||||
guidance_embed=True,
|
||||
),
|
||||
}
|
||||
|
||||
49
invokeai/backend/llava_onevision_model.py
Normal file
49
invokeai/backend/llava_onevision_model.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration, LlavaOnevisionProcessor
|
||||
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
|
||||
|
||||
class LlavaOnevisionModel(RawModel):
|
||||
def __init__(self, vllm_model: LlavaOnevisionForConditionalGeneration, processor: LlavaOnevisionProcessor):
|
||||
self._vllm_model = vllm_model
|
||||
self._processor = processor
|
||||
|
||||
@classmethod
|
||||
def load_from_path(cls, path: str | Path):
|
||||
vllm_model = LlavaOnevisionForConditionalGeneration.from_pretrained(path, local_files_only=True)
|
||||
assert isinstance(vllm_model, LlavaOnevisionForConditionalGeneration)
|
||||
processor = AutoProcessor.from_pretrained(path, local_files_only=True)
|
||||
assert isinstance(processor, LlavaOnevisionProcessor)
|
||||
return cls(vllm_model, processor)
|
||||
|
||||
def run(self, prompt: str, images: list[Image], device: torch.device, dtype: torch.dtype) -> str:
|
||||
# TODO(ryand): Tune the max number of images that are useful for the model.
|
||||
if len(images) > 3:
|
||||
raise ValueError(
|
||||
f"{len(images)} images were provided as input to the LLaVA OneVision model. "
|
||||
"Pass <=3 images for good performance."
|
||||
)
|
||||
|
||||
# Define a chat history and use `apply_chat_template` to get correctly formatted prompt.
|
||||
# "content" is a list of dicts with types "text" or "image".
|
||||
content = [{"type": "text", "text": prompt}]
|
||||
# Add the correct number of images.
|
||||
for _ in images:
|
||||
content.append({"type": "image"})
|
||||
|
||||
conversation = [{"role": "user", "content": content}]
|
||||
prompt = self._processor.apply_chat_template(conversation, add_generation_prompt=True)
|
||||
inputs = self._processor(images=images or None, text=prompt, return_tensors="pt").to(device=device, dtype=dtype)
|
||||
output = self._vllm_model.generate(**inputs, max_new_tokens=400, do_sample=False)
|
||||
output_str: str = self._processor.decode(output[0][2:], skip_special_tokens=True)
|
||||
# The output_str will include the prompt, so we extract the response.
|
||||
response = output_str.split("assistant\n", 1)[1].strip()
|
||||
return response
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
self._vllm_model.to(device=device, dtype=dtype)
|
||||
@@ -5,6 +5,7 @@ from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
InvalidModelConfigException,
|
||||
ModelConfigBase,
|
||||
ModelConfigFactory,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
@@ -13,8 +14,8 @@ from invokeai.backend.model_manager.config import (
|
||||
SchedulerPredictionType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.legacy_probe import ModelProbe
|
||||
from invokeai.backend.model_manager.load import LoadedModel
|
||||
from invokeai.backend.model_manager.probe import ModelProbe
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
|
||||
__all__ = [
|
||||
@@ -32,4 +33,5 @@ __all__ = [
|
||||
"ModelVariantType",
|
||||
"SchedulerPredictionType",
|
||||
"SubModelType",
|
||||
"ModelConfigBase",
|
||||
]
|
||||
|
||||
@@ -20,21 +20,34 @@ Validation errors will raise an InvalidModelConfigException error.
|
||||
|
||||
"""
|
||||
|
||||
# pyright: reportIncompatibleVariableOverride=false
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional, Type, TypeAlias, Union
|
||||
from inspect import isabstract
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, Literal, Optional, TypeAlias, Union
|
||||
|
||||
import diffusers
|
||||
import onnxruntime as ort
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from picklescan.scanner import scan_file_path
|
||||
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
|
||||
from typing_extensions import Annotated, Any, Dict
|
||||
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.model_hash.hash_validator import validate_hash
|
||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
|
||||
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ModelMixin is the base class for all diffusers and transformers models
|
||||
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
|
||||
@@ -44,7 +57,7 @@ AnyModel = Union[
|
||||
|
||||
|
||||
class InvalidModelConfigException(Exception):
|
||||
"""Exception for when config parser doesn't recognized this combination of model type and format."""
|
||||
"""Exception for when config parser doesn't recognize this combination of model type and format."""
|
||||
|
||||
|
||||
class BaseModelType(str, Enum):
|
||||
@@ -78,6 +91,7 @@ class ModelType(str, Enum):
|
||||
SpandrelImageToImage = "spandrel_image_to_image"
|
||||
SigLIP = "siglip"
|
||||
FluxRedux = "flux_redux"
|
||||
LlavaOnevision = "llava_onevision"
|
||||
|
||||
|
||||
class SubModelType(str, Enum):
|
||||
@@ -190,12 +204,98 @@ class MainModelDefaultSettings(BaseModel):
|
||||
class ControlAdapterDefaultSettings(BaseModel):
|
||||
# This could be narrowed to controlnet processor nodes, but they change. Leaving this a string is safer.
|
||||
preprocessor: str | None
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class ModelConfigBase(BaseModel):
|
||||
"""Base class for model configuration information."""
|
||||
class ModelOnDisk:
|
||||
"""A utility class representing a model stored on disk."""
|
||||
|
||||
def __init__(self, path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single"):
|
||||
self.path = path
|
||||
self.format_type = ModelFormat.Diffusers if path.is_dir() else ModelFormat.Checkpoint
|
||||
if self.path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
|
||||
self.name = path.stem
|
||||
else:
|
||||
self.name = path.name
|
||||
self.hash_algo = hash_algo
|
||||
|
||||
def hash(self):
|
||||
return ModelHash(algorithm=self.hash_algo).hash(self.path)
|
||||
|
||||
def size(self):
|
||||
if self.format_type == ModelFormat.Checkpoint:
|
||||
return self.path.stat().st_size
|
||||
return sum(file.stat().st_size for file in self.path.rglob("*"))
|
||||
|
||||
def component_paths(self):
|
||||
if self.format_type == ModelFormat.Checkpoint:
|
||||
return {self.path}
|
||||
extensions = {".safetensors", ".pt", ".pth", ".ckpt", ".bin", ".gguf"}
|
||||
return {f for f in self.path.rglob("*") if f.suffix in extensions}
|
||||
|
||||
def repo_variant(self):
|
||||
if self.format_type == ModelFormat.Checkpoint:
|
||||
return None
|
||||
|
||||
weight_files = list(self.path.glob("**/*.safetensors"))
|
||||
weight_files.extend(list(self.path.glob("**/*.bin")))
|
||||
for x in weight_files:
|
||||
if ".fp16" in x.suffixes:
|
||||
return ModelRepoVariant.FP16
|
||||
if "openvino_model" in x.name:
|
||||
return ModelRepoVariant.OpenVINO
|
||||
if "flax_model" in x.name:
|
||||
return ModelRepoVariant.Flax
|
||||
if x.suffix == ".onnx":
|
||||
return ModelRepoVariant.ONNX
|
||||
return ModelRepoVariant.Default
|
||||
|
||||
@staticmethod
|
||||
def load_state_dict(path: Path):
|
||||
with SilenceWarnings():
|
||||
if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
|
||||
scan_result = scan_file_path(path)
|
||||
if scan_result.infected_files != 0 or scan_result.scan_err:
|
||||
raise RuntimeError(f"The model {path.stem} is potentially infected by malware. Aborting import.")
|
||||
checkpoint = torch.load(path, map_location="cpu")
|
||||
elif path.suffix.endswith(".gguf"):
|
||||
checkpoint = gguf_sd_loader(path, compute_dtype=torch.float32)
|
||||
elif path.suffix.endswith(".safetensors"):
|
||||
checkpoint = safetensors.torch.load_file(path)
|
||||
else:
|
||||
raise ValueError(f"Unrecognized model extension: {path.suffix}")
|
||||
|
||||
state_dict = checkpoint.get("state_dict", checkpoint)
|
||||
return state_dict
|
||||
|
||||
|
||||
class MatchSpeed(int, Enum):
|
||||
"""Represents the estimated runtime speed of a config's 'matches' method."""
|
||||
|
||||
FAST = 0
|
||||
MED = 1
|
||||
SLOW = 2
|
||||
|
||||
|
||||
class ModelConfigBase(ABC, BaseModel):
|
||||
"""
|
||||
Abstract Base class for model configurations.
|
||||
|
||||
To create a new config type, inherit from this class and implement its interface:
|
||||
- (mandatory) override methods 'matches' and 'parse'
|
||||
- (mandatory) define fields 'type' and 'format' as class attributes
|
||||
|
||||
- (optional) override method 'get_tag'
|
||||
- (optional) override field _MATCH_SPEED
|
||||
|
||||
See MinimalConfigExample in test_model_probe.py for an example implementation.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any]) -> None:
|
||||
schema["required"].extend(["key", "type", "format"])
|
||||
|
||||
model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra)
|
||||
|
||||
key: str = Field(description="A unique key for this model.", default_factory=uuid_string)
|
||||
hash: str = Field(description="The hash of the model file(s).")
|
||||
@@ -203,27 +303,134 @@ class ModelConfigBase(BaseModel):
|
||||
description="Path to the model on the filesystem. Relative paths are relative to the Invoke root directory."
|
||||
)
|
||||
name: str = Field(description="Name of the model.")
|
||||
type: ModelType = Field(description="Model type")
|
||||
format: ModelFormat = Field(description="Model format")
|
||||
base: BaseModelType = Field(description="The base model.")
|
||||
description: Optional[str] = Field(description="Model description", default=None)
|
||||
source: str = Field(description="The original source of the model (path, URL or repo_id).")
|
||||
source_type: ModelSourceType = Field(description="The type of source")
|
||||
|
||||
description: Optional[str] = Field(description="Model description", default=None)
|
||||
source_api_response: Optional[str] = Field(
|
||||
description="The original API response from the source, as stringified JSON.", default=None
|
||||
)
|
||||
cover_image: Optional[str] = Field(description="Url for image to preview model", default=None)
|
||||
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
schema["required"].extend(["key", "type", "format"])
|
||||
|
||||
model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra)
|
||||
submodels: Optional[Dict[SubModelType, SubmodelDefinition]] = Field(
|
||||
description="Loadable submodels in this model", default=None
|
||||
)
|
||||
|
||||
_USING_LEGACY_PROBE: ClassVar[set] = set()
|
||||
_USING_CLASSIFY_API: ClassVar[set] = set()
|
||||
_MATCH_SPEED: ClassVar[MatchSpeed] = MatchSpeed.MED
|
||||
|
||||
class CheckpointConfigBase(ModelConfigBase):
|
||||
"""Model config for checkpoint-style models."""
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
if issubclass(cls, LegacyProbeMixin):
|
||||
ModelConfigBase._USING_LEGACY_PROBE.add(cls)
|
||||
else:
|
||||
ModelConfigBase._USING_CLASSIFY_API.add(cls)
|
||||
|
||||
@staticmethod
|
||||
def all_config_classes():
|
||||
subclasses = ModelConfigBase._USING_LEGACY_PROBE | ModelConfigBase._USING_CLASSIFY_API
|
||||
concrete = {cls for cls in subclasses if not isabstract(cls)}
|
||||
return concrete
|
||||
|
||||
@staticmethod
|
||||
def classify(model_path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single", **overrides):
|
||||
"""
|
||||
Returns the best matching ModelConfig instance from a model's file/folder path.
|
||||
Raises InvalidModelConfigException if no valid configuration is found.
|
||||
Created to deprecate ModelProbe.probe
|
||||
"""
|
||||
candidates = ModelConfigBase._USING_CLASSIFY_API
|
||||
sorted_by_match_speed = sorted(candidates, key=lambda cls: cls._MATCH_SPEED)
|
||||
mod = ModelOnDisk(model_path, hash_algo)
|
||||
|
||||
for config_cls in sorted_by_match_speed:
|
||||
try:
|
||||
return config_cls.from_model_on_disk(mod, **overrides)
|
||||
except InvalidModelConfigException:
|
||||
logger.debug(f"ModelConfig '{config_cls.__name__}' failed to parse '{mod.path}', trying next config")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected exception while parsing '{config_cls.__name__}': {e}, trying next config")
|
||||
|
||||
raise InvalidModelConfigException("No valid config found")
|
||||
|
||||
@classmethod
|
||||
def get_tag(cls) -> Tag:
|
||||
type = cls.model_fields["type"].default.value
|
||||
format = cls.model_fields["format"].default.value
|
||||
return Tag(f"{type}.{format}")
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
||||
"""Returns a dictionary with the fields needed to construct the model.
|
||||
Raises InvalidModelConfigException if the model is invalid.
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def matches(cls, mod: ModelOnDisk) -> bool:
|
||||
"""Performs a quick check to determine if the config matches the model.
|
||||
This doesn't need to be a perfect test - the aim is to eliminate unlikely matches quickly before parsing."""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def cast_overrides(overrides: dict[str, Any]):
|
||||
"""Casts user overrides from str to Enum"""
|
||||
if "type" in overrides:
|
||||
overrides["type"] = ModelType(overrides["type"])
|
||||
|
||||
if "format" in overrides:
|
||||
overrides["format"] = ModelFormat(overrides["format"])
|
||||
|
||||
if "base" in overrides:
|
||||
overrides["base"] = BaseModelType(overrides["base"])
|
||||
|
||||
if "source_type" in overrides:
|
||||
overrides["source_type"] = ModelSourceType(overrides["source_type"])
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, **overrides):
|
||||
"""Creates an instance of this config or raises InvalidModelConfigException."""
|
||||
if not cls.matches(mod):
|
||||
raise InvalidModelConfigException(f"Path {mod.path} does not match {cls.__name__} format")
|
||||
|
||||
fields = cls.parse(mod)
|
||||
cls.cast_overrides(overrides)
|
||||
fields.update(overrides)
|
||||
|
||||
type = fields.get("type") or cls.model_fields["type"].default
|
||||
base = fields.get("base") or cls.model_fields["base"].default
|
||||
|
||||
fields["path"] = mod.path.as_posix()
|
||||
fields["source"] = fields.get("source") or fields["path"]
|
||||
fields["source_type"] = fields.get("source_type") or ModelSourceType.Path
|
||||
fields["name"] = name = fields.get("name") or mod.name
|
||||
fields["hash"] = fields.get("hash") or mod.hash()
|
||||
fields["key"] = fields.get("key") or uuid_string()
|
||||
fields["description"] = fields.get("description") or f"{base.value} {type.value} model {name}"
|
||||
fields["repo_variant"] = fields.get("repo_variant") or mod.repo_variant()
|
||||
|
||||
return cls(**fields)
|
||||
|
||||
|
||||
class LegacyProbeMixin:
|
||||
"""Mixin for classes using the legacy probe for model classification."""
|
||||
|
||||
@classmethod
|
||||
def matches(cls, *args, **kwargs):
|
||||
raise NotImplementedError(f"Method 'matches' not implemented for {cls.__name__}")
|
||||
|
||||
@classmethod
|
||||
def parse(cls, *args, **kwargs):
|
||||
raise NotImplementedError(f"Method 'parse' not implemented for {cls.__name__}")
|
||||
|
||||
|
||||
class CheckpointConfigBase(ABC, BaseModel):
|
||||
"""Base class for checkpoint-style models."""
|
||||
|
||||
format: Literal[ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b, ModelFormat.GGUFQuantized] = Field(
|
||||
description="Format of the provided checkpoint model", default=ModelFormat.Checkpoint
|
||||
@@ -234,153 +441,109 @@ class CheckpointConfigBase(ModelConfigBase):
|
||||
)
|
||||
|
||||
|
||||
class DiffusersConfigBase(ModelConfigBase):
|
||||
"""Model config for diffusers-style models."""
|
||||
class DiffusersConfigBase(ABC, BaseModel):
|
||||
"""Base class for diffusers-style models."""
|
||||
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.Default
|
||||
|
||||
|
||||
class LoRAConfigBase(ModelConfigBase):
|
||||
class LoRAConfigBase(ABC, BaseModel):
|
||||
"""Base class for LoRA models."""
|
||||
|
||||
type: Literal[ModelType.LoRA] = ModelType.LoRA
|
||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||
|
||||
|
||||
class T5EncoderConfigBase(ModelConfigBase):
|
||||
class T5EncoderConfigBase(ABC, BaseModel):
|
||||
"""Base class for diffusers-style models."""
|
||||
|
||||
type: Literal[ModelType.T5Encoder] = ModelType.T5Encoder
|
||||
|
||||
|
||||
class T5EncoderConfig(T5EncoderConfigBase):
|
||||
class T5EncoderConfig(T5EncoderConfigBase, LegacyProbeMixin, ModelConfigBase):
|
||||
format: Literal[ModelFormat.T5Encoder] = ModelFormat.T5Encoder
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.T5Encoder.value}.{ModelFormat.T5Encoder.value}")
|
||||
|
||||
|
||||
class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase):
|
||||
class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, LegacyProbeMixin, ModelConfigBase):
|
||||
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = ModelFormat.BnbQuantizedLlmInt8b
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.T5Encoder.value}.{ModelFormat.BnbQuantizedLlmInt8b.value}")
|
||||
|
||||
|
||||
class LoRALyCORISConfig(LoRAConfigBase):
|
||||
class LoRALyCORISConfig(LoRAConfigBase, LegacyProbeMixin, ModelConfigBase):
|
||||
"""Model config for LoRA/Lycoris models."""
|
||||
|
||||
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.LyCORIS.value}")
|
||||
|
||||
|
||||
class ControlAdapterConfigBase(BaseModel):
|
||||
class ControlAdapterConfigBase(ABC, BaseModel):
|
||||
default_settings: Optional[ControlAdapterDefaultSettings] = Field(
|
||||
description="Default settings for this model", default=None
|
||||
)
|
||||
|
||||
|
||||
class ControlLoRALyCORISConfig(ModelConfigBase, ControlAdapterConfigBase):
|
||||
class ControlLoRALyCORISConfig(ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
|
||||
"""Model config for Control LoRA models."""
|
||||
|
||||
type: Literal[ModelType.ControlLoRa] = ModelType.ControlLoRa
|
||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.ControlLoRa.value}.{ModelFormat.LyCORIS.value}")
|
||||
|
||||
|
||||
class ControlLoRADiffusersConfig(ModelConfigBase, ControlAdapterConfigBase):
|
||||
class ControlLoRADiffusersConfig(ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
|
||||
"""Model config for Control LoRA models."""
|
||||
|
||||
type: Literal[ModelType.ControlLoRa] = ModelType.ControlLoRa
|
||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.ControlLoRa.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class LoRADiffusersConfig(LoRAConfigBase):
|
||||
class LoRADiffusersConfig(LoRAConfigBase, LegacyProbeMixin, ModelConfigBase):
|
||||
"""Model config for LoRA/Diffusers models."""
|
||||
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class VAECheckpointConfig(CheckpointConfigBase):
|
||||
class VAECheckpointConfig(CheckpointConfigBase, LegacyProbeMixin, ModelConfigBase):
|
||||
"""Model config for standalone VAE models."""
|
||||
|
||||
type: Literal[ModelType.VAE] = ModelType.VAE
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.VAE.value}.{ModelFormat.Checkpoint.value}")
|
||||
|
||||
|
||||
class VAEDiffusersConfig(ModelConfigBase):
|
||||
class VAEDiffusersConfig(LegacyProbeMixin, ModelConfigBase):
|
||||
"""Model config for standalone VAE models (diffusers version)."""
|
||||
|
||||
type: Literal[ModelType.VAE] = ModelType.VAE
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.VAE.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase):
|
||||
class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
|
||||
"""Model config for ControlNet models (diffusers version)."""
|
||||
|
||||
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.ControlNet.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase):
|
||||
class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
|
||||
"""Model config for ControlNet models (diffusers version)."""
|
||||
|
||||
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.ControlNet.value}.{ModelFormat.Checkpoint.value}")
|
||||
|
||||
|
||||
class TextualInversionFileConfig(ModelConfigBase):
|
||||
class TextualInversionFileConfig(LegacyProbeMixin, ModelConfigBase):
|
||||
"""Model config for textual inversion embeddings."""
|
||||
|
||||
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
|
||||
format: Literal[ModelFormat.EmbeddingFile] = ModelFormat.EmbeddingFile
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFile.value}")
|
||||
|
||||
|
||||
class TextualInversionFolderConfig(ModelConfigBase):
|
||||
class TextualInversionFolderConfig(LegacyProbeMixin, ModelConfigBase):
|
||||
"""Model config for textual inversion embeddings."""
|
||||
|
||||
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
|
||||
format: Literal[ModelFormat.EmbeddingFolder] = ModelFormat.EmbeddingFolder
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFolder.value}")
|
||||
|
||||
|
||||
class MainConfigBase(ModelConfigBase):
|
||||
class MainConfigBase(ABC, BaseModel):
|
||||
type: Literal[ModelType.Main] = ModelType.Main
|
||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||
default_settings: Optional[MainModelDefaultSettings] = Field(
|
||||
@@ -389,167 +552,146 @@ class MainConfigBase(ModelConfigBase):
|
||||
variant: AnyVariant = ModelVariantType.Normal
|
||||
|
||||
|
||||
class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase):
|
||||
class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase):
|
||||
"""Model config for main checkpoint models."""
|
||||
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||
upcast_attention: bool = False
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.Main.value}.{ModelFormat.Checkpoint.value}")
|
||||
|
||||
|
||||
class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase):
|
||||
class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase):
|
||||
"""Model config for main checkpoint models."""
|
||||
|
||||
format: Literal[ModelFormat.BnbQuantizednf4b] = ModelFormat.BnbQuantizednf4b
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||
upcast_attention: bool = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.format = ModelFormat.BnbQuantizednf4b
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.Main.value}.{ModelFormat.BnbQuantizednf4b.value}")
|
||||
|
||||
|
||||
class MainGGUFCheckpointConfig(CheckpointConfigBase, MainConfigBase):
|
||||
class MainGGUFCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase):
|
||||
"""Model config for main checkpoint models."""
|
||||
|
||||
format: Literal[ModelFormat.GGUFQuantized] = ModelFormat.GGUFQuantized
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||
upcast_attention: bool = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.format = ModelFormat.GGUFQuantized
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.Main.value}.{ModelFormat.GGUFQuantized.value}")
|
||||
|
||||
|
||||
class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase):
|
||||
class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase):
|
||||
"""Model config for main diffusers models."""
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.Main.value}.{ModelFormat.Diffusers.value}")
|
||||
pass
|
||||
|
||||
|
||||
class IPAdapterBaseConfig(ModelConfigBase):
|
||||
class IPAdapterConfigBase(ABC, BaseModel):
|
||||
type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter
|
||||
|
||||
|
||||
class IPAdapterInvokeAIConfig(IPAdapterBaseConfig):
|
||||
class IPAdapterInvokeAIConfig(IPAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
|
||||
"""Model config for IP Adapter diffusers format models."""
|
||||
|
||||
# TODO(ryand): Should we deprecate this field? From what I can tell, it hasn't been probed correctly for a long
|
||||
# time. Need to go through the history to make sure I'm understanding this fully.
|
||||
image_encoder_model_id: str
|
||||
format: Literal[ModelFormat.InvokeAI]
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.InvokeAI.value}")
|
||||
format: Literal[ModelFormat.InvokeAI] = ModelFormat.InvokeAI
|
||||
|
||||
|
||||
class IPAdapterCheckpointConfig(IPAdapterBaseConfig):
|
||||
class IPAdapterCheckpointConfig(IPAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
|
||||
"""Model config for IP Adapter checkpoint format models."""
|
||||
|
||||
format: Literal[ModelFormat.Checkpoint]
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.Checkpoint.value}")
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
|
||||
class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
|
||||
"""Model config for Clip Embeddings."""
|
||||
|
||||
variant: ClipVariantType = Field(description="Clip variant for this model")
|
||||
type: Literal[ModelType.CLIPEmbed] = ModelType.CLIPEmbed
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
variant: ClipVariantType = ClipVariantType.L
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig):
|
||||
class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, LegacyProbeMixin, ModelConfigBase):
|
||||
"""Model config for CLIP-G Embeddings."""
|
||||
|
||||
variant: ClipVariantType = ClipVariantType.G
|
||||
variant: Literal[ClipVariantType.G] = ClipVariantType.G
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.G}")
|
||||
@classmethod
|
||||
def get_tag(cls) -> Tag:
|
||||
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.G.value}")
|
||||
|
||||
|
||||
class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig):
|
||||
class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, LegacyProbeMixin, ModelConfigBase):
|
||||
"""Model config for CLIP-L Embeddings."""
|
||||
|
||||
variant: ClipVariantType = ClipVariantType.L
|
||||
variant: Literal[ClipVariantType.L] = ClipVariantType.L
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.L}")
|
||||
@classmethod
|
||||
def get_tag(cls) -> Tag:
|
||||
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.L.value}")
|
||||
|
||||
|
||||
class CLIPVisionDiffusersConfig(DiffusersConfigBase):
|
||||
class CLIPVisionDiffusersConfig(DiffusersConfigBase, LegacyProbeMixin, ModelConfigBase):
|
||||
"""Model config for CLIPVision."""
|
||||
|
||||
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.CLIPVision.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase):
|
||||
class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
|
||||
"""Model config for T2I."""
|
||||
|
||||
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.T2IAdapter.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class SpandrelImageToImageConfig(ModelConfigBase):
|
||||
class SpandrelImageToImageConfig(LegacyProbeMixin, ModelConfigBase):
|
||||
"""Model config for Spandrel Image to Image models."""
|
||||
|
||||
_MATCH_SPEED: ClassVar[MatchSpeed] = MatchSpeed.SLOW # requires loading the model from disk
|
||||
|
||||
type: Literal[ModelType.SpandrelImageToImage] = ModelType.SpandrelImageToImage
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.SpandrelImageToImage.value}.{ModelFormat.Checkpoint.value}")
|
||||
|
||||
|
||||
class SigLIPConfig(DiffusersConfigBase):
|
||||
class SigLIPConfig(DiffusersConfigBase, LegacyProbeMixin, ModelConfigBase):
|
||||
"""Model config for SigLIP."""
|
||||
|
||||
type: Literal[ModelType.SigLIP] = ModelType.SigLIP
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.SigLIP.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class FluxReduxConfig(ModelConfigBase):
|
||||
class FluxReduxConfig(LegacyProbeMixin, ModelConfigBase):
|
||||
"""Model config for FLUX Tools Redux model."""
|
||||
|
||||
type: Literal[ModelType.FluxRedux] = ModelType.FluxRedux
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.FluxRedux.value}.{ModelFormat.Checkpoint.value}")
|
||||
|
||||
class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
|
||||
"""Model config for Llava Onevision models."""
|
||||
|
||||
type: Literal[ModelType.LlavaOnevision] = ModelType.LlavaOnevision
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@classmethod
|
||||
def matches(cls, mod: ModelOnDisk) -> bool:
|
||||
if mod.format_type == ModelFormat.Checkpoint:
|
||||
return False
|
||||
|
||||
config_path = mod.path / "config.json"
|
||||
try:
|
||||
with open(config_path, "r") as file:
|
||||
config = json.load(file)
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
|
||||
architectures = config.get("architectures")
|
||||
return architectures and architectures[0] == "LlavaOnevisionForConditionalGeneration"
|
||||
|
||||
@classmethod
|
||||
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
||||
return {
|
||||
"base": BaseModelType.Any,
|
||||
"variant": ModelVariantType.Normal,
|
||||
}
|
||||
|
||||
|
||||
def get_model_discriminator_value(v: Any) -> str:
|
||||
@@ -557,22 +699,40 @@ def get_model_discriminator_value(v: Any) -> str:
|
||||
Computes the discriminator value for a model config.
|
||||
https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator
|
||||
"""
|
||||
format_ = None
|
||||
type_ = None
|
||||
format_ = type_ = variant_ = None
|
||||
|
||||
if isinstance(v, dict):
|
||||
format_ = v.get("format")
|
||||
if isinstance(format_, Enum):
|
||||
format_ = format_.value
|
||||
|
||||
type_ = v.get("type")
|
||||
if isinstance(type_, Enum):
|
||||
type_ = type_.value
|
||||
|
||||
variant_ = v.get("variant")
|
||||
if isinstance(variant_, Enum):
|
||||
variant_ = variant_.value
|
||||
else:
|
||||
format_ = v.format.value
|
||||
type_ = v.type.value
|
||||
v = f"{type_}.{format_}"
|
||||
return v
|
||||
variant_ = getattr(v, "variant", None)
|
||||
if variant_:
|
||||
variant_ = variant_.value
|
||||
|
||||
# Ideally, each config would be uniquely identified with a combination of fields
|
||||
# i.e. (type, format, variant) without any special cases. Alas...
|
||||
|
||||
# Previously, CLIPEmbed did not have any variants, meaning older database entries lack a variant field.
|
||||
# To maintain compatibility, we default to ClipVariantType.L in this case.
|
||||
if type_ == ModelType.CLIPEmbed.value and format_ == ModelFormat.Diffusers.value:
|
||||
variant_ = variant_ or ClipVariantType.L.value
|
||||
return f"{type_}.{format_}.{variant_}"
|
||||
return f"{type_}.{format_}"
|
||||
|
||||
|
||||
# The types are listed explicitly because IDEs/LSPs can't identify the correct types
|
||||
# when AnyModelConfig is constructed dynamically using ModelConfigBase.all_config_classes
|
||||
AnyModelConfig = Annotated[
|
||||
Union[
|
||||
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
|
||||
@@ -596,11 +756,11 @@ AnyModelConfig = Annotated[
|
||||
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
|
||||
Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()],
|
||||
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
|
||||
Annotated[CLIPEmbedDiffusersConfig, CLIPEmbedDiffusersConfig.get_tag()],
|
||||
Annotated[CLIPLEmbedDiffusersConfig, CLIPLEmbedDiffusersConfig.get_tag()],
|
||||
Annotated[CLIPGEmbedDiffusersConfig, CLIPGEmbedDiffusersConfig.get_tag()],
|
||||
Annotated[SigLIPConfig, SigLIPConfig.get_tag()],
|
||||
Annotated[FluxReduxConfig, FluxReduxConfig.get_tag()],
|
||||
Annotated[LlavaOnevisionConfig, LlavaOnevisionConfig.get_tag()],
|
||||
],
|
||||
Discriminator(get_model_discriminator_value),
|
||||
]
|
||||
@@ -609,39 +769,12 @@ AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
|
||||
AnyDefaultSettings: TypeAlias = Union[MainModelDefaultSettings, ControlAdapterDefaultSettings]
|
||||
|
||||
|
||||
class ModelConfigFactory(object):
|
||||
"""Class for parsing config dicts into StableDiffusion Config obects."""
|
||||
|
||||
@classmethod
|
||||
def make_config(
|
||||
cls,
|
||||
model_data: Union[Dict[str, Any], AnyModelConfig],
|
||||
key: Optional[str] = None,
|
||||
dest_class: Optional[Type[ModelConfigBase]] = None,
|
||||
timestamp: Optional[float] = None,
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
Return the appropriate config object from raw dict values.
|
||||
|
||||
:param model_data: A raw dict corresponding the obect fields to be
|
||||
parsed into a ModelConfigBase obect (or descendent), or a ModelConfigBase
|
||||
object, which will be passed through unchanged.
|
||||
:param dest_class: The config class to be returned. If not provided, will
|
||||
be selected automatically.
|
||||
"""
|
||||
model: Optional[ModelConfigBase] = None
|
||||
if isinstance(model_data, ModelConfigBase):
|
||||
model = model_data
|
||||
elif dest_class:
|
||||
model = dest_class.model_validate(model_data)
|
||||
else:
|
||||
# mypy doesn't typecheck TypeAdapters well?
|
||||
model = AnyModelConfigValidator.validate_python(model_data) # type: ignore
|
||||
assert model is not None
|
||||
if key:
|
||||
model.key = key
|
||||
if isinstance(model, CheckpointConfigBase) and timestamp is not None:
|
||||
class ModelConfigFactory:
|
||||
@staticmethod
|
||||
def make_config(model_data: Dict[str, Any], timestamp: Optional[float] = None) -> AnyModelConfig:
|
||||
"""Return the appropriate config object from raw dict values."""
|
||||
model = AnyModelConfigValidator.validate_python(model_data) # type: ignore
|
||||
if isinstance(model, CheckpointConfigBase) and timestamp:
|
||||
model.converted_at = timestamp
|
||||
if model:
|
||||
validate_hash(model.hash)
|
||||
validate_hash(model.hash)
|
||||
return model # type: ignore
|
||||
|
||||
@@ -3,10 +3,10 @@ import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Literal, Optional, Union
|
||||
|
||||
import picklescan.scanner as pscan
|
||||
import safetensors.torch
|
||||
import spandrel
|
||||
import torch
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
@@ -141,6 +141,7 @@ class ModelProbe(object):
|
||||
"SD3Transformer2DModel": ModelType.Main,
|
||||
"CLIPTextModelWithProjection": ModelType.CLIPEmbed,
|
||||
"SiglipModel": ModelType.SigLIP,
|
||||
"LlavaOnevisionForConditionalGeneration": ModelType.LlavaOnevision,
|
||||
}
|
||||
|
||||
TYPE2VARIANT: Dict[ModelType, Callable[[str], Optional[AnyVariant]]] = {ModelType.CLIPEmbed: get_clip_variant_type}
|
||||
@@ -416,20 +417,22 @@ class ModelProbe(object):
|
||||
# TODO: Decide between dev/schnell
|
||||
checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
|
||||
state_dict = checkpoint.get("state_dict") or checkpoint
|
||||
|
||||
# HACK: For FLUX, config_file is used as a key into invokeai.backend.flux.util.params during model
|
||||
# loading. When FLUX support was first added, it was decided that this was the easiest way to support
|
||||
# the various FLUX formats rather than adding new model types/formats. Be careful when modifying this in
|
||||
# the future.
|
||||
if (
|
||||
"guidance_in.out_layer.weight" in state_dict
|
||||
or "model.diffusion_model.guidance_in.out_layer.weight" in state_dict
|
||||
):
|
||||
# For flux, this is a key in invokeai.backend.flux.util.params
|
||||
# Due to model type and format being the descriminator for model configs this
|
||||
# is used rather than attempting to support flux with separate model types and format
|
||||
# If changed in the future, please fix me
|
||||
config_file = "flux-dev"
|
||||
if variant_type == ModelVariantType.Normal:
|
||||
config_file = "flux-dev"
|
||||
elif variant_type == ModelVariantType.Inpaint:
|
||||
config_file = "flux-dev-fill"
|
||||
else:
|
||||
raise ValueError(f"Unexpected FLUX variant type: {variant_type}")
|
||||
else:
|
||||
# For flux, this is a key in invokeai.backend.flux.util.params
|
||||
# Due to model type and format being the discriminator for model configs this
|
||||
# is used rather than attempting to support flux with separate model types and format
|
||||
# If changed in the future, please fix me
|
||||
config_file = "flux-schnell"
|
||||
else:
|
||||
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
||||
@@ -482,9 +485,11 @@ class ModelProbe(object):
|
||||
and option to exit if an infected file is identified.
|
||||
"""
|
||||
# scan model
|
||||
scan_result = scan_file_path(checkpoint)
|
||||
if scan_result.infected_files != 0 or scan_result.scan_err:
|
||||
raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")
|
||||
scan_result = pscan.scan_file_path(checkpoint)
|
||||
if scan_result.infected_files != 0:
|
||||
raise Exception(f"The model {model_name} is potentially infected by malware. Aborting import.")
|
||||
if scan_result.scan_err:
|
||||
raise Exception(f"Error scanning model {model_name} for malware. Aborting import.")
|
||||
|
||||
|
||||
# Probing utilities
|
||||
@@ -552,9 +557,27 @@ class CheckpointProbeBase(ProbeBase):
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
model_type = ModelProbe.get_model_type_from_checkpoint(self.model_path, self.checkpoint)
|
||||
base_type = self.get_base_type()
|
||||
if model_type != ModelType.Main or base_type == BaseModelType.Flux:
|
||||
if model_type != ModelType.Main:
|
||||
return ModelVariantType.Normal
|
||||
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
|
||||
|
||||
if base_type == BaseModelType.Flux:
|
||||
in_channels = state_dict["img_in.weight"].shape[1]
|
||||
|
||||
# FLUX Model variant types are distinguished by input channels:
|
||||
# - Unquantized Dev and Schnell have in_channels=64
|
||||
# - BNB-NF4 Dev and Schnell have in_channels=1
|
||||
# - FLUX Fill has in_channels=384
|
||||
# - Unsure of quantized FLUX Fill models
|
||||
# - Unsure of GGUF-quantized models
|
||||
if in_channels == 384:
|
||||
# This is a FLUX Fill model. FLUX Fill needs special handling throughout the application. The variant
|
||||
# type is used to determine whether to use the fill model or the base model.
|
||||
return ModelVariantType.Inpaint
|
||||
else:
|
||||
# Fall back on "normal" variant type for all other FLUX models.
|
||||
return ModelVariantType.Normal
|
||||
|
||||
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||
if in_channels == 9:
|
||||
return ModelVariantType.Inpaint
|
||||
@@ -767,6 +790,11 @@ class FluxReduxCheckpointProbe(CheckpointProbeBase):
|
||||
return BaseModelType.Flux
|
||||
|
||||
|
||||
class LlavaOnevisionCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
########################################################
|
||||
# classes for probing folders
|
||||
#######################################################
|
||||
@@ -1047,6 +1075,11 @@ class FluxReduxFolderProbe(FolderProbeBase):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class LlaveOnevisionFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
return BaseModelType.Any
|
||||
|
||||
|
||||
class T2IAdapterFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
config_file = self.model_path / "config.json"
|
||||
@@ -1082,6 +1115,7 @@ ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderPro
|
||||
ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelImageToImageFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.SigLIP, SigLIPFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.FluxRedux, FluxReduxFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.LlavaOnevision, LlaveOnevisionFolderProbe)
|
||||
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
|
||||
@@ -1095,5 +1129,6 @@ ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpoi
|
||||
ModelProbe.register_probe("checkpoint", ModelType.SpandrelImageToImage, SpandrelImageToImageCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.SigLIP, SigLIPCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.FluxRedux, FluxReduxCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.LlavaOnevision, LlavaOnevisionCheckpointProbe)
|
||||
|
||||
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)
|
||||
@@ -0,0 +1,32 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.backend.llava_onevision_model import LlavaOnevisionModel
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LlavaOnevision, format=ModelFormat.Diffusers)
|
||||
class LlavaOnevisionModelLoader(ModelLoader):
|
||||
"""Class for loading LLaVA Onevision VLLM models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if submodel_type is not None:
|
||||
raise ValueError("Unexpected submodel requested for LLaVA OneVision model.")
|
||||
|
||||
model_path = Path(config.path)
|
||||
model = LlavaOnevisionModel.load_from_path(model_path)
|
||||
model.to(dtype=self._torch_dtype)
|
||||
return model
|
||||
@@ -614,6 +614,26 @@ flux_redux = StarterModel(
|
||||
)
|
||||
# endregion
|
||||
|
||||
# region LlavaOnevisionModel
|
||||
llava_onevision = StarterModel(
|
||||
name="LLaVA Onevision Qwen2 0.5B",
|
||||
base=BaseModelType.Any,
|
||||
source="llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
|
||||
description="LLaVA Onevision VLLM model",
|
||||
type=ModelType.LlavaOnevision,
|
||||
)
|
||||
# endregion
|
||||
|
||||
# region FLUX Fill
|
||||
flux_fill = StarterModel(
|
||||
name="FLUX Fill",
|
||||
base=BaseModelType.Flux,
|
||||
source="black-forest-labs/FLUX.1-Fill-dev::flux1-fill-dev.safetensors",
|
||||
description="FLUX Fill model (for inpainting).",
|
||||
type=ModelType.Main,
|
||||
)
|
||||
# endregion
|
||||
|
||||
# List of starter models, displayed on the frontend.
|
||||
# The order/sort of this list is not changed by the frontend - set it how you want it here.
|
||||
STARTER_MODELS: list[StarterModel] = [
|
||||
@@ -683,6 +703,8 @@ STARTER_MODELS: list[StarterModel] = [
|
||||
clip_l_encoder,
|
||||
siglip,
|
||||
flux_redux,
|
||||
llava_onevision,
|
||||
flux_fill,
|
||||
]
|
||||
|
||||
sd1_bundle: list[StarterModel] = [
|
||||
@@ -731,6 +753,7 @@ flux_bundle: list[StarterModel] = [
|
||||
flux_canny_control_lora,
|
||||
flux_depth_control_lora,
|
||||
flux_redux,
|
||||
flux_fill,
|
||||
]
|
||||
|
||||
STARTER_BUNDLES: dict[str, list[StarterModel]] = {
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
"""Utilities for parsing model files, used mostly by probe.py"""
|
||||
"""Utilities for parsing model files, used mostly by legacy_probe.py"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import picklescan.scanner as pscan
|
||||
import safetensors
|
||||
import torch
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
from invokeai.backend.model_manager.config import ClipVariantType
|
||||
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
|
||||
@@ -57,9 +57,12 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = True) -> Dict[str,
|
||||
checkpoint = gguf_sd_loader(Path(path), compute_dtype=torch.float32)
|
||||
else:
|
||||
if scan:
|
||||
scan_result = scan_file_path(path)
|
||||
if scan_result.infected_files != 0 or scan_result.scan_err:
|
||||
raise Exception(f'The model file "{path}" is potentially infected by malware. Aborting import.')
|
||||
scan_result = pscan.scan_file_path(path)
|
||||
if scan_result.infected_files != 0:
|
||||
raise Exception(f"The model at {path} is potentially infected by malware. Aborting import.")
|
||||
if scan_result.scan_err:
|
||||
raise Exception(f"Error scanning model at {path} for malware. Aborting import.")
|
||||
|
||||
checkpoint = torch.load(path, map_location=torch.device("meta"))
|
||||
return checkpoint
|
||||
|
||||
|
||||
@@ -114,7 +114,9 @@
|
||||
"layout": "Layout",
|
||||
"board": "Ordner",
|
||||
"combinatorial": "Kombinatorisch",
|
||||
"saveChanges": "Änderungen speichern"
|
||||
"saveChanges": "Änderungen speichern",
|
||||
"error_withCount_one": "{{count}} Fehler",
|
||||
"error_withCount_other": "{{count}} Fehler"
|
||||
},
|
||||
"gallery": {
|
||||
"galleryImageSize": "Bildgröße",
|
||||
@@ -764,10 +766,10 @@
|
||||
"layerCopiedToClipboard": "Ebene in die Zwischenablage kopiert",
|
||||
"sentToCanvas": "An Leinwand gesendet",
|
||||
"problemDeletingWorkflow": "Problem beim Löschen des Arbeitsablaufs",
|
||||
"uploadFailedInvalidUploadDesc_withCount_one": "Es darf maximal 1 PNG- oder JPEG-Bild sein.",
|
||||
"uploadFailedInvalidUploadDesc_withCount_other": "Es dürfen maximal {{count}} PNG- oder JPEG-Bilder sein.",
|
||||
"uploadFailedInvalidUploadDesc_withCount_one": "Darf maximal 1 PNG-, JPEG- oder WEBP-Bild sein.",
|
||||
"uploadFailedInvalidUploadDesc_withCount_other": "Dürfen maximal {{count}} PNG-, JPEG- oder WEBP-Bild sein.",
|
||||
"problemRetrievingWorkflow": "Problem beim Abrufen des Arbeitsablaufs",
|
||||
"uploadFailedInvalidUploadDesc": "Müssen PNG- oder JPEG-Bilder sein.",
|
||||
"uploadFailedInvalidUploadDesc": "Müssen PNG-, JPEG- oder WEBP-Bilder sein.",
|
||||
"pasteSuccess": "Eingefügt in {{destination}}",
|
||||
"pasteFailed": "Einfügen fehlgeschlagen",
|
||||
"unableToCopy": "Kopieren nicht möglich",
|
||||
@@ -1259,7 +1261,6 @@
|
||||
"nodePack": "Knoten-Pack",
|
||||
"loadWorkflow": "Lade Workflow",
|
||||
"snapToGrid": "Am Gitternetz einrasten",
|
||||
"unknownOutput": "Unbekannte Ausgabe: {{name}}",
|
||||
"updateNode": "Knoten updaten",
|
||||
"edge": "Rand / Kante",
|
||||
"sourceNodeDoesNotExist": "Ungültiger Rand: Quell- / Ausgabe-Knoten {{node}} existiert nicht",
|
||||
@@ -1325,7 +1326,9 @@
|
||||
"description": "Beschreibung",
|
||||
"loadWorkflowDesc": "Arbeitsablauf laden?",
|
||||
"loadWorkflowDesc2": "Ihr aktueller Arbeitsablauf enthält nicht gespeicherte Änderungen.",
|
||||
"loadingTemplates": "Lade {{name}}"
|
||||
"loadingTemplates": "Lade {{name}}",
|
||||
"missingSourceOrTargetHandle": "Fehlender Quell- oder Zielgriff",
|
||||
"missingSourceOrTargetNode": "Fehlender Quell- oder Zielknoten"
|
||||
},
|
||||
"hrf": {
|
||||
"enableHrf": "Korrektur für hohe Auflösungen",
|
||||
|
||||
@@ -194,7 +194,9 @@
|
||||
"combinatorial": "Combinatorial",
|
||||
"layout": "Layout",
|
||||
"row": "Row",
|
||||
"column": "Column"
|
||||
"column": "Column",
|
||||
"value": "Value",
|
||||
"label": "Label"
|
||||
},
|
||||
"hrf": {
|
||||
"hrf": "High Resolution Fix",
|
||||
@@ -846,6 +848,7 @@
|
||||
"starterModels": "Starter Models",
|
||||
"starterModelsInModelManager": "Starter Models can be found in Model Manager",
|
||||
"controlLora": "Control LoRA",
|
||||
"llavaOnevision": "LLaVA OneVision",
|
||||
"syncModels": "Sync Models",
|
||||
"textualInversions": "Textual Inversions",
|
||||
"triggerPhrases": "Trigger Phrases",
|
||||
@@ -1013,7 +1016,10 @@
|
||||
"unknownNodeType": "Unknown node type",
|
||||
"unknownTemplate": "Unknown Template",
|
||||
"unknownInput": "Unknown input: {{name}}",
|
||||
"unknownOutput": "Unknown output: {{name}}",
|
||||
"missingField_withName": "Missing field \"{{name}}\"",
|
||||
"unexpectedField_withName": "Unexpected field \"{{name}}\"",
|
||||
"unknownField_withName": "Unknown field \"{{name}}\"",
|
||||
"unknownFieldEditWorkflowToFix_withName": "Workflow contains an unknown field \"{{name}}\".\nEdit the workflow to fix the issue.",
|
||||
"updateNode": "Update Node",
|
||||
"updateApp": "Update App",
|
||||
"loadingTemplates": "Loading {{name}}",
|
||||
@@ -1298,7 +1304,8 @@
|
||||
"problemDeletingWorkflow": "Problem Deleting Workflow",
|
||||
"unableToCopy": "Unable to Copy",
|
||||
"unableToCopyDesc": "Your browser does not support clipboard access. Firefox users may be able to fix this by following ",
|
||||
"unableToCopyDesc_theseSteps": "these steps"
|
||||
"unableToCopyDesc_theseSteps": "these steps",
|
||||
"fluxFillIncompatibleWithT2IAndI2I": "FLUX Fill is not compatible with Text to Image or Image to Image. Use other FLUX models for these tasks."
|
||||
},
|
||||
"popovers": {
|
||||
"clipSkip": {
|
||||
@@ -1700,6 +1707,7 @@
|
||||
"shared": "Shared",
|
||||
"browseWorkflows": "Browse Workflows",
|
||||
"deselectAll": "Deselect All",
|
||||
"recommended": "Recommended For You",
|
||||
"opened": "Opened",
|
||||
"openWorkflow": "Open Workflow",
|
||||
"updated": "Updated",
|
||||
@@ -1738,6 +1746,7 @@
|
||||
"openLibrary": "Open Library",
|
||||
"workflowThumbnail": "Workflow Thumbnail",
|
||||
"saveChanges": "Save Changes",
|
||||
"emptyStringPlaceholder": "<empty string>",
|
||||
"builder": {
|
||||
"deleteAllElements": "Delete All Form Elements",
|
||||
"resetAllNodeFields": "Reset All Node Fields",
|
||||
@@ -1762,6 +1771,9 @@
|
||||
"singleLine": "Single Line",
|
||||
"multiLine": "Multi Line",
|
||||
"slider": "Slider",
|
||||
"dropdown": "Dropdown",
|
||||
"addOption": "Add Option",
|
||||
"resetOptions": "Reset Options",
|
||||
"both": "Both",
|
||||
"emptyRootPlaceholderViewMode": "Click Edit to start building a form for this workflow.",
|
||||
"emptyRootPlaceholderEditMode": "Drag a form element or node field here to get started.",
|
||||
@@ -1948,7 +1960,8 @@
|
||||
"rgNegativePromptNotSupported": "Negative Prompt not supported for selected base model",
|
||||
"rgReferenceImagesNotSupported": "regional Reference Images not supported for selected base model",
|
||||
"rgAutoNegativeNotSupported": "Auto-Negative not supported for selected base model",
|
||||
"rgNoRegion": "no region drawn"
|
||||
"rgNoRegion": "no region drawn",
|
||||
"fluxFillIncompatibleWithControlLoRA": "Control LoRA is not compatible with FLUX Fill"
|
||||
},
|
||||
"errors": {
|
||||
"unableToFindImage": "Unable to find image",
|
||||
@@ -2331,7 +2344,7 @@
|
||||
"whatsNewInInvoke": "What's New in Invoke",
|
||||
"items": [
|
||||
"Workflows: New and improved Workflow Library.",
|
||||
"FLUX: Support for FLUX Redux in Workflows and Canvas."
|
||||
"FLUX: Support for FLUX Redux & FLUX Fill in Workflows and Canvas."
|
||||
],
|
||||
"readReleaseNotes": "Read Release Notes",
|
||||
"watchRecentReleaseVideos": "Watch Recent Release Videos",
|
||||
|
||||
@@ -1653,7 +1653,6 @@
|
||||
"collectionFieldType": "{{name}} (Collection)",
|
||||
"newWorkflow": "Nouveau Workflow",
|
||||
"reorderLinearView": "Réorganiser la vue linéaire",
|
||||
"unknownOutput": "Sortie inconnue : {{name}}",
|
||||
"outputFieldTypeParseError": "Impossible d'analyser le type du champ de sortie {{node}}.{{field}} ({{message}})",
|
||||
"unsupportedMismatchedUnion": "type CollectionOrScalar non concordant avec les types de base {{firstType}} et {{secondType}}",
|
||||
"unableToParseFieldType": "impossible d'analyser le type de champ",
|
||||
|
||||
@@ -113,7 +113,9 @@
|
||||
"saveChanges": "Salva modifiche",
|
||||
"error_withCount_one": "{{count}} errore",
|
||||
"error_withCount_many": "{{count}} errori",
|
||||
"error_withCount_other": "{{count}} errori"
|
||||
"error_withCount_other": "{{count}} errori",
|
||||
"value": "Valore",
|
||||
"label": "Etichetta"
|
||||
},
|
||||
"gallery": {
|
||||
"galleryImageSize": "Dimensione dell'immagine",
|
||||
@@ -779,7 +781,8 @@
|
||||
"enableModelDescriptions": "Abilita le descrizioni dei modelli nei menu a discesa",
|
||||
"modelDescriptionsDisabled": "Descrizioni dei modelli nei menu a discesa disabilitate",
|
||||
"modelDescriptionsDisabledDesc": "Le descrizioni dei modelli nei menu a discesa sono state disabilitate. Abilitale nelle Impostazioni.",
|
||||
"showDetailedInvocationProgress": "Mostra dettagli avanzamento"
|
||||
"showDetailedInvocationProgress": "Mostra dettagli avanzamento",
|
||||
"enableHighlightFocusedRegions": "Evidenzia le regioni interessate"
|
||||
},
|
||||
"toast": {
|
||||
"uploadFailed": "Caricamento fallito",
|
||||
@@ -847,7 +850,8 @@
|
||||
"pasteSuccess": "Incollato su {{destination}}",
|
||||
"unableToCopy": "Impossibile copiare",
|
||||
"unableToCopyDesc": "Il tuo browser non supporta l'accesso agli appunti. Gli utenti di Firefox potrebbero risolvere il problema seguendo ",
|
||||
"unableToCopyDesc_theseSteps": "questi passaggi"
|
||||
"unableToCopyDesc_theseSteps": "questi passaggi",
|
||||
"fluxFillIncompatibleWithT2IAndI2I": "FLUX Fill non è compatibile con Testo a Immagine o Immagine a Immagine. Per queste attività, utilizzare altri modelli FLUX."
|
||||
},
|
||||
"accessibility": {
|
||||
"invokeProgressBar": "Barra di avanzamento generazione",
|
||||
@@ -968,7 +972,6 @@
|
||||
"unableToGetWorkflowVersion": "Impossibile ottenere la versione dello schema del flusso di lavoro",
|
||||
"nodePack": "Pacchetto di nodi",
|
||||
"unableToExtractSchemaNameFromRef": "Impossibile estrarre il nome dello schema dal riferimento",
|
||||
"unknownOutput": "Output sconosciuto: {{name}}",
|
||||
"unknownNodeType": "Tipo di nodo sconosciuto",
|
||||
"targetNodeDoesNotExist": "Connessione non valida: il nodo di destinazione/input {{node}} non esiste",
|
||||
"unknownFieldType": "$t(nodes.unknownField) tipo: {{type}}",
|
||||
@@ -1038,7 +1041,11 @@
|
||||
"generatorImages_many": "{{count}} immagini",
|
||||
"generatorImages_other": "{{count}} immagini",
|
||||
"generatorImagesFromBoard": "Immagini dalla Bacheca",
|
||||
"missingSourceOrTargetNode": "Nodo sorgente o di destinazione mancante"
|
||||
"missingSourceOrTargetNode": "Nodo sorgente o di destinazione mancante",
|
||||
"unknownField_withName": "Campo \"{{name}}\" sconosciuto",
|
||||
"missingField_withName": "Campo \"{{name}}\" mancante",
|
||||
"unknownFieldEditWorkflowToFix_withName": "Il flusso di lavoro contiene un campo \"{{name}}\" sconosciuto .\nModifica il flusso di lavoro per risolvere il problema.",
|
||||
"unexpectedField_withName": "Campo \"{{name}}\" inaspettato"
|
||||
},
|
||||
"boards": {
|
||||
"autoAddBoard": "Aggiungi automaticamente bacheca",
|
||||
@@ -1776,7 +1783,12 @@
|
||||
"text": "Testo",
|
||||
"numberInput": "Ingresso numerico",
|
||||
"containerRowLayout": "Contenitore (disposizione riga)",
|
||||
"containerColumnLayout": "Contenitore (disposizione colonna)"
|
||||
"containerColumnLayout": "Contenitore (disposizione colonna)",
|
||||
"minimum": "Minimo",
|
||||
"maximum": "Massimo",
|
||||
"dropdown": "Elenco a discesa",
|
||||
"addOption": "Aggiungi opzione",
|
||||
"resetOptions": "Reimposta opzioni"
|
||||
},
|
||||
"loadMore": "Carica altro",
|
||||
"searchPlaceholder": "Cerca per nome, descrizione o etichetta",
|
||||
@@ -1791,7 +1803,9 @@
|
||||
"private": "Privato",
|
||||
"deselectAll": "Deseleziona tutto",
|
||||
"noRecentWorkflows": "Nessun flusso di lavoro recente",
|
||||
"view": "Visualizza"
|
||||
"view": "Visualizza",
|
||||
"recommended": "Consigliato per te",
|
||||
"emptyStringPlaceholder": "<stringa vuota>"
|
||||
},
|
||||
"accordions": {
|
||||
"compositing": {
|
||||
@@ -2235,7 +2249,8 @@
|
||||
"rgNegativePromptNotSupported": "Prompt negativo non supportato per il modello base selezionato",
|
||||
"ipAdapterIncompatibleBaseModel": "modello base dell'immagine di riferimento incompatibile",
|
||||
"ipAdapterNoImageSelected": "nessuna immagine di riferimento selezionata",
|
||||
"rgAutoNegativeNotSupported": "Auto-Negativo non supportato per il modello base selezionato"
|
||||
"rgAutoNegativeNotSupported": "Auto-Negativo non supportato per il modello base selezionato",
|
||||
"fluxFillIncompatibleWithControlLoRA": "Il controllo LoRA non è compatibile con FLUX Fill"
|
||||
},
|
||||
"pasteTo": "Incolla su",
|
||||
"pasteToBboxDesc": "Nuovo livello (nel riquadro di delimitazione)",
|
||||
@@ -2350,8 +2365,8 @@
|
||||
"watchRecentReleaseVideos": "Guarda i video su questa versione",
|
||||
"watchUiUpdatesOverview": "Guarda le novità dell'interfaccia",
|
||||
"items": [
|
||||
"Gestione della memoria: nuova impostazione per gli utenti con GPU Nvidia per ridurre l'utilizzo della VRAM.",
|
||||
"Prestazioni: continui miglioramenti alle prestazioni e alla reattività complessive dell'applicazione."
|
||||
"Flussi di lavoro: nuova e migliorata libreria dei flussi di lavoro.",
|
||||
"FLUX: supporto per FLUX Redux e FLUX Fill in Flussi di lavoro e Tela."
|
||||
]
|
||||
},
|
||||
"system": {
|
||||
|
||||
@@ -425,7 +425,6 @@
|
||||
"newWorkflow": "Nieuwe werkstroom",
|
||||
"unknownErrorValidatingWorkflow": "Onbekende fout bij valideren werkstroom",
|
||||
"unsupportedAnyOfLength": "te veel union-leden ({{count}})",
|
||||
"unknownOutput": "Onbekende uitvoer: {{name}}",
|
||||
"viewMode": "Gebruik in lineaire weergave",
|
||||
"unableToExtractSchemaNameFromRef": "fout bij het extraheren van de schemanaam via de ref",
|
||||
"unsupportedMismatchedUnion": "niet-overeenkomende soort CollectionOrScalar met basissoorten {{firstType}} en {{secondType}}",
|
||||
|
||||
@@ -879,7 +879,6 @@
|
||||
"unableToExtractSchemaNameFromRef": "невозможно извлечь имя схемы из ссылки",
|
||||
"executionStateError": "Ошибка",
|
||||
"prototypeDesc": "Этот вызов является прототипом. Он может претерпевать изменения при обновлении приложения и может быть удален в любой момент.",
|
||||
"unknownOutput": "Неизвестный вывод: {{name}}",
|
||||
"executionStateCompleted": "Выполнено",
|
||||
"node": "Узел",
|
||||
"workflowAuthor": "Автор",
|
||||
|
||||
@@ -236,7 +236,8 @@
|
||||
"layout": "Bố Cục",
|
||||
"row": "Hàng",
|
||||
"board": "Bảng",
|
||||
"saveChanges": "Lưu Thay Đổi"
|
||||
"saveChanges": "Lưu Thay Đổi",
|
||||
"error_withCount_other": "{{count}} lỗi"
|
||||
},
|
||||
"prompt": {
|
||||
"addPromptTrigger": "Thêm Prompt Trigger",
|
||||
@@ -769,7 +770,8 @@
|
||||
"urlForbiddenErrorMessage": "Bạn có thể cần yêu cầu quyền truy cập từ trang web đang cung cấp model.",
|
||||
"urlUnauthorizedErrorMessage": "Bạn có thể cần thiếp lập một token API để dùng được model này.",
|
||||
"fluxRedux": "FLUX Redux",
|
||||
"sigLip": "SigLIP"
|
||||
"sigLip": "SigLIP",
|
||||
"llavaOnevision": "LLaVA OneVision"
|
||||
},
|
||||
"metadata": {
|
||||
"guidance": "Hướng Dẫn",
|
||||
@@ -892,7 +894,6 @@
|
||||
"targetNodeFieldDoesNotExist": "Kết nối không phù hợp: đích đến/đầu vào của vùng {{node}}.{{field}} không tồn tại",
|
||||
"missingTemplate": "Node không hợp lệ: node {{node}} thuộc loại {{type}} bị thiếu mẫu trình bày (chưa tải?)",
|
||||
"unsupportedMismatchedUnion": "Dạng số lượng dữ liệu không khớp với {{firstType}} và {{secondType}}",
|
||||
"unknownOutput": "Đầu Ra Không Rõ: {{name}}",
|
||||
"betaDesc": "Trình kích hoạt này vẫn trong giai đoạn beta. Cho đến khi ổn định, nó có thể phá hỏng thay đổi trong khi cập nhật ứng dụng. Chúng tôi dự định hỗ trợ trình kích hoạt này về lâu dài.",
|
||||
"cannotConnectInputToInput": "Không thế kết nối đầu vào với đầu vào",
|
||||
"showEdgeLabelsHelp": "Hiển thị tên trên kết nối, chỉ ra những node được kết nối",
|
||||
@@ -1019,7 +1020,11 @@
|
||||
"downloadWorkflowError": "Lỗi tải xuống workflow",
|
||||
"generatorImagesFromBoard": "Ảnh Từ Bảng",
|
||||
"generatorImagesCategory": "Phân Loại",
|
||||
"generatorImages_other": "{{count}} ảnh"
|
||||
"generatorImages_other": "{{count}} ảnh",
|
||||
"unknownField_withName": "Vùng Dữ Liệu Không Rõ \"{{name}}\"",
|
||||
"unexpectedField_withName": "Sai Vùng Dữ Liệu \"{{name}}\"",
|
||||
"unknownFieldEditWorkflowToFix_withName": "Workflow chứa vùng dữ liệu không rõ \"{{name}}\".\nHãy biên tập workflow để sửa lỗi.",
|
||||
"missingField_withName": "Thiếu Vùng Dữ Liệu \"{{name}}\""
|
||||
},
|
||||
"popovers": {
|
||||
"paramCFGRescaleMultiplier": {
|
||||
@@ -1618,7 +1623,8 @@
|
||||
"displayInProgress": "Hiển Thị Hình Ảnh Đang Xử Lý",
|
||||
"intermediatesClearedFailed": "Có Vấn Đề Khi Dọn Sạch Sản Phẩm Trung Gian",
|
||||
"enableInvisibleWatermark": "Bật Chế Độ Ẩn Watermark",
|
||||
"showDetailedInvocationProgress": "Hiện Dữ Liệu Xử Lý"
|
||||
"showDetailedInvocationProgress": "Hiện Dữ Liệu Xử Lý",
|
||||
"enableHighlightFocusedRegions": "Nhấn Mạnh Khu Vực Chỉ Định"
|
||||
},
|
||||
"sdxl": {
|
||||
"loading": "Đang Tải...",
|
||||
@@ -2048,7 +2054,8 @@
|
||||
"rgNegativePromptNotSupported": "Lệnh Tiêu Cực không được hỗ trợ cho model cơ sở được chọn",
|
||||
"rgReferenceImagesNotSupported": "Ảnh Mẫu Khu Vực không được hỗ trợ cho model cơ sở được chọn",
|
||||
"rgAutoNegativeNotSupported": "Tự Động Đảo Chiều không được hỗ trợ cho model cơ sở được chọn",
|
||||
"rgNoRegion": "không có khu vực được vẽ"
|
||||
"rgNoRegion": "không có khu vực được vẽ",
|
||||
"fluxFillIncompatibleWithControlLoRA": "LoRA Điều Khiển Được không tương tích với FLUX Fill"
|
||||
},
|
||||
"pasteTo": "Dán Vào",
|
||||
"pasteToAssets": "Tài Nguyên",
|
||||
@@ -2199,7 +2206,8 @@
|
||||
"unableToCopyDesc_theseSteps": "các bước sau",
|
||||
"unableToCopyDesc": "Trình duyệt của bạn không hỗ trợ tính năng clipboard. Người dùng Firefox có thể khắc phục theo ",
|
||||
"pasteSuccess": "Dán Vào {{destination}}",
|
||||
"pasteFailed": "Dán Thất Bại"
|
||||
"pasteFailed": "Dán Thất Bại",
|
||||
"fluxFillIncompatibleWithT2IAndI2I": "FLUX Fill không tương tích với Từ Ngữ Sang Hình Ảnh và Hình Ảnh Sang Hình Ảnh. Dùng model FLUX khác cho các tính năng này."
|
||||
},
|
||||
"ui": {
|
||||
"tabs": {
|
||||
@@ -2288,7 +2296,11 @@
|
||||
"container": "Hộp Chứa",
|
||||
"heading": "Đầu Dòng",
|
||||
"text": "Văn Bản",
|
||||
"divider": "Gạch Chia"
|
||||
"divider": "Gạch Chia",
|
||||
"minimum": "Tối Thiểu",
|
||||
"maximum": "Tối Đa",
|
||||
"containerRowLayout": "Hộp Chứa (bố cục hàng)",
|
||||
"containerColumnLayout": "Hộp Chứa (bố cục cột)"
|
||||
},
|
||||
"yourWorkflows": "Workflow Của Bạn",
|
||||
"browseWorkflows": "Khám Phá Workflow",
|
||||
@@ -2300,7 +2312,11 @@
|
||||
"filterByTags": "Lọc Theo Nhãn",
|
||||
"recentlyOpened": "Mở Gần Đây",
|
||||
"private": "Cá Nhân",
|
||||
"loadMore": "Tải Thêm"
|
||||
"loadMore": "Tải Thêm",
|
||||
"view": "Xem",
|
||||
"deselectAll": "Huỷ Chọn Tất Cả",
|
||||
"noRecentWorkflows": "Không Có Workflows Gần Đây",
|
||||
"recommended": "Có Thể Bạn Sẽ Cần"
|
||||
},
|
||||
"upscaling": {
|
||||
"missingUpscaleInitialImage": "Thiếu ảnh dùng để upscale",
|
||||
@@ -2327,7 +2343,8 @@
|
||||
"gettingStartedSeries": "Cần thêm hướng dẫn? Xem thử <LinkComponent>Bắt Đầu Làm Quen</LinkComponent> để biết thêm mẹo khai thác toàn bộ tiềm năng của Invoke Studio.",
|
||||
"toGetStarted": "Để bắt đầu, hãy nhập lệnh vào hộp và nhấp chuột vào <StrongComponent>Kích Hoạt</StrongComponent> để tạo ra bức ảnh đầu tiên. Chọn một mẫu trình bày cho lệnh để cải thiện kết quả. Bạn có thể chọn để lưu ảnh trực tiếp vào <StrongComponent>Thư Viện Ảnh</StrongComponent> hoặc chỉnh sửa chúng ở <StrongComponent>Canvas</StrongComponent>.",
|
||||
"noModelsInstalled": "Dường như bạn chưa tải model nào cả! Bạn có thể <DownloadStarterModelsButton>tải xuống các model khởi đầu</DownloadStarterModelsButton> hoặc <ImportModelsButton>nhập vào thêm model</ImportModelsButton>.",
|
||||
"lowVRAMMode": "Cho hiệu suất tốt nhất, hãy làm theo <LinkComponent>hướng dẫn VRAM Thấp</LinkComponent> của chúng tôi."
|
||||
"lowVRAMMode": "Cho hiệu suất tốt nhất, hãy làm theo <LinkComponent>hướng dẫn VRAM Thấp</LinkComponent> của chúng tôi.",
|
||||
"toGetStartedWorkflow": "Để bắt đầu, hãy điền vào khu vực bên trái và bấm <StrongComponent>Kích Hoạt</StrongComponent> nhằm tạo sinh ảnh. Muốn khám phá thêm workflow? Nhấp vào <StrongComponent>icon thư mục</StrongComponent> nằm cạnh tiêu đề workflow để xem một dãy các mẫu trình bày khác."
|
||||
},
|
||||
"whatsNew": {
|
||||
"whatsNewInInvoke": "Có Gì Mới Ở Invoke",
|
||||
@@ -2335,8 +2352,8 @@
|
||||
"watchRecentReleaseVideos": "Xem Video Phát Hành Mới Nhất",
|
||||
"watchUiUpdatesOverview": "Xem Tổng Quan Về Những Cập Nhật Cho Giao Diện Người Dùng",
|
||||
"items": [
|
||||
"Trình Quản Lý Bộ Nhớ: Thiết lập mới cho người dùng với GPU Nvidia để giảm lượng VRAM sử dụng.",
|
||||
"Hiệu suất: Các cải thiện tiếp theo nhằm gói gọn hiệu suất và khả năng phản hồi của ứng dụng."
|
||||
"Workflow: Thư Viện Workflow mới và đã được cải tiến.",
|
||||
"FLUX: Hỗ trợ FLUX Redux & FLUX Fill trong Workflow và Canvas."
|
||||
]
|
||||
},
|
||||
"upsell": {
|
||||
|
||||
@@ -908,7 +908,6 @@
|
||||
"unableToGetWorkflowVersion": "无法获取工作流架构版本",
|
||||
"nodePack": "节点包",
|
||||
"unableToExtractSchemaNameFromRef": "无法从参考中提取架构名",
|
||||
"unknownOutput": "未知输出:{{name}}",
|
||||
"unknownErrorValidatingWorkflow": "验证工作流时出现未知错误",
|
||||
"collectionFieldType": "{{name}}(合集)",
|
||||
"unknownNodeType": "未知节点类型",
|
||||
|
||||
@@ -12,6 +12,12 @@ import { sentImageToCanvas } from 'features/gallery/store/actions';
|
||||
import { parseAndRecallAllMetadata } from 'features/metadata/util/handlers';
|
||||
import { $hasTemplates } from 'features/nodes/store/nodesSlice';
|
||||
import { $isWorkflowLibraryModalOpen } from 'features/nodes/store/workflowLibraryModal';
|
||||
import {
|
||||
$workflowLibraryTagOptions,
|
||||
workflowLibraryTagsReset,
|
||||
workflowLibraryTagToggled,
|
||||
workflowLibraryViewChanged,
|
||||
} from 'features/nodes/store/workflowLibrarySlice';
|
||||
import { $isStylePresetsMenuOpen, activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { activeTabCanvasRightPanelChanged, setActiveTab } from 'features/ui/store/uiSlice';
|
||||
@@ -30,9 +36,17 @@ type SendToCanvasAction = _StudioInitAction<'sendToCanvas', { imageName: string
|
||||
type UseAllParametersAction = _StudioInitAction<'useAllParameters', { imageName: string }>;
|
||||
type StudioDestinationAction = _StudioInitAction<
|
||||
'goToDestination',
|
||||
{ destination: 'generation' | 'canvas' | 'workflows' | 'upscaling' | 'viewAllWorkflows' | 'viewAllStylePresets' }
|
||||
{
|
||||
destination:
|
||||
| 'generation'
|
||||
| 'canvas'
|
||||
| 'workflows'
|
||||
| 'upscaling'
|
||||
| 'viewAllWorkflows'
|
||||
| 'viewAllWorkflowsRecommended'
|
||||
| 'viewAllStylePresets';
|
||||
}
|
||||
>;
|
||||
|
||||
// Use global state to show loader until we are ready to render the studio.
|
||||
export const $didStudioInit = atom(false);
|
||||
|
||||
@@ -58,6 +72,7 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
|
||||
const didParseOpenAPISchema = useStore($hasTemplates);
|
||||
const store = useAppStore();
|
||||
const loadWorkflowWithDialog = useLoadWorkflowWithDialog();
|
||||
const workflowLibraryTagOptions = useStore($workflowLibraryTagOptions);
|
||||
|
||||
const handleSendToCanvas = useCallback(
|
||||
async (imageName: string) => {
|
||||
@@ -173,6 +188,18 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
|
||||
store.dispatch(setActiveTab('workflows'));
|
||||
$isWorkflowLibraryModalOpen.set(true);
|
||||
break;
|
||||
case 'viewAllWorkflowsRecommended':
|
||||
// Go to the workflows tab and open the workflow library modal with the recommended workflows view
|
||||
store.dispatch(setActiveTab('workflows'));
|
||||
$isWorkflowLibraryModalOpen.set(true);
|
||||
store.dispatch(workflowLibraryViewChanged('defaults'));
|
||||
store.dispatch(workflowLibraryTagsReset());
|
||||
for (const tag of workflowLibraryTagOptions) {
|
||||
if (tag.recommended) {
|
||||
store.dispatch(workflowLibraryTagToggled(tag.label));
|
||||
}
|
||||
}
|
||||
break;
|
||||
case 'viewAllStylePresets':
|
||||
// Go to the canvas tab and open the style presets menu
|
||||
store.dispatch(setActiveTab('canvas'));
|
||||
@@ -180,7 +207,7 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
|
||||
break;
|
||||
}
|
||||
},
|
||||
[store]
|
||||
[store, workflowLibraryTagOptions]
|
||||
);
|
||||
|
||||
const handleStudioInitAction = useCallback(
|
||||
|
||||
@@ -27,7 +27,8 @@ export type AppFeature =
|
||||
| 'bulkDownload'
|
||||
| 'starterModels'
|
||||
| 'hfToken'
|
||||
| 'retryQueueItem';
|
||||
| 'retryQueueItem'
|
||||
| 'cancelAndClearAll';
|
||||
/**
|
||||
* A disable-able Stable Diffusion feature
|
||||
*/
|
||||
|
||||
@@ -19,7 +19,7 @@ const styles: CSSProperties = { position: 'absolute', top: 0, left: 0, right: 0,
|
||||
|
||||
const ScrollableContent = ({ children, maxHeight, overflowX = 'hidden', overflowY = 'scroll' }: Props) => {
|
||||
const overlayscrollbarsOptions = useMemo(
|
||||
() => getOverlayScrollbarsParams(overflowX, overflowY).options,
|
||||
() => getOverlayScrollbarsParams({ overflowX, overflowY }).options,
|
||||
[overflowX, overflowY]
|
||||
);
|
||||
const [os, osRef] = useState<OverlayScrollbarsComponentRef | null>(null);
|
||||
|
||||
@@ -20,12 +20,22 @@ export const overlayScrollbarsParams: UseOverlayScrollbarsParams = {
|
||||
},
|
||||
};
|
||||
|
||||
export const getOverlayScrollbarsParams = (
|
||||
overflowX: 'hidden' | 'scroll' = 'hidden',
|
||||
overflowY: 'hidden' | 'scroll' = 'scroll'
|
||||
) => {
|
||||
export const getOverlayScrollbarsParams = ({
|
||||
overflowX = 'hidden',
|
||||
overflowY = 'scroll',
|
||||
visibility = 'auto',
|
||||
}: {
|
||||
overflowX?: 'hidden' | 'scroll';
|
||||
overflowY?: 'hidden' | 'scroll';
|
||||
visibility?: 'auto' | 'hidden' | 'visible';
|
||||
}) => {
|
||||
const params = deepClone(overlayScrollbarsParams);
|
||||
merge(params, { options: { overflow: { y: overflowY, x: overflowX } } });
|
||||
merge(params, {
|
||||
options: {
|
||||
overflow: { y: overflowY, x: overflowX },
|
||||
scrollbars: { visibility, autoHide: visibility === 'visible' ? 'never' : 'scroll' },
|
||||
},
|
||||
});
|
||||
return params;
|
||||
};
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { useEntityIsEnabled } from 'features/controlLayers/hooks/useEntityIsEnabled';
|
||||
import { selectModel } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
|
||||
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import {
|
||||
@@ -19,11 +18,12 @@ import { upperFirst } from 'lodash-es';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiWarningBold } from 'react-icons/pi';
|
||||
import { selectMainModelConfig } from 'services/api/endpoints/models';
|
||||
import type { Equals } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const buildSelectWarnings = (entityIdentifier: CanvasEntityIdentifier, t: TFunction) => {
|
||||
return createSelector(selectCanvasSlice, selectModel, (canvas, model) => {
|
||||
return createSelector(selectCanvasSlice, selectMainModelConfig, (canvas, model) => {
|
||||
// This component is used within a <CanvasEntityStateGate /> so we can safely assume that the entity exists.
|
||||
// Should never throw.
|
||||
const entity = selectEntityOrThrow(canvas, entityIdentifier, 'CanvasEntityHeaderWarnings');
|
||||
|
||||
@@ -301,8 +301,9 @@ export class CanvasBrushToolModule extends CanvasModuleBase {
|
||||
points,
|
||||
strokeWidth: settings.brushWidth,
|
||||
color: this.manager.stateApi.getCurrentColor(),
|
||||
// When shift is held, the line may extend beyond the clip region. No clip for these lines.
|
||||
clip: isShiftDraw ? null : this.parent.getClip(selectedEntity.state),
|
||||
// When shift is held, the line may extend beyond the clip region. Clip only if we are clipping to bbox. If we
|
||||
// are clipping to stage, we don't need to clip at all.
|
||||
clip: isShiftDraw && !settings.clipToBbox ? null : this.parent.getClip(selectedEntity.state),
|
||||
});
|
||||
} else {
|
||||
const lastLinePoint = getLastPointOfLastLine(selectedEntity.state.objects, 'brush_line');
|
||||
@@ -325,8 +326,9 @@ export class CanvasBrushToolModule extends CanvasModuleBase {
|
||||
points,
|
||||
strokeWidth: settings.brushWidth,
|
||||
color: this.manager.stateApi.getCurrentColor(),
|
||||
// When shift is held, the line may extend beyond the clip region. No clip for these lines.
|
||||
clip: isShiftDraw ? null : this.parent.getClip(selectedEntity.state),
|
||||
// When shift is held, the line may extend beyond the clip region. Clip only if we are clipping to bbox. If we
|
||||
// are clipping to stage, we don't need to clip at all.
|
||||
clip: isShiftDraw && !settings.clipToBbox ? null : this.parent.getClip(selectedEntity.state),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
@@ -5,7 +5,7 @@ import type {
|
||||
CanvasReferenceImageState,
|
||||
CanvasRegionalGuidanceState,
|
||||
} from 'features/controlLayers/store/types';
|
||||
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
|
||||
import type { MainModelConfig } from 'services/api/types';
|
||||
|
||||
const WARNINGS = {
|
||||
UNSUPPORTED_MODEL: 'controlLayers.warnings.unsupportedModel',
|
||||
@@ -20,13 +20,14 @@ const WARNINGS = {
|
||||
CONTROL_ADAPTER_NO_MODEL_SELECTED: 'controlLayers.warnings.controlAdapterNoModelSelected',
|
||||
CONTROL_ADAPTER_INCOMPATIBLE_BASE_MODEL: 'controlLayers.warnings.controlAdapterIncompatibleBaseModel',
|
||||
CONTROL_ADAPTER_NO_CONTROL: 'controlLayers.warnings.controlAdapterNoControl',
|
||||
FLUX_FILL_NO_WORKY_WITH_CONTROL_LORA: 'controlLayers.warnings.fluxFillIncompatibleWithControlLoRA',
|
||||
} as const;
|
||||
|
||||
type WarningTKey = (typeof WARNINGS)[keyof typeof WARNINGS];
|
||||
|
||||
export const getRegionalGuidanceWarnings = (
|
||||
entity: CanvasRegionalGuidanceState,
|
||||
model: ParameterModel | null
|
||||
model: MainModelConfig | null | undefined
|
||||
): WarningTKey[] => {
|
||||
const warnings: WarningTKey[] = [];
|
||||
|
||||
@@ -78,7 +79,7 @@ export const getRegionalGuidanceWarnings = (
|
||||
|
||||
export const getGlobalReferenceImageWarnings = (
|
||||
entity: CanvasReferenceImageState,
|
||||
model: ParameterModel | null
|
||||
model: MainModelConfig | null | undefined
|
||||
): WarningTKey[] => {
|
||||
const warnings: WarningTKey[] = [];
|
||||
|
||||
@@ -110,7 +111,7 @@ export const getGlobalReferenceImageWarnings = (
|
||||
|
||||
export const getControlLayerWarnings = (
|
||||
entity: CanvasControlLayerState,
|
||||
model: ParameterModel | null
|
||||
model: MainModelConfig | null | undefined
|
||||
): WarningTKey[] => {
|
||||
const warnings: WarningTKey[] = [];
|
||||
|
||||
@@ -129,6 +130,13 @@ export const getControlLayerWarnings = (
|
||||
} else if (entity.controlAdapter.model.base !== model.base) {
|
||||
// Supported model architecture but doesn't match
|
||||
warnings.push(WARNINGS.CONTROL_ADAPTER_INCOMPATIBLE_BASE_MODEL);
|
||||
} else if (
|
||||
model.base === 'flux' &&
|
||||
model.variant === 'inpaint' &&
|
||||
entity.controlAdapter.model.type === 'control_lora'
|
||||
) {
|
||||
// FLUX inpaint variants are FLUX Fill models - not compatible w/ Control LoRA
|
||||
warnings.push(WARNINGS.FLUX_FILL_NO_WORKY_WITH_CONTROL_LORA);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -137,7 +145,7 @@ export const getControlLayerWarnings = (
|
||||
|
||||
export const getRasterLayerWarnings = (
|
||||
_entity: CanvasRasterLayerState,
|
||||
_model: ParameterModel | null
|
||||
_model: MainModelConfig | null | undefined
|
||||
): WarningTKey[] => {
|
||||
const warnings: WarningTKey[] = [];
|
||||
|
||||
@@ -148,7 +156,7 @@ export const getRasterLayerWarnings = (
|
||||
|
||||
export const getInpaintMaskWarnings = (
|
||||
_entity: CanvasInpaintMaskState,
|
||||
_model: ParameterModel | null
|
||||
_model: MainModelConfig | null | undefined
|
||||
): WarningTKey[] => {
|
||||
const warnings: WarningTKey[] = [];
|
||||
|
||||
|
||||
@@ -21,7 +21,10 @@ type Props = {
|
||||
extraCopyActions?: { label: string; getData: (data: unknown) => unknown }[];
|
||||
} & FlexProps;
|
||||
|
||||
const overlayscrollbarsOptions = getOverlayScrollbarsParams('scroll', 'scroll').options;
|
||||
const overlayscrollbarsOptions = getOverlayScrollbarsParams({
|
||||
overflowX: 'scroll',
|
||||
overflowY: 'scroll',
|
||||
}).options;
|
||||
|
||||
const DataViewer = (props: Props) => {
|
||||
const { label, data, fileName, withDownload = true, withCopy = true, extraCopyActions, ...rest } = props;
|
||||
|
||||
@@ -26,6 +26,7 @@ const useFocusRegionOptions = {
|
||||
};
|
||||
|
||||
const FOCUS_REGION_STYLES: SystemStyleObject = {
|
||||
display: 'flex',
|
||||
width: 'full',
|
||||
height: 'full',
|
||||
position: 'absolute',
|
||||
@@ -45,7 +46,7 @@ export const ImageViewer = memo(({ closeButton }: Props) => {
|
||||
<FocusRegionWrapper region="viewer" sx={FOCUS_REGION_STYLES} layerStyle="first" {...useFocusRegionOptions}>
|
||||
{hasImageToCompare && <CompareToolbar />}
|
||||
{!hasImageToCompare && <ViewerToolbar closeButton={closeButton} />}
|
||||
<Box ref={containerRef} w="full" h="full" p={2}>
|
||||
<Box ref={containerRef} w="full" h="full" p={2} overflow="hidden">
|
||||
{!hasImageToCompare && <CurrentImagePreview />}
|
||||
{hasImageToCompare && <ImageComparison containerDims={containerDims} />}
|
||||
</Box>
|
||||
|
||||
@@ -42,8 +42,7 @@ export class InvalidModelConfigError extends Error {
|
||||
export const fetchModelConfig = async (key: string): Promise<AnyModelConfig> => {
|
||||
const { dispatch } = getStore();
|
||||
try {
|
||||
const req = dispatch(modelsApi.endpoints.getModelConfig.initiate(key));
|
||||
req.unsubscribe();
|
||||
const req = dispatch(modelsApi.endpoints.getModelConfig.initiate(key, { subscribe: false }));
|
||||
return await req.unwrap();
|
||||
} catch {
|
||||
throw new ModelConfigNotFoundError(`Unable to retrieve model config for key ${key}`);
|
||||
@@ -62,8 +61,9 @@ export const fetchModelConfig = async (key: string): Promise<AnyModelConfig> =>
|
||||
const fetchModelConfigByAttrs = async (name: string, base: BaseModelType, type: ModelType): Promise<AnyModelConfig> => {
|
||||
const { dispatch } = getStore();
|
||||
try {
|
||||
const req = dispatch(modelsApi.endpoints.getModelConfigByAttrs.initiate({ name, base, type }));
|
||||
req.unsubscribe();
|
||||
const req = dispatch(
|
||||
modelsApi.endpoints.getModelConfigByAttrs.initiate({ name, base, type }, { subscribe: false })
|
||||
);
|
||||
return await req.unwrap();
|
||||
} catch {
|
||||
throw new ModelConfigNotFoundError(`Unable to retrieve model config for name/base/type ${name}/${base}/${type}`);
|
||||
|
||||
@@ -16,6 +16,7 @@ import {
|
||||
useEmbeddingModels,
|
||||
useFluxReduxModels,
|
||||
useIPAdapterModels,
|
||||
useLLaVAModels,
|
||||
useLoRAModels,
|
||||
useMainModels,
|
||||
useRefinerModels,
|
||||
@@ -126,6 +127,12 @@ const ModelList = () => {
|
||||
[fluxReduxModels, searchTerm, filteredModelType]
|
||||
);
|
||||
|
||||
const [llavaOneVisionModels, { isLoading: isLoadingLlavaOneVisionModels }] = useLLaVAModels();
|
||||
const filteredLlavaOneVisionModels = useMemo(
|
||||
() => modelsFilter(llavaOneVisionModels, searchTerm, filteredModelType),
|
||||
[llavaOneVisionModels, searchTerm, filteredModelType]
|
||||
);
|
||||
|
||||
const totalFilteredModels = useMemo(() => {
|
||||
return (
|
||||
filteredMainModels.length +
|
||||
@@ -236,6 +243,17 @@ const ModelList = () => {
|
||||
{!isLoadingClipEmbedModels && filteredClipEmbedModels.length > 0 && (
|
||||
<ModelListWrapper title={t('modelManager.clipEmbed')} modelList={filteredClipEmbedModels} key="clip-embed" />
|
||||
)}
|
||||
|
||||
{/* LLaVA OneVision List */}
|
||||
{isLoadingLlavaOneVisionModels && <FetchingModelsLoader loadingMessage="Loading LLaVA OneVision Models..." />}
|
||||
{!isLoadingLlavaOneVisionModels && filteredLlavaOneVisionModels.length > 0 && (
|
||||
<ModelListWrapper
|
||||
title={t('modelManager.llavaOnevision')}
|
||||
modelList={filteredLlavaOneVisionModels}
|
||||
key="llava-onevision"
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Spandrel Image to Image List */}
|
||||
{isLoadingSpandrelImageToImageModels && (
|
||||
<FetchingModelsLoader loadingMessage="Loading Image-to-Image Models..." />
|
||||
|
||||
@@ -25,8 +25,9 @@ export const ModelTypeFilter = memo(() => {
|
||||
clip_vision: 'CLIP Vision',
|
||||
spandrel_image_to_image: t('modelManager.spandrelImageToImage'),
|
||||
control_lora: t('modelManager.controlLora'),
|
||||
siglip: t('modelManager.siglip'),
|
||||
siglip: t('modelManager.sigLip'),
|
||||
flux_redux: t('modelManager.fluxRedux'),
|
||||
llava_onevision: t('modelManager.llavaOnevision'),
|
||||
}),
|
||||
[t]
|
||||
);
|
||||
|
||||
@@ -14,6 +14,7 @@ import BottomLeftPanel from './flow/panels/BottomLeftPanel/BottomLeftPanel';
|
||||
import MinimapPanel from './flow/panels/MinimapPanel/MinimapPanel';
|
||||
|
||||
const FOCUS_REGION_STYLES: SystemStyleObject = {
|
||||
display: 'flex',
|
||||
position: 'relative',
|
||||
width: 'full',
|
||||
height: 'full',
|
||||
|
||||
@@ -8,7 +8,7 @@ import {
|
||||
Textarea,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useInputFieldDescription } from 'features/nodes/hooks/useInputFieldDescription';
|
||||
import { useInputFieldDescriptionSafe } from 'features/nodes/hooks/useInputFieldDescriptionSafe';
|
||||
import { fieldDescriptionChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { NO_DRAG_CLASS, NO_PAN_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
|
||||
import type { ChangeEvent } from 'react';
|
||||
@@ -48,7 +48,7 @@ InputFieldDescriptionPopover.displayName = 'InputFieldDescriptionPopover';
|
||||
const Content = memo(({ nodeId, fieldName }: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const description = useInputFieldDescription(nodeId, fieldName);
|
||||
const description = useInputFieldDescriptionSafe(nodeId, fieldName);
|
||||
const onChange = useCallback(
|
||||
(e: ChangeEvent<HTMLTextAreaElement>) => {
|
||||
dispatch(fieldDescriptionChanged({ nodeId, fieldName, val: e.target.value }));
|
||||
|
||||
@@ -7,7 +7,7 @@ import { InputFieldResetToDefaultValueIconButton } from 'features/nodes/componen
|
||||
import { useNodeFieldDnd } from 'features/nodes/components/sidePanel/builder/dnd-hooks';
|
||||
import { useInputFieldIsConnected } from 'features/nodes/hooks/useInputFieldIsConnected';
|
||||
import { useInputFieldIsInvalid } from 'features/nodes/hooks/useInputFieldIsInvalid';
|
||||
import { useInputFieldTemplate } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { NO_DRAG_CLASS } from 'features/nodes/types/constants';
|
||||
import type { FieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useRef } from 'react';
|
||||
@@ -22,7 +22,7 @@ interface Props {
|
||||
}
|
||||
|
||||
export const InputFieldEditModeNodes = memo(({ nodeId, fieldName }: Props) => {
|
||||
const fieldTemplate = useInputFieldTemplate(nodeId, fieldName);
|
||||
const fieldTemplate = useInputFieldTemplateOrThrow(nodeId, fieldName);
|
||||
const isInvalid = useInputFieldIsInvalid(nodeId, fieldName);
|
||||
const isConnected = useInputFieldIsConnected(nodeId, fieldName);
|
||||
|
||||
|
||||
@@ -1,23 +1,83 @@
|
||||
import { InputFieldUnknownPlaceholder } from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldUnknownPlaceholder';
|
||||
import { Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { InputFieldWrapper } from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldWrapper';
|
||||
import { useInputFieldInstanceExists } from 'features/nodes/hooks/useInputFieldInstanceExists';
|
||||
import { useInputFieldNameSafe } from 'features/nodes/hooks/useInputFieldNameSafe';
|
||||
import { useInputFieldTemplateExists } from 'features/nodes/hooks/useInputFieldTemplateExists';
|
||||
import type { PropsWithChildren } from 'react';
|
||||
import { memo } from 'react';
|
||||
import type { PropsWithChildren, ReactNode } from 'react';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
type Props = PropsWithChildren<{
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
fallback?: ReactNode;
|
||||
formatLabel?: (name: string) => string;
|
||||
}>;
|
||||
|
||||
export const InputFieldGate = memo(({ nodeId, fieldName, children }: Props) => {
|
||||
export const InputFieldGate = memo(({ nodeId, fieldName, children, fallback, formatLabel }: Props) => {
|
||||
const hasInstance = useInputFieldInstanceExists(nodeId, fieldName);
|
||||
const hasTemplate = useInputFieldTemplateExists(nodeId, fieldName);
|
||||
|
||||
if (!hasTemplate || !hasInstance) {
|
||||
return <InputFieldUnknownPlaceholder nodeId={nodeId} fieldName={fieldName} />;
|
||||
// fallback may be null, indicating we should render nothing at all - must check for undefined explicitly
|
||||
if (fallback !== undefined) {
|
||||
return fallback;
|
||||
}
|
||||
return (
|
||||
<Fallback
|
||||
nodeId={nodeId}
|
||||
fieldName={fieldName}
|
||||
formatLabel={formatLabel}
|
||||
hasInstance={hasInstance}
|
||||
hasTemplate={hasTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
return children;
|
||||
});
|
||||
|
||||
InputFieldGate.displayName = 'InputFieldGate';
|
||||
|
||||
const Fallback = memo(
|
||||
({
|
||||
nodeId,
|
||||
fieldName,
|
||||
formatLabel,
|
||||
hasTemplate,
|
||||
hasInstance,
|
||||
}: {
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
formatLabel?: (name: string) => string;
|
||||
hasTemplate: boolean;
|
||||
hasInstance: boolean;
|
||||
}) => {
|
||||
const { t } = useTranslation();
|
||||
const name = useInputFieldNameSafe(nodeId, fieldName);
|
||||
const label = useMemo(() => {
|
||||
if (formatLabel) {
|
||||
return formatLabel(name);
|
||||
}
|
||||
if (hasTemplate && !hasInstance) {
|
||||
return t('nodes.missingField_withName', { name });
|
||||
}
|
||||
if (!hasTemplate && hasInstance) {
|
||||
return t('nodes.unexpectedField_withName', { name });
|
||||
}
|
||||
return t('nodes.unknownField_withName', { name });
|
||||
}, [formatLabel, hasInstance, hasTemplate, name, t]);
|
||||
|
||||
return (
|
||||
<InputFieldWrapper>
|
||||
<Flex w="full" px={1} py={1} justifyContent="center">
|
||||
<Text fontWeight="semibold" color="error.300" whiteSpace="pre" textAlign="center">
|
||||
{label}
|
||||
</Text>
|
||||
</Flex>
|
||||
</InputFieldWrapper>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
Fallback.displayName = 'Fallback';
|
||||
|
||||
@@ -7,7 +7,7 @@ import {
|
||||
useIsConnectionInProgress,
|
||||
useIsConnectionStartField,
|
||||
} from 'features/nodes/hooks/useFieldConnectionState';
|
||||
import { useInputFieldTemplate } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
|
||||
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
|
||||
import type { FieldInputTemplate } from 'features/nodes/types/field';
|
||||
@@ -62,7 +62,7 @@ const handleStyles = {
|
||||
} satisfies CSSProperties;
|
||||
|
||||
export const InputFieldHandle = memo(({ nodeId, fieldName }: Props) => {
|
||||
const fieldTemplate = useInputFieldTemplate(nodeId, fieldName);
|
||||
const fieldTemplate = useInputFieldTemplateOrThrow(nodeId, fieldName);
|
||||
const fieldTypeName = useFieldTypeName(fieldTemplate.type);
|
||||
const fieldColor = useMemo(() => getFieldColor(fieldTemplate.type), [fieldTemplate.type]);
|
||||
const isModelField = useMemo(() => isModelFieldType(fieldTemplate.type), [fieldTemplate.type]);
|
||||
|
||||
@@ -13,10 +13,11 @@ import { StringGeneratorFieldInputComponent } from 'features/nodes/components/fl
|
||||
import { IntegerFieldInput } from 'features/nodes/components/flow/nodes/Invocation/fields/IntegerField/IntegerFieldInput';
|
||||
import { IntegerFieldInputAndSlider } from 'features/nodes/components/flow/nodes/Invocation/fields/IntegerField/IntegerFieldInputAndSlider';
|
||||
import { IntegerFieldSlider } from 'features/nodes/components/flow/nodes/Invocation/fields/IntegerField/IntegerFieldSlider';
|
||||
import { StringFieldDropdown } from 'features/nodes/components/flow/nodes/Invocation/fields/StringField/StringFieldDropdown';
|
||||
import { StringFieldInput } from 'features/nodes/components/flow/nodes/Invocation/fields/StringField/StringFieldInput';
|
||||
import { StringFieldTextarea } from 'features/nodes/components/flow/nodes/Invocation/fields/StringField/StringFieldTextarea';
|
||||
import { useInputFieldInstance } from 'features/nodes/hooks/useInputFieldInstance';
|
||||
import { useInputFieldTemplate } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import {
|
||||
isBoardFieldInputInstance,
|
||||
isBoardFieldInputTemplate,
|
||||
@@ -62,6 +63,8 @@ import {
|
||||
isIntegerGeneratorFieldInputTemplate,
|
||||
isIPAdapterModelFieldInputInstance,
|
||||
isIPAdapterModelFieldInputTemplate,
|
||||
isLLaVAModelFieldInputInstance,
|
||||
isLLaVAModelFieldInputTemplate,
|
||||
isLoRAModelFieldInputInstance,
|
||||
isLoRAModelFieldInputTemplate,
|
||||
isMainModelFieldInputInstance,
|
||||
@@ -112,6 +115,7 @@ import FluxReduxModelFieldInputComponent from './inputs/FluxReduxModelFieldInput
|
||||
import FluxVAEModelFieldInputComponent from './inputs/FluxVAEModelFieldInputComponent';
|
||||
import ImageFieldInputComponent from './inputs/ImageFieldInputComponent';
|
||||
import IPAdapterModelFieldInputComponent from './inputs/IPAdapterModelFieldInputComponent';
|
||||
import LLaVAModelFieldInputComponent from './inputs/LLaVAModelFieldInputComponent';
|
||||
import LoRAModelFieldInputComponent from './inputs/LoRAModelFieldInputComponent';
|
||||
import MainModelFieldInputComponent from './inputs/MainModelFieldInputComponent';
|
||||
import RefinerModelFieldInputComponent from './inputs/RefinerModelFieldInputComponent';
|
||||
@@ -132,7 +136,7 @@ type Props = {
|
||||
|
||||
export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props) => {
|
||||
const field = useInputFieldInstance(nodeId, fieldName);
|
||||
const template = useInputFieldTemplate(nodeId, fieldName);
|
||||
const template = useInputFieldTemplateOrThrow(nodeId, fieldName);
|
||||
|
||||
// When deciding which component to render, first we check the type of the template, which is more efficient than the
|
||||
// instance type check. The instance type check uses zod and is slower.
|
||||
@@ -148,7 +152,7 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props)
|
||||
if (!isStringFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
if (settings?.type !== 'string-field-config') {
|
||||
if (!settings || settings.type !== 'string-field-config') {
|
||||
if (template.ui_component === 'textarea') {
|
||||
return <StringFieldTextarea nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
} else {
|
||||
@@ -159,8 +163,10 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props)
|
||||
return <StringFieldInput nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
} else if (settings.component === 'textarea') {
|
||||
return <StringFieldTextarea nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
} else if (settings.component === 'dropdown') {
|
||||
return <StringFieldDropdown nodeId={nodeId} field={field} fieldTemplate={template} settings={settings} />;
|
||||
} else {
|
||||
assert<Equals<never, typeof settings.component>>(false, 'Unexpected settings.component');
|
||||
assert<Equals<never, typeof settings>>(false, 'Unexpected settings');
|
||||
}
|
||||
}
|
||||
|
||||
@@ -322,6 +328,13 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props)
|
||||
return <ControlLoRAModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isLLaVAModelFieldInputTemplate(template)) {
|
||||
if (!isLLaVAModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <LLaVAModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isFluxVAEModelFieldInputTemplate(template)) {
|
||||
if (!isFluxVAEModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
|
||||
@@ -9,7 +9,7 @@ import {
|
||||
useIsConnectionStartField,
|
||||
} from 'features/nodes/hooks/useFieldConnectionState';
|
||||
import { useInputFieldIsConnected } from 'features/nodes/hooks/useInputFieldIsConnected';
|
||||
import { useInputFieldLabel } from 'features/nodes/hooks/useInputFieldLabel';
|
||||
import { useInputFieldLabelSafe } from 'features/nodes/hooks/useInputFieldLabelSafe';
|
||||
import { useInputFieldTemplateTitle } from 'features/nodes/hooks/useInputFieldTemplateTitle';
|
||||
import { fieldLabelChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { HANDLE_TOOLTIP_OPEN_DELAY, NO_FIT_ON_DOUBLE_CLICK_CLASS } from 'features/nodes/types/constants';
|
||||
@@ -43,7 +43,7 @@ interface Props {
|
||||
export const InputFieldTitle = memo((props: Props) => {
|
||||
const { nodeId, fieldName, isInvalid, isDragging } = props;
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
const label = useInputFieldLabel(nodeId, fieldName);
|
||||
const label = useInputFieldLabelSafe(nodeId, fieldName);
|
||||
const fieldTemplateTitle = useInputFieldTemplateTitle(nodeId, fieldName);
|
||||
const { t } = useTranslation();
|
||||
const isConnected = useInputFieldIsConnected(nodeId, fieldName);
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { Flex, ListItem, Text, UnorderedList } from '@invoke-ai/ui-library';
|
||||
import { useInputFieldErrors } from 'features/nodes/hooks/useInputFieldErrors';
|
||||
import { useInputFieldInstance } from 'features/nodes/hooks/useInputFieldInstance';
|
||||
import { useInputFieldTemplate } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
|
||||
import { startCase } from 'lodash-es';
|
||||
import { memo, useMemo } from 'react';
|
||||
@@ -16,7 +16,7 @@ export const InputFieldTooltipContent = memo(({ nodeId, fieldName }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const fieldInstance = useInputFieldInstance(nodeId, fieldName);
|
||||
const fieldTemplate = useInputFieldTemplate(nodeId, fieldName);
|
||||
const fieldTemplate = useInputFieldTemplateOrThrow(nodeId, fieldName);
|
||||
const fieldTypeName = useFieldTypeName(fieldTemplate.type);
|
||||
const fieldErrors = useInputFieldErrors(nodeId, fieldName);
|
||||
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
import { FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { InputFieldWrapper } from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldWrapper';
|
||||
import { useInputFieldName } from 'features/nodes/hooks/useInputFieldName';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
type Props = {
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
};
|
||||
|
||||
export const InputFieldUnknownPlaceholder = memo(({ nodeId, fieldName }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const name = useInputFieldName(nodeId, fieldName);
|
||||
|
||||
return (
|
||||
<InputFieldWrapper>
|
||||
<FormControl isInvalid={true} alignItems="stretch" justifyContent="center" gap={2} h="full" w="full">
|
||||
<FormLabel display="flex" mb={0} px={1} py={2} gap={2}>
|
||||
{t('nodes.unknownInput', { name })}
|
||||
</FormLabel>
|
||||
</FormControl>
|
||||
</InputFieldWrapper>
|
||||
);
|
||||
});
|
||||
|
||||
InputFieldUnknownPlaceholder.displayName = 'InputFieldUnknownPlaceholder';
|
||||
@@ -1,7 +1,10 @@
|
||||
import { OutputFieldUnknownPlaceholder } from 'features/nodes/components/flow/nodes/Invocation/fields/OutputFieldUnknownPlaceholder';
|
||||
import { FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { OutputFieldWrapper } from 'features/nodes/components/flow/nodes/Invocation/fields/OutputFieldWrapper';
|
||||
import { useOutputFieldName } from 'features/nodes/hooks/useOutputFieldName';
|
||||
import { useOutputFieldTemplateExists } from 'features/nodes/hooks/useOutputFieldTemplateExists';
|
||||
import type { PropsWithChildren } from 'react';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
type Props = PropsWithChildren<{
|
||||
nodeId: string;
|
||||
@@ -12,10 +15,27 @@ export const OutputFieldGate = memo(({ nodeId, fieldName, children }: Props) =>
|
||||
const hasTemplate = useOutputFieldTemplateExists(nodeId, fieldName);
|
||||
|
||||
if (!hasTemplate) {
|
||||
return <OutputFieldUnknownPlaceholder nodeId={nodeId} fieldName={fieldName} />;
|
||||
return <Fallback nodeId={nodeId} fieldName={fieldName} />;
|
||||
}
|
||||
|
||||
return children;
|
||||
});
|
||||
|
||||
OutputFieldGate.displayName = 'OutputFieldGate';
|
||||
|
||||
const Fallback = memo(({ nodeId, fieldName }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const name = useOutputFieldName(nodeId, fieldName);
|
||||
|
||||
return (
|
||||
<OutputFieldWrapper>
|
||||
<FormControl isInvalid={true} alignItems="stretch" justifyContent="space-between" gap={2} h="full" w="full">
|
||||
<FormLabel display="flex" alignItems="center" h="full" color="error.300" mb={0} px={1} gap={2}>
|
||||
{t('nodes.unexpectedField_withName', { name })}
|
||||
</FormLabel>
|
||||
</FormControl>
|
||||
</OutputFieldWrapper>
|
||||
);
|
||||
});
|
||||
|
||||
Fallback.displayName = 'Fallback';
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
import { FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { OutputFieldWrapper } from 'features/nodes/components/flow/nodes/Invocation/fields/OutputFieldWrapper';
|
||||
import { useOutputFieldName } from 'features/nodes/hooks/useOutputFieldName';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
type Props = {
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
};
|
||||
|
||||
export const OutputFieldUnknownPlaceholder = memo(({ nodeId, fieldName }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const name = useOutputFieldName(nodeId, fieldName);
|
||||
|
||||
return (
|
||||
<OutputFieldWrapper>
|
||||
<FormControl isInvalid={true} alignItems="stretch" justifyContent="space-between" gap={2} h="full" w="full">
|
||||
<FormLabel display="flex" alignItems="center" h="full" color="error.300" mb={0} px={1} gap={2}>
|
||||
{t('nodes.unknownOutput', { name })}
|
||||
</FormLabel>
|
||||
</FormControl>
|
||||
</OutputFieldWrapper>
|
||||
);
|
||||
});
|
||||
|
||||
OutputFieldUnknownPlaceholder.displayName = 'OutputFieldUnknownPlaceholder';
|
||||
@@ -0,0 +1,36 @@
|
||||
import { Select } from '@invoke-ai/ui-library';
|
||||
import type { FieldComponentProps } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/types';
|
||||
import { useStringField } from 'features/nodes/components/flow/nodes/Invocation/fields/StringField/useStringField';
|
||||
import { NO_DRAG_CLASS, NO_PAN_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
|
||||
import type { StringFieldInputInstance, StringFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import type { NodeFieldStringSettings } from 'features/nodes/types/workflow';
|
||||
import { memo } from 'react';
|
||||
|
||||
export const StringFieldDropdown = memo(
|
||||
(
|
||||
props: FieldComponentProps<
|
||||
StringFieldInputInstance,
|
||||
StringFieldInputTemplate,
|
||||
{ settings: Extract<NodeFieldStringSettings, { component: 'dropdown' }> }
|
||||
>
|
||||
) => {
|
||||
const { value, onChange } = useStringField(props);
|
||||
|
||||
return (
|
||||
<Select
|
||||
onChange={onChange}
|
||||
className={`${NO_DRAG_CLASS} ${NO_PAN_CLASS} ${NO_WHEEL_CLASS}`}
|
||||
isDisabled={props.settings.options.length === 0}
|
||||
value={value}
|
||||
>
|
||||
{props.settings.options.map((choice, i) => (
|
||||
<option key={`${i}_${choice.value}`} value={choice.value}>
|
||||
{choice.label || choice.value || `Option ${i + 1}`}
|
||||
</option>
|
||||
))}
|
||||
</Select>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
StringFieldDropdown.displayName = 'StringFieldDropdown';
|
||||
@@ -4,12 +4,21 @@ import { useStringField } from 'features/nodes/components/flow/nodes/Invocation/
|
||||
import { NO_DRAG_CLASS, NO_PAN_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
|
||||
import type { StringFieldInputInstance, StringFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export const StringFieldInput = memo(
|
||||
(props: FieldComponentProps<StringFieldInputInstance, StringFieldInputTemplate>) => {
|
||||
const { value, onChange } = useStringField(props);
|
||||
const { t } = useTranslation();
|
||||
|
||||
return <Input className={`${NO_DRAG_CLASS} ${NO_PAN_CLASS} ${NO_WHEEL_CLASS}`} value={value} onChange={onChange} />;
|
||||
return (
|
||||
<Input
|
||||
className={`${NO_DRAG_CLASS} ${NO_PAN_CLASS} ${NO_WHEEL_CLASS}`}
|
||||
placeholder={t('workflows.emptyStringPlaceholder')}
|
||||
value={value}
|
||||
onChange={onChange}
|
||||
/>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
|
||||
@@ -4,14 +4,17 @@ import { useStringField } from 'features/nodes/components/flow/nodes/Invocation/
|
||||
import { NO_DRAG_CLASS, NO_PAN_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
|
||||
import type { StringFieldInputInstance, StringFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export const StringFieldTextarea = memo(
|
||||
(props: FieldComponentProps<StringFieldInputInstance, StringFieldInputTemplate>) => {
|
||||
const { t } = useTranslation();
|
||||
const { value, onChange } = useStringField(props);
|
||||
|
||||
return (
|
||||
<Textarea
|
||||
className={`${NO_DRAG_CLASS} ${NO_PAN_CLASS} ${NO_WHEEL_CLASS}`}
|
||||
placeholder={t('workflows.emptyStringPlaceholder')}
|
||||
value={value}
|
||||
onChange={onChange}
|
||||
h="full"
|
||||
|
||||
@@ -10,7 +10,7 @@ export const useStringField = (props: FieldComponentProps<StringFieldInputInstan
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const onChange = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement | HTMLTextAreaElement>) => {
|
||||
(e: ChangeEvent<HTMLInputElement | HTMLTextAreaElement | HTMLSelectElement>) => {
|
||||
dispatch(
|
||||
fieldStringValueChanged({
|
||||
nodeId,
|
||||
|
||||
@@ -24,7 +24,7 @@ import { PiXBold } from 'react-icons/pi';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
const overlayscrollbarsOptions = getOverlayScrollbarsParams().options;
|
||||
const overlayscrollbarsOptions = getOverlayScrollbarsParams({}).options;
|
||||
|
||||
const sx = {
|
||||
borderWidth: 1,
|
||||
|
||||
@@ -24,7 +24,7 @@ import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useDebounce } from 'use-debounce';
|
||||
|
||||
const overlayscrollbarsOptions = getOverlayScrollbarsParams().options;
|
||||
const overlayscrollbarsOptions = getOverlayScrollbarsParams({}).options;
|
||||
|
||||
export const FloatGeneratorFieldInputComponent = memo(
|
||||
(props: FieldComponentProps<FloatGeneratorFieldInputInstance, FloatGeneratorFieldInputTemplate>) => {
|
||||
|
||||
@@ -24,7 +24,7 @@ import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
const overlayscrollbarsOptions = getOverlayScrollbarsParams().options;
|
||||
const overlayscrollbarsOptions = getOverlayScrollbarsParams({}).options;
|
||||
|
||||
const sx = {
|
||||
borderWidth: 1,
|
||||
|
||||
@@ -17,7 +17,7 @@ import type { ChangeEvent } from 'react';
|
||||
import { memo, useCallback, useEffect, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const overlayscrollbarsOptions = getOverlayScrollbarsParams().options;
|
||||
const overlayscrollbarsOptions = getOverlayScrollbarsParams({}).options;
|
||||
|
||||
export const ImageGeneratorFieldInputComponent = memo(
|
||||
(props: FieldComponentProps<ImageGeneratorFieldInputInstance, ImageGeneratorFieldInputTemplate>) => {
|
||||
|
||||
@@ -27,7 +27,7 @@ import { PiXBold } from 'react-icons/pi';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
const overlayscrollbarsOptions = getOverlayScrollbarsParams().options;
|
||||
const overlayscrollbarsOptions = getOverlayScrollbarsParams({}).options;
|
||||
|
||||
const sx = {
|
||||
borderWidth: 1,
|
||||
|
||||
@@ -27,7 +27,7 @@ import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useDebounce } from 'use-debounce';
|
||||
|
||||
const overlayscrollbarsOptions = getOverlayScrollbarsParams().options;
|
||||
const overlayscrollbarsOptions = getOverlayScrollbarsParams({}).options;
|
||||
|
||||
export const IntegerGeneratorFieldInputComponent = memo(
|
||||
(props: FieldComponentProps<IntegerGeneratorFieldInputInstance, IntegerGeneratorFieldInputTemplate>) => {
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
import { Combobox, FormControl } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { fieldLLaVAModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
|
||||
import type { LLaVAModelFieldInputInstance, LLaVAModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useLLaVAModels } from 'services/api/hooks/modelsByType';
|
||||
import type { LlavaOnevisionConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<LLaVAModelFieldInputInstance, LLaVAModelFieldInputTemplate>;
|
||||
|
||||
const LLaVAModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useLLaVAModels();
|
||||
const _onChange = useCallback(
|
||||
(value: LlavaOnevisionConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldLLaVAModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
selectedModel: field.value,
|
||||
isLoading,
|
||||
});
|
||||
|
||||
return (
|
||||
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isInvalid={!value} isDisabled={!options.length}>
|
||||
<Combobox
|
||||
value={value}
|
||||
placeholder={placeholder}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
options={options}
|
||||
onChange={onChange}
|
||||
/>
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(LLaVAModelFieldInputComponent);
|
||||
@@ -1,5 +1,5 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Button, Divider, Flex, FormLabel, Grid, GridItem, IconButton, Input } from '@invoke-ai/ui-library';
|
||||
import { Button, Divider, Flex, Grid, GridItem, IconButton, Input, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppStore } from 'app/store/nanostores/store';
|
||||
import { getOverlayScrollbarsParams, overlayScrollbarsStyles } from 'common/components/OverlayScrollbars/constants';
|
||||
import { useInputFieldIsInvalid } from 'features/nodes/hooks/useInputFieldIsInvalid';
|
||||
@@ -17,7 +17,7 @@ import { PiXBold } from 'react-icons/pi';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
const overlayscrollbarsOptions = getOverlayScrollbarsParams().options;
|
||||
const overlayscrollbarsOptions = getOverlayScrollbarsParams({}).options;
|
||||
|
||||
const sx = {
|
||||
borderWidth: 1,
|
||||
@@ -107,42 +107,6 @@ export const StringFieldCollectionInputComponent = memo(
|
||||
|
||||
StringFieldCollectionInputComponent.displayName = 'StringFieldCollectionInputComponent';
|
||||
|
||||
type StringListItemContentProps = {
|
||||
value: string;
|
||||
index: number;
|
||||
onRemoveString: (index: number) => void;
|
||||
onChangeString: (index: number, value: string) => void;
|
||||
};
|
||||
|
||||
const StringListItemContent = memo(({ value, index, onRemoveString, onChangeString }: StringListItemContentProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const onClickRemove = useCallback(() => {
|
||||
onRemoveString(index);
|
||||
}, [index, onRemoveString]);
|
||||
const onChange = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => {
|
||||
onChangeString(index, e.target.value);
|
||||
},
|
||||
[index, onChangeString]
|
||||
);
|
||||
return (
|
||||
<Flex alignItems="center" gap={1}>
|
||||
<Input size="xs" resize="none" value={value} onChange={onChange} />
|
||||
<IconButton
|
||||
size="sm"
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
onClick={onClickRemove}
|
||||
icon={<PiXBold />}
|
||||
aria-label={t('common.remove')}
|
||||
tooltip={t('common.remove')}
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
StringListItemContent.displayName = 'StringListItemContent';
|
||||
|
||||
type ListItemContentProps = {
|
||||
value: string;
|
||||
index: number;
|
||||
@@ -166,19 +130,26 @@ const ListItemContent = memo(({ value, index, onRemoveString, onChangeString }:
|
||||
return (
|
||||
<>
|
||||
<GridItem>
|
||||
<FormLabel ps={1} m={0}>
|
||||
<Text variant="subtext" textAlign="center" minW={8}>
|
||||
{index + 1}.
|
||||
</FormLabel>
|
||||
</Text>
|
||||
</GridItem>
|
||||
<GridItem>
|
||||
<Input size="sm" resize="none" value={value} onChange={onChange} />
|
||||
<Input
|
||||
placeholder={t('workflows.emptyStringPlaceholder')}
|
||||
size="sm"
|
||||
resize="none"
|
||||
value={value}
|
||||
onChange={onChange}
|
||||
/>
|
||||
</GridItem>
|
||||
<GridItem>
|
||||
<IconButton
|
||||
tabIndex={-1}
|
||||
size="sm"
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
minW={8}
|
||||
minH={8}
|
||||
onClick={onClickRemove}
|
||||
icon={<PiXBold />}
|
||||
aria-label={t('common.delete')}
|
||||
|
||||
@@ -21,7 +21,7 @@ import type { ChangeEvent } from 'react';
|
||||
import { memo, useCallback, useEffect, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const overlayscrollbarsOptions = getOverlayScrollbarsParams().options;
|
||||
const overlayscrollbarsOptions = getOverlayScrollbarsParams({}).options;
|
||||
|
||||
export const StringGeneratorFieldInputComponent = memo(
|
||||
(props: FieldComponentProps<StringGeneratorFieldInputInstance, StringGeneratorFieldInputTemplate>) => {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import type { FlexProps, SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Flex, IconButton, Spacer, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { InputFieldGate } from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldGate';
|
||||
import { ContainerElementSettings } from 'features/nodes/components/sidePanel/builder/ContainerElementSettings';
|
||||
import { useDepthContext } from 'features/nodes/components/sidePanel/builder/contexts';
|
||||
import { NodeFieldElementSettings } from 'features/nodes/components/sidePanel/builder/NodeFieldElementSettings';
|
||||
@@ -47,8 +48,16 @@ export const FormElementEditModeHeader = memo(({ element, dragHandleRef, ...rest
|
||||
<Label element={element} />
|
||||
<Spacer />
|
||||
{isContainerElement(element) && <ContainerElementSettings element={element} />}
|
||||
{isNodeFieldElement(element) && <ZoomToNodeButton element={element} />}
|
||||
{isNodeFieldElement(element) && <NodeFieldElementSettings element={element} />}
|
||||
{isNodeFieldElement(element) && (
|
||||
<InputFieldGate
|
||||
nodeId={element.data.fieldIdentifier.nodeId}
|
||||
fieldName={element.data.fieldIdentifier.fieldName}
|
||||
fallback={null} // Do not render these buttons if the field is not found
|
||||
>
|
||||
<ZoomToNodeButton element={element} />
|
||||
<NodeFieldElementSettings element={element} />
|
||||
</InputFieldGate>
|
||||
)}
|
||||
<RemoveElementButton element={element} />
|
||||
</Flex>
|
||||
);
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { InputFieldGate } from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldGate';
|
||||
import { NodeFieldElementEditMode } from 'features/nodes/components/sidePanel/builder/NodeFieldElementEditMode';
|
||||
import { NodeFieldElementViewMode } from 'features/nodes/components/sidePanel/builder/NodeFieldElementViewMode';
|
||||
import { selectWorkflowMode, useElement } from 'features/nodes/store/workflowSlice';
|
||||
@@ -15,19 +14,11 @@ export const NodeFieldElement = memo(({ id }: { id: string }) => {
|
||||
}
|
||||
|
||||
if (mode === 'view') {
|
||||
return (
|
||||
<InputFieldGate nodeId={el.data.fieldIdentifier.nodeId} fieldName={el.data.fieldIdentifier.fieldName}>
|
||||
<NodeFieldElementViewMode el={el} />
|
||||
</InputFieldGate>
|
||||
);
|
||||
return <NodeFieldElementViewMode el={el} />;
|
||||
}
|
||||
|
||||
// mode === 'edit'
|
||||
return (
|
||||
<InputFieldGate nodeId={el.data.fieldIdentifier.nodeId} fieldName={el.data.fieldIdentifier.fieldName}>
|
||||
<NodeFieldElementEditMode el={el} />
|
||||
</InputFieldGate>
|
||||
);
|
||||
return <NodeFieldElementEditMode el={el} />;
|
||||
});
|
||||
|
||||
NodeFieldElement.displayName = 'NodeFieldElement';
|
||||
|
||||
@@ -2,8 +2,8 @@ import { FormHelperText, Textarea } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { linkifyOptions, linkifySx } from 'common/components/linkify';
|
||||
import { useEditable } from 'common/hooks/useEditable';
|
||||
import { useInputFieldDescription } from 'features/nodes/hooks/useInputFieldDescription';
|
||||
import { useInputFieldTemplate } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { useInputFieldDescriptionSafe } from 'features/nodes/hooks/useInputFieldDescriptionSafe';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { fieldDescriptionChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { NodeFieldElement } from 'features/nodes/types/workflow';
|
||||
import Linkify from 'linkify-react';
|
||||
@@ -13,8 +13,8 @@ export const NodeFieldElementDescriptionEditable = memo(({ el }: { el: NodeField
|
||||
const { data } = el;
|
||||
const { fieldIdentifier } = data;
|
||||
const dispatch = useAppDispatch();
|
||||
const description = useInputFieldDescription(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const fieldTemplate = useInputFieldTemplate(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const description = useInputFieldDescriptionSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const fieldTemplate = useInputFieldTemplateOrThrow(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const inputRef = useRef<HTMLTextAreaElement>(null);
|
||||
|
||||
const onChange = useCallback(
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Box, Flex, FormControl } from '@invoke-ai/ui-library';
|
||||
import { Box, Divider, Flex, FormControl } from '@invoke-ai/ui-library';
|
||||
import { InputFieldGate } from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldGate';
|
||||
import { InputFieldRenderer } from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer';
|
||||
import { useContainerContext } from 'features/nodes/components/sidePanel/builder/contexts';
|
||||
import { useFormElementDnd } from 'features/nodes/components/sidePanel/builder/dnd-hooks';
|
||||
@@ -8,9 +9,11 @@ import { FormElementEditModeContent } from 'features/nodes/components/sidePanel/
|
||||
import { FormElementEditModeHeader } from 'features/nodes/components/sidePanel/builder/FormElementEditModeHeader';
|
||||
import { NodeFieldElementDescriptionEditable } from 'features/nodes/components/sidePanel/builder/NodeFieldElementDescriptionEditable';
|
||||
import { NodeFieldElementLabelEditable } from 'features/nodes/components/sidePanel/builder/NodeFieldElementLabelEditable';
|
||||
import { NodeFieldElementStringDropdownSettings } from 'features/nodes/components/sidePanel/builder/NodeFieldElementStringDropdownSettings';
|
||||
import { useMouseOverFormField, useMouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
|
||||
import type { NodeFieldElement } from 'features/nodes/types/workflow';
|
||||
import { NODE_FIELD_CLASS_NAME } from 'features/nodes/types/workflow';
|
||||
import type { RefObject } from 'react';
|
||||
import { memo, useRef } from 'react';
|
||||
|
||||
const sx: SystemStyleObject = {
|
||||
@@ -31,25 +34,11 @@ export const NodeFieldElementEditMode = memo(({ el }: { el: NodeFieldElement })
|
||||
const dragHandleRef = useRef<HTMLDivElement>(null);
|
||||
const [activeDropRegion, isDragging] = useFormElementDnd(el.id, draggableRef, dragHandleRef);
|
||||
const containerCtx = useContainerContext();
|
||||
const { id, data } = el;
|
||||
const { fieldIdentifier, showDescription } = data;
|
||||
const { id } = el;
|
||||
|
||||
return (
|
||||
<Flex ref={draggableRef} id={id} className={NODE_FIELD_CLASS_NAME} sx={sx} data-parent-layout={containerCtx.layout}>
|
||||
<FormElementEditModeHeader dragHandleRef={dragHandleRef} element={el} data-is-dragging={isDragging} />
|
||||
<FormElementEditModeContent data-is-dragging={isDragging} p={4}>
|
||||
<FormControl flex="1 1 0" orientation="vertical">
|
||||
<NodeFieldElementLabelEditable el={el} />
|
||||
<Flex w="full" gap={4}>
|
||||
<InputFieldRenderer
|
||||
nodeId={fieldIdentifier.nodeId}
|
||||
fieldName={fieldIdentifier.fieldName}
|
||||
settings={data.settings}
|
||||
/>
|
||||
</Flex>
|
||||
{showDescription && <NodeFieldElementDescriptionEditable el={el} />}
|
||||
</FormControl>
|
||||
</FormElementEditModeContent>
|
||||
<NodeFieldElementEditModeContent dragHandleRef={dragHandleRef} el={el} isDragging={isDragging} />
|
||||
<NodeFieldElementOverlay element={el} />
|
||||
<DndListDropIndicator activeDropRegion={activeDropRegion} gap="var(--invoke-space-4)" />
|
||||
</Flex>
|
||||
@@ -58,6 +47,48 @@ export const NodeFieldElementEditMode = memo(({ el }: { el: NodeFieldElement })
|
||||
|
||||
NodeFieldElementEditMode.displayName = 'NodeFieldElementEditMode';
|
||||
|
||||
const NodeFieldElementEditModeContent = memo(
|
||||
({
|
||||
el,
|
||||
dragHandleRef,
|
||||
isDragging,
|
||||
}: {
|
||||
el: NodeFieldElement;
|
||||
dragHandleRef: RefObject<HTMLDivElement>;
|
||||
isDragging: boolean;
|
||||
}) => {
|
||||
const { id, data } = el;
|
||||
const { fieldIdentifier, showDescription } = data;
|
||||
return (
|
||||
<>
|
||||
<FormElementEditModeHeader dragHandleRef={dragHandleRef} element={el} data-is-dragging={isDragging} />
|
||||
<FormElementEditModeContent data-is-dragging={isDragging} p={4}>
|
||||
<InputFieldGate nodeId={fieldIdentifier.nodeId} fieldName={fieldIdentifier.fieldName}>
|
||||
<FormControl flex="1 1 0" orientation="vertical">
|
||||
<NodeFieldElementLabelEditable el={el} />
|
||||
<Flex w="full" gap={4}>
|
||||
<InputFieldRenderer
|
||||
nodeId={fieldIdentifier.nodeId}
|
||||
fieldName={fieldIdentifier.fieldName}
|
||||
settings={data.settings}
|
||||
/>
|
||||
</Flex>
|
||||
{showDescription && <NodeFieldElementDescriptionEditable el={el} />}
|
||||
{data.settings?.type === 'string-field-config' && data.settings.component === 'dropdown' && (
|
||||
<>
|
||||
<Divider />
|
||||
<NodeFieldElementStringDropdownSettings id={id} settings={data.settings} />
|
||||
</>
|
||||
)}
|
||||
</FormControl>
|
||||
</InputFieldGate>
|
||||
</FormElementEditModeContent>
|
||||
</>
|
||||
);
|
||||
}
|
||||
);
|
||||
NodeFieldElementEditModeContent.displayName = 'NodeFieldElementEditModeContent';
|
||||
|
||||
const nodeFieldOverlaySx: SystemStyleObject = {
|
||||
position: 'absolute',
|
||||
top: 0,
|
||||
|
||||
@@ -14,42 +14,48 @@ import { useTranslation } from 'react-i18next';
|
||||
|
||||
type Props = {
|
||||
id: string;
|
||||
config: NodeFieldFloatSettings;
|
||||
settings: NodeFieldFloatSettings;
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
fieldTemplate: FloatFieldInputTemplate;
|
||||
};
|
||||
|
||||
export const NodeFieldElementFloatSettings = memo(({ id, config, nodeId, fieldName, fieldTemplate }: Props) => {
|
||||
export const NodeFieldElementFloatSettings = memo(({ id, settings, nodeId, fieldName, fieldTemplate }: Props) => {
|
||||
return (
|
||||
<>
|
||||
<SettingComponent id={id} config={config} nodeId={nodeId} fieldName={fieldName} fieldTemplate={fieldTemplate} />
|
||||
<SettingMin id={id} config={config} nodeId={nodeId} fieldName={fieldName} fieldTemplate={fieldTemplate} />
|
||||
<SettingMax id={id} config={config} nodeId={nodeId} fieldName={fieldName} fieldTemplate={fieldTemplate} />
|
||||
<SettingComponent
|
||||
id={id}
|
||||
settings={settings}
|
||||
nodeId={nodeId}
|
||||
fieldName={fieldName}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
<SettingMin id={id} settings={settings} nodeId={nodeId} fieldName={fieldName} fieldTemplate={fieldTemplate} />
|
||||
<SettingMax id={id} settings={settings} nodeId={nodeId} fieldName={fieldName} fieldTemplate={fieldTemplate} />
|
||||
</>
|
||||
);
|
||||
});
|
||||
NodeFieldElementFloatSettings.displayName = 'NodeFieldElementFloatSettings';
|
||||
|
||||
const SettingComponent = memo(({ id, config }: Props) => {
|
||||
const SettingComponent = memo(({ id, settings }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const onChangeComponent = useCallback(
|
||||
(e: ChangeEvent<HTMLSelectElement>) => {
|
||||
const newConfig: NodeFieldFloatSettings = {
|
||||
...config,
|
||||
const newSettings: NodeFieldFloatSettings = {
|
||||
...settings,
|
||||
component: zNumberComponent.parse(e.target.value),
|
||||
};
|
||||
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: newConfig } }));
|
||||
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: newSettings } }));
|
||||
},
|
||||
[config, dispatch, id]
|
||||
[settings, dispatch, id]
|
||||
);
|
||||
|
||||
return (
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel flex={1}>{t('workflows.builder.component')}</FormLabel>
|
||||
<Select value={config.component} onChange={onChangeComponent} size="sm">
|
||||
<Select value={settings.component} onChange={onChangeComponent} size="sm">
|
||||
<option value="number-input">{t('workflows.builder.numberInput')}</option>
|
||||
<option value="slider">{t('workflows.builder.slider')}</option>
|
||||
<option value="number-input-and-slider">{t('workflows.builder.both')}</option>
|
||||
@@ -59,36 +65,36 @@ const SettingComponent = memo(({ id, config }: Props) => {
|
||||
});
|
||||
SettingComponent.displayName = 'SettingComponent';
|
||||
|
||||
const SettingMin = memo(({ id, config, nodeId, fieldName, fieldTemplate }: Props) => {
|
||||
const SettingMin = memo(({ id, settings, nodeId, fieldName, fieldTemplate }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const field = useInputFieldInstance<FloatFieldInputInstance>(nodeId, fieldName);
|
||||
|
||||
const floatField = useFloatField(nodeId, fieldName, fieldTemplate);
|
||||
|
||||
const onToggleSetting = useCallback(() => {
|
||||
const newConfig: NodeFieldFloatSettings = {
|
||||
...config,
|
||||
min: config.min !== undefined ? undefined : floatField.min,
|
||||
const onToggleOverride = useCallback(() => {
|
||||
const newSettings: NodeFieldFloatSettings = {
|
||||
...settings,
|
||||
min: settings.min !== undefined ? undefined : floatField.min,
|
||||
};
|
||||
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: newConfig } }));
|
||||
}, [config, dispatch, floatField, id]);
|
||||
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: newSettings } }));
|
||||
}, [settings, dispatch, floatField, id]);
|
||||
|
||||
const onChange = useCallback(
|
||||
(min: number) => {
|
||||
const newConfig: NodeFieldFloatSettings = {
|
||||
...config,
|
||||
const newSettings: NodeFieldFloatSettings = {
|
||||
...settings,
|
||||
min,
|
||||
};
|
||||
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: newConfig } }));
|
||||
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: newSettings } }));
|
||||
|
||||
// We may need to update the value if it is outside the new min/max range
|
||||
const constrained = constrainNumber(field.value, floatField, newConfig);
|
||||
const constrained = constrainNumber(field.value, floatField, newSettings);
|
||||
if (field.value !== constrained) {
|
||||
dispatch(fieldFloatValueChanged({ nodeId, fieldName, value: constrained }));
|
||||
}
|
||||
},
|
||||
[config, dispatch, field.value, fieldName, floatField, id, nodeId]
|
||||
[settings, dispatch, field.value, fieldName, floatField, id, nodeId]
|
||||
);
|
||||
|
||||
const constraintMin = useMemo(
|
||||
@@ -97,20 +103,20 @@ const SettingMin = memo(({ id, config, nodeId, fieldName, fieldTemplate }: Props
|
||||
);
|
||||
|
||||
const constraintMax = useMemo(
|
||||
() => (config.max ?? floatField.max) - floatField.step,
|
||||
[config.max, floatField.max, floatField.step]
|
||||
() => (settings.max ?? floatField.max) - floatField.step,
|
||||
[settings.max, floatField.max, floatField.step]
|
||||
);
|
||||
|
||||
return (
|
||||
<FormControl orientation="vertical">
|
||||
<Flex justifyContent="space-between" w="full" alignItems="center">
|
||||
<FormLabel m={0}>{t('workflows.builder.minimum')}</FormLabel>
|
||||
<Switch isChecked={config.min !== undefined} onChange={onToggleSetting} size="sm" />
|
||||
<Switch isChecked={settings.min !== undefined} onChange={onToggleOverride} size="sm" />
|
||||
</Flex>
|
||||
<CompositeNumberInput
|
||||
w="full"
|
||||
isDisabled={config.min === undefined}
|
||||
value={config.min === undefined ? (`${floatField.min} (inherited)` as unknown as number) : config.min}
|
||||
isDisabled={settings.min === undefined}
|
||||
value={settings.min === undefined ? (`${floatField.min} (inherited)` as unknown as number) : settings.min}
|
||||
onChange={onChange}
|
||||
min={constraintMin}
|
||||
max={constraintMax}
|
||||
@@ -121,41 +127,41 @@ const SettingMin = memo(({ id, config, nodeId, fieldName, fieldTemplate }: Props
|
||||
});
|
||||
SettingMin.displayName = 'SettingMin';
|
||||
|
||||
const SettingMax = memo(({ id, config, nodeId, fieldName, fieldTemplate }: Props) => {
|
||||
const SettingMax = memo(({ id, settings, nodeId, fieldName, fieldTemplate }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const field = useInputFieldInstance<FloatFieldInputInstance>(nodeId, fieldName);
|
||||
|
||||
const floatField = useFloatField(nodeId, fieldName, fieldTemplate);
|
||||
|
||||
const onToggleSetting = useCallback(() => {
|
||||
const newConfig: NodeFieldFloatSettings = {
|
||||
...config,
|
||||
max: config.max !== undefined ? undefined : floatField.max,
|
||||
const onToggleOverride = useCallback(() => {
|
||||
const newSettings: NodeFieldFloatSettings = {
|
||||
...settings,
|
||||
max: settings.max !== undefined ? undefined : floatField.max,
|
||||
};
|
||||
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: newConfig } }));
|
||||
}, [config, dispatch, floatField, id]);
|
||||
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: newSettings } }));
|
||||
}, [settings, dispatch, floatField, id]);
|
||||
|
||||
const onChange = useCallback(
|
||||
(max: number) => {
|
||||
const newConfig: NodeFieldFloatSettings = {
|
||||
...config,
|
||||
const newSettings: NodeFieldFloatSettings = {
|
||||
...settings,
|
||||
max,
|
||||
};
|
||||
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: newConfig } }));
|
||||
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: newSettings } }));
|
||||
|
||||
// We may need to update the value if it is outside the new min/max range
|
||||
const constrained = constrainNumber(field.value, floatField, newConfig);
|
||||
const constrained = constrainNumber(field.value, floatField, newSettings);
|
||||
if (field.value !== constrained) {
|
||||
dispatch(fieldFloatValueChanged({ nodeId, fieldName, value: constrained }));
|
||||
}
|
||||
},
|
||||
[config, dispatch, field.value, fieldName, floatField, id, nodeId]
|
||||
[settings, dispatch, field.value, fieldName, floatField, id, nodeId]
|
||||
);
|
||||
|
||||
const constraintMin = useMemo(
|
||||
() => (config.min ?? floatField.min) + floatField.step,
|
||||
[config.min, floatField.min, floatField.step]
|
||||
() => (settings.min ?? floatField.min) + floatField.step,
|
||||
[settings.min, floatField.min, floatField.step]
|
||||
);
|
||||
|
||||
const constraintMax = useMemo(
|
||||
@@ -167,12 +173,12 @@ const SettingMax = memo(({ id, config, nodeId, fieldName, fieldTemplate }: Props
|
||||
<FormControl orientation="vertical">
|
||||
<Flex justifyContent="space-between" w="full" alignItems="center">
|
||||
<FormLabel m={0}>{t('workflows.builder.maximum')}</FormLabel>
|
||||
<Switch isChecked={config.max !== undefined} onChange={onToggleSetting} size="sm" />
|
||||
<Switch isChecked={settings.max !== undefined} onChange={onToggleOverride} size="sm" />
|
||||
</Flex>
|
||||
<CompositeNumberInput
|
||||
w="full"
|
||||
isDisabled={config.max === undefined}
|
||||
value={config.max === undefined ? (`${floatField.max} (inherited)` as unknown as number) : config.max}
|
||||
isDisabled={settings.max === undefined}
|
||||
value={settings.max === undefined ? (`${floatField.max} (inherited)` as unknown as number) : settings.max}
|
||||
onChange={onChange}
|
||||
min={constraintMin}
|
||||
max={constraintMax}
|
||||
|
||||
@@ -15,42 +15,48 @@ import { useTranslation } from 'react-i18next';
|
||||
|
||||
type Props = {
|
||||
id: string;
|
||||
config: NodeFieldIntegerSettings;
|
||||
settings: NodeFieldIntegerSettings;
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
fieldTemplate: IntegerFieldInputTemplate;
|
||||
};
|
||||
|
||||
export const NodeFieldElementIntegerSettings = memo(({ id, config, nodeId, fieldName, fieldTemplate }: Props) => {
|
||||
export const NodeFieldElementIntegerSettings = memo(({ id, settings, nodeId, fieldName, fieldTemplate }: Props) => {
|
||||
return (
|
||||
<>
|
||||
<SettingComponent id={id} config={config} nodeId={nodeId} fieldName={fieldName} fieldTemplate={fieldTemplate} />
|
||||
<SettingMin id={id} config={config} nodeId={nodeId} fieldName={fieldName} fieldTemplate={fieldTemplate} />
|
||||
<SettingMax id={id} config={config} nodeId={nodeId} fieldName={fieldName} fieldTemplate={fieldTemplate} />
|
||||
<SettingComponent
|
||||
id={id}
|
||||
settings={settings}
|
||||
nodeId={nodeId}
|
||||
fieldName={fieldName}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
<SettingMin id={id} settings={settings} nodeId={nodeId} fieldName={fieldName} fieldTemplate={fieldTemplate} />
|
||||
<SettingMax id={id} settings={settings} nodeId={nodeId} fieldName={fieldName} fieldTemplate={fieldTemplate} />
|
||||
</>
|
||||
);
|
||||
});
|
||||
NodeFieldElementIntegerSettings.displayName = 'NodeFieldElementIntegerSettings';
|
||||
|
||||
const SettingComponent = memo(({ id, config }: Props) => {
|
||||
const SettingComponent = memo(({ id, settings }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const onChangeComponent = useCallback(
|
||||
(e: ChangeEvent<HTMLSelectElement>) => {
|
||||
const newConfig: NodeFieldIntegerSettings = {
|
||||
...config,
|
||||
const newSettings: NodeFieldIntegerSettings = {
|
||||
...settings,
|
||||
component: zNumberComponent.parse(e.target.value),
|
||||
};
|
||||
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: newConfig } }));
|
||||
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: newSettings } }));
|
||||
},
|
||||
[config, dispatch, id]
|
||||
[settings, dispatch, id]
|
||||
);
|
||||
|
||||
return (
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel flex={1}>{t('workflows.builder.component')}</FormLabel>
|
||||
<Select value={config.component} onChange={onChangeComponent} size="sm">
|
||||
<Select value={settings.component} onChange={onChangeComponent} size="sm">
|
||||
<option value="number-input">{t('workflows.builder.numberInput')}</option>
|
||||
<option value="slider">{t('workflows.builder.slider')}</option>
|
||||
<option value="number-input-and-slider">{t('workflows.builder.both')}</option>
|
||||
@@ -60,37 +66,37 @@ const SettingComponent = memo(({ id, config }: Props) => {
|
||||
});
|
||||
SettingComponent.displayName = 'SettingComponent';
|
||||
|
||||
const SettingMin = memo(({ id, config, nodeId, fieldName, fieldTemplate }: Props) => {
|
||||
const SettingMin = memo(({ id, settings, nodeId, fieldName, fieldTemplate }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const field = useInputFieldInstance<IntegerFieldInputInstance>(nodeId, fieldName);
|
||||
|
||||
const integerField = useIntegerField(nodeId, fieldName, fieldTemplate);
|
||||
|
||||
const onToggleSetting = useCallback(() => {
|
||||
const newConfig: NodeFieldIntegerSettings = {
|
||||
...config,
|
||||
min: config.min !== undefined ? undefined : integerField.min,
|
||||
const onToggleOverride = useCallback(() => {
|
||||
const newSettings: NodeFieldIntegerSettings = {
|
||||
...settings,
|
||||
min: settings.min !== undefined ? undefined : integerField.min,
|
||||
};
|
||||
|
||||
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: newConfig } }));
|
||||
}, [config, dispatch, integerField.min, id]);
|
||||
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: newSettings } }));
|
||||
}, [settings, dispatch, integerField.min, id]);
|
||||
|
||||
const onChange = useCallback(
|
||||
(min: number) => {
|
||||
const newConfig: NodeFieldIntegerSettings = {
|
||||
...config,
|
||||
const newSettings: NodeFieldIntegerSettings = {
|
||||
...settings,
|
||||
min,
|
||||
};
|
||||
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: newConfig } }));
|
||||
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: newSettings } }));
|
||||
|
||||
// We may need to update the value if it is outside the new min/max range
|
||||
const constrained = constrainNumber(field.value, integerField, newConfig);
|
||||
const constrained = constrainNumber(field.value, integerField, newSettings);
|
||||
if (field.value !== constrained) {
|
||||
dispatch(fieldIntegerValueChanged({ nodeId, fieldName, value: constrained }));
|
||||
}
|
||||
},
|
||||
[config, dispatch, id, field, integerField, nodeId, fieldName]
|
||||
[settings, dispatch, id, field, integerField, nodeId, fieldName]
|
||||
);
|
||||
|
||||
const constraintMin = useMemo(
|
||||
@@ -99,20 +105,20 @@ const SettingMin = memo(({ id, config, nodeId, fieldName, fieldTemplate }: Props
|
||||
);
|
||||
|
||||
const constraintMax = useMemo(
|
||||
() => (config.max ?? integerField.max) - integerField.step,
|
||||
[config.max, integerField.max, integerField.step]
|
||||
() => (settings.max ?? integerField.max) - integerField.step,
|
||||
[settings.max, integerField.max, integerField.step]
|
||||
);
|
||||
|
||||
return (
|
||||
<FormControl orientation="vertical">
|
||||
<Flex justifyContent="space-between" w="full" alignItems="center">
|
||||
<FormLabel m={0}>{t('workflows.builder.minimum')}</FormLabel>
|
||||
<Switch isChecked={config.min !== undefined} onChange={onToggleSetting} size="sm" />
|
||||
<Switch isChecked={settings.min !== undefined} onChange={onToggleOverride} size="sm" />
|
||||
</Flex>
|
||||
<CompositeNumberInput
|
||||
w="full"
|
||||
isDisabled={config.min === undefined}
|
||||
value={config.min ?? (`${integerField.min} (inherited)` as unknown as number)}
|
||||
isDisabled={settings.min === undefined}
|
||||
value={settings.min ?? (`${integerField.min} (inherited)` as unknown as number)}
|
||||
onChange={onChange}
|
||||
min={constraintMin}
|
||||
max={constraintMax}
|
||||
@@ -123,42 +129,42 @@ const SettingMin = memo(({ id, config, nodeId, fieldName, fieldTemplate }: Props
|
||||
});
|
||||
SettingMin.displayName = 'SettingMin';
|
||||
|
||||
const SettingMax = memo(({ id, config, nodeId, fieldName, fieldTemplate }: Props) => {
|
||||
const SettingMax = memo(({ id, settings, nodeId, fieldName, fieldTemplate }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const field = useInputFieldInstance<IntegerFieldInputInstance>(nodeId, fieldName);
|
||||
|
||||
const integerField = useIntegerField(nodeId, fieldName, fieldTemplate);
|
||||
|
||||
const onToggleSetting = useCallback(() => {
|
||||
const newConfig: NodeFieldIntegerSettings = {
|
||||
...config,
|
||||
max: config.max !== undefined ? undefined : integerField.max,
|
||||
const onToggleOverride = useCallback(() => {
|
||||
const newSettings: NodeFieldIntegerSettings = {
|
||||
...settings,
|
||||
max: settings.max !== undefined ? undefined : integerField.max,
|
||||
};
|
||||
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: newConfig } }));
|
||||
}, [config, dispatch, integerField.max, id]);
|
||||
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: newSettings } }));
|
||||
}, [settings, dispatch, integerField.max, id]);
|
||||
|
||||
const onChange = useCallback(
|
||||
(max: number) => {
|
||||
const newConfig: NodeFieldIntegerSettings = {
|
||||
...config,
|
||||
const newSettings: NodeFieldIntegerSettings = {
|
||||
...settings,
|
||||
max,
|
||||
};
|
||||
|
||||
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: newConfig } }));
|
||||
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: newSettings } }));
|
||||
|
||||
// We may need to update the value if it is outside the new min/max range
|
||||
const constrained = constrainNumber(field.value, integerField, newConfig);
|
||||
const constrained = constrainNumber(field.value, integerField, newSettings);
|
||||
if (field.value !== constrained) {
|
||||
dispatch(fieldIntegerValueChanged({ nodeId, fieldName, value: constrained }));
|
||||
}
|
||||
},
|
||||
[config, dispatch, field.value, fieldName, integerField, id, nodeId]
|
||||
[settings, dispatch, field.value, fieldName, integerField, id, nodeId]
|
||||
);
|
||||
|
||||
const constraintMin = useMemo(
|
||||
() => (config.min ?? integerField.min) + integerField.step,
|
||||
[config.min, integerField.min, integerField.step]
|
||||
() => (settings.min ?? integerField.min) + integerField.step,
|
||||
[settings.min, integerField.min, integerField.step]
|
||||
);
|
||||
|
||||
const constraintMax = useMemo(
|
||||
@@ -170,12 +176,12 @@ const SettingMax = memo(({ id, config, nodeId, fieldName, fieldTemplate }: Props
|
||||
<FormControl orientation="vertical">
|
||||
<Flex justifyContent="space-between" w="full" alignItems="center">
|
||||
<FormLabel m={0}>{t('workflows.builder.maximum')}</FormLabel>
|
||||
<Switch isChecked={config.max !== undefined} onChange={onToggleSetting} size="sm" />
|
||||
<Switch isChecked={settings.max !== undefined} onChange={onToggleOverride} size="sm" />
|
||||
</Flex>
|
||||
<CompositeNumberInput
|
||||
w="full"
|
||||
isDisabled={config.max === undefined}
|
||||
value={config.max ?? (`${integerField.max} (inherited)` as unknown as number)}
|
||||
isDisabled={settings.max === undefined}
|
||||
value={settings.max ?? (`${integerField.max} (inherited)` as unknown as number)}
|
||||
onChange={onChange}
|
||||
min={constraintMin}
|
||||
max={constraintMax}
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
import { Flex, FormLabel, Spacer } from '@invoke-ai/ui-library';
|
||||
import { NodeFieldElementResetToInitialValueIconButton } from 'features/nodes/components/flow/nodes/Invocation/fields/NodeFieldElementResetToInitialValueIconButton';
|
||||
import { useInputFieldLabel } from 'features/nodes/hooks/useInputFieldLabel';
|
||||
import { useInputFieldTemplate } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { useInputFieldLabelSafe } from 'features/nodes/hooks/useInputFieldLabelSafe';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import type { NodeFieldElement } from 'features/nodes/types/workflow';
|
||||
import { memo, useMemo } from 'react';
|
||||
|
||||
export const NodeFieldElementLabel = memo(({ el }: { el: NodeFieldElement }) => {
|
||||
const { data } = el;
|
||||
const { fieldIdentifier } = data;
|
||||
const label = useInputFieldLabel(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const fieldTemplate = useInputFieldTemplate(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const label = useInputFieldLabelSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const fieldTemplate = useInputFieldTemplateOrThrow(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
|
||||
const _label = useMemo(() => label || fieldTemplate.title, [label, fieldTemplate.title]);
|
||||
|
||||
|
||||
@@ -2,8 +2,8 @@ import { Flex, FormLabel, Input, Spacer } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useEditable } from 'common/hooks/useEditable';
|
||||
import { NodeFieldElementResetToInitialValueIconButton } from 'features/nodes/components/flow/nodes/Invocation/fields/NodeFieldElementResetToInitialValueIconButton';
|
||||
import { useInputFieldLabel } from 'features/nodes/hooks/useInputFieldLabel';
|
||||
import { useInputFieldTemplate } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { useInputFieldLabelSafe } from 'features/nodes/hooks/useInputFieldLabelSafe';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { fieldLabelChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { NodeFieldElement } from 'features/nodes/types/workflow';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
@@ -12,8 +12,8 @@ export const NodeFieldElementLabelEditable = memo(({ el }: { el: NodeFieldElemen
|
||||
const { data } = el;
|
||||
const { fieldIdentifier } = data;
|
||||
const dispatch = useAppDispatch();
|
||||
const label = useInputFieldLabel(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const fieldTemplate = useInputFieldTemplate(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const label = useInputFieldLabelSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const fieldTemplate = useInputFieldTemplateOrThrow(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
const onChange = useCallback(
|
||||
|
||||
@@ -15,7 +15,7 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { NodeFieldElementFloatSettings } from 'features/nodes/components/sidePanel/builder/NodeFieldElementFloatSettings';
|
||||
import { NodeFieldElementIntegerSettings } from 'features/nodes/components/sidePanel/builder/NodeFieldElementIntegerSettings';
|
||||
import { NodeFieldElementStringSettings } from 'features/nodes/components/sidePanel/builder/NodeFieldElementStringSettings';
|
||||
import { useInputFieldTemplate } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { formElementNodeFieldDataChanged } from 'features/nodes/store/workflowSlice';
|
||||
import {
|
||||
isFloatFieldInputTemplate,
|
||||
@@ -36,7 +36,7 @@ export const NodeFieldElementSettings = memo(({ element }: { element: NodeFieldE
|
||||
const { id, data } = element;
|
||||
const { showDescription, fieldIdentifier } = data;
|
||||
const { nodeId, fieldName } = fieldIdentifier;
|
||||
const fieldTemplate = useInputFieldTemplate(nodeId, fieldName);
|
||||
const fieldTemplate = useInputFieldTemplateOrThrow(nodeId, fieldName);
|
||||
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
@@ -79,7 +79,7 @@ export const NodeFieldElementSettings = memo(({ element }: { element: NodeFieldE
|
||||
{settings?.type === 'integer-field-config' && isIntegerFieldInputTemplate(fieldTemplate) && (
|
||||
<NodeFieldElementIntegerSettings
|
||||
id={id}
|
||||
config={settings}
|
||||
settings={settings}
|
||||
nodeId={nodeId}
|
||||
fieldName={fieldName}
|
||||
fieldTemplate={fieldTemplate}
|
||||
@@ -88,13 +88,15 @@ export const NodeFieldElementSettings = memo(({ element }: { element: NodeFieldE
|
||||
{settings?.type === 'float-field-config' && isFloatFieldInputTemplate(fieldTemplate) && (
|
||||
<NodeFieldElementFloatSettings
|
||||
id={id}
|
||||
config={settings}
|
||||
settings={settings}
|
||||
nodeId={nodeId}
|
||||
fieldName={fieldName}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
)}
|
||||
{settings?.type === 'string-field-config' && <NodeFieldElementStringSettings id={id} config={settings} />}
|
||||
{settings?.type === 'string-field-config' && (
|
||||
<NodeFieldElementStringSettings id={id} settings={settings} />
|
||||
)}
|
||||
</Flex>
|
||||
</PopoverBody>
|
||||
</PopoverContent>
|
||||
|
||||
@@ -0,0 +1,205 @@
|
||||
import { Button, ButtonGroup, Divider, Flex, Grid, GridItem, IconButton, Input, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { getOverlayScrollbarsParams, overlayScrollbarsStyles } from 'common/components/OverlayScrollbars/constants';
|
||||
import { formElementNodeFieldDataChanged } from 'features/nodes/store/workflowSlice';
|
||||
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
|
||||
import { getDefaultStringOption, type NodeFieldStringDropdownSettings } from 'features/nodes/types/workflow';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiArrowCounterClockwiseBold, PiPlusBold, PiXBold } from 'react-icons/pi';
|
||||
|
||||
const overlayscrollbarsOptions = getOverlayScrollbarsParams({}).options;
|
||||
|
||||
export const NodeFieldElementStringDropdownSettings = memo(
|
||||
({ id, settings }: { id: string; settings: NodeFieldStringDropdownSettings }) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const onRemoveOption = useCallback(
|
||||
(index: number) => {
|
||||
const options = [...settings.options];
|
||||
options.splice(index, 1);
|
||||
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: { ...settings, options } } }));
|
||||
},
|
||||
[settings, dispatch, id]
|
||||
);
|
||||
|
||||
const onChangeOptionValue = useCallback(
|
||||
(index: number, value: string) => {
|
||||
if (!settings.options[index]) {
|
||||
return;
|
||||
}
|
||||
const option = { ...settings.options[index], value };
|
||||
const options = [...settings.options];
|
||||
options[index] = option;
|
||||
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: { ...settings, options } } }));
|
||||
},
|
||||
[dispatch, id, settings]
|
||||
);
|
||||
|
||||
const onChangeOptionLabel = useCallback(
|
||||
(index: number, label: string) => {
|
||||
if (!settings.options[index]) {
|
||||
return;
|
||||
}
|
||||
const option = { ...settings.options[index], label };
|
||||
const options = [...settings.options];
|
||||
options[index] = option;
|
||||
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: { ...settings, options } } }));
|
||||
},
|
||||
[dispatch, id, settings]
|
||||
);
|
||||
|
||||
const onAddOption = useCallback(() => {
|
||||
const options = [...settings.options, getDefaultStringOption()];
|
||||
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: { ...settings, options } } }));
|
||||
}, [dispatch, id, settings]);
|
||||
|
||||
const onResetOptions = useCallback(() => {
|
||||
const options = [getDefaultStringOption()];
|
||||
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: { ...settings, options } } }));
|
||||
}, [dispatch, id, settings]);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
className={NO_DRAG_CLASS}
|
||||
position="relative"
|
||||
w="full"
|
||||
h="auto"
|
||||
maxH={64}
|
||||
alignItems="stretch"
|
||||
justifyContent="center"
|
||||
borderRadius="base"
|
||||
flexDir="column"
|
||||
gap={2}
|
||||
>
|
||||
<ButtonGroup isAttached={false} w="full">
|
||||
<Button onClick={onAddOption} variant="ghost" flex={1} leftIcon={<PiPlusBold />}>
|
||||
{t('workflows.builder.addOption')}
|
||||
</Button>
|
||||
<Button onClick={onResetOptions} variant="ghost" flex={1} leftIcon={<PiArrowCounterClockwiseBold />}>
|
||||
{t('workflows.builder.resetOptions')}
|
||||
</Button>
|
||||
</ButtonGroup>
|
||||
{settings.options.length > 0 && (
|
||||
<>
|
||||
<Divider />
|
||||
<OverlayScrollbarsComponent
|
||||
className={NO_WHEEL_CLASS}
|
||||
defer
|
||||
style={overlayScrollbarsStyles}
|
||||
options={overlayscrollbarsOptions}
|
||||
>
|
||||
<Grid gap={1} gridTemplateColumns="auto 1fr 1fr auto" gridTemplateRows="auto 1fr" alignItems="center">
|
||||
<GridItem minW={8}>
|
||||
<Text textAlign="center" variant="subtext">
|
||||
#
|
||||
</Text>
|
||||
</GridItem>
|
||||
<GridItem>
|
||||
<Text textAlign="center" variant="subtext">
|
||||
{t('common.label')}
|
||||
</Text>
|
||||
</GridItem>
|
||||
<GridItem>
|
||||
<Text textAlign="center" variant="subtext">
|
||||
{t('common.value')}
|
||||
</Text>
|
||||
</GridItem>
|
||||
<GridItem />
|
||||
{settings.options.map(({ value, label }, index) => (
|
||||
<ListItemContent
|
||||
key={index}
|
||||
value={value}
|
||||
label={label}
|
||||
index={index}
|
||||
onRemoveOption={onRemoveOption}
|
||||
onChangeOptionValue={onChangeOptionValue}
|
||||
onChangeOptionLabel={onChangeOptionLabel}
|
||||
isRemoveDisabled={settings.options.length <= 1}
|
||||
/>
|
||||
))}
|
||||
</Grid>
|
||||
</OverlayScrollbarsComponent>
|
||||
</>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
NodeFieldElementStringDropdownSettings.displayName = 'NodeFieldElementStringDropdownSettings';
|
||||
|
||||
type ListItemContentProps = {
|
||||
value: string;
|
||||
label: string;
|
||||
index: number;
|
||||
onRemoveOption: (index: number) => void;
|
||||
onChangeOptionValue: (index: number, value: string) => void;
|
||||
onChangeOptionLabel: (index: number, label: string) => void;
|
||||
isRemoveDisabled: boolean;
|
||||
};
|
||||
|
||||
const ListItemContent = memo(
|
||||
({
|
||||
value,
|
||||
label,
|
||||
index,
|
||||
onRemoveOption,
|
||||
onChangeOptionValue,
|
||||
onChangeOptionLabel,
|
||||
isRemoveDisabled,
|
||||
}: ListItemContentProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const onClickRemove = useCallback(() => {
|
||||
onRemoveOption(index);
|
||||
}, [index, onRemoveOption]);
|
||||
|
||||
const onChangeValue = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => {
|
||||
onChangeOptionValue(index, e.target.value);
|
||||
},
|
||||
[index, onChangeOptionValue]
|
||||
);
|
||||
|
||||
const onChangeLabel = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => {
|
||||
onChangeOptionLabel(index, e.target.value);
|
||||
},
|
||||
[index, onChangeOptionLabel]
|
||||
);
|
||||
|
||||
return (
|
||||
<>
|
||||
<GridItem>
|
||||
<Text variant="subtext" textAlign="center">
|
||||
{index + 1}.
|
||||
</Text>
|
||||
</GridItem>
|
||||
<GridItem>
|
||||
<Input size="sm" resize="none" placeholder="label" value={label} onChange={onChangeLabel} />
|
||||
</GridItem>
|
||||
<GridItem>
|
||||
<Input size="sm" resize="none" placeholder="value" value={value} onChange={onChangeValue} />
|
||||
</GridItem>
|
||||
<GridItem>
|
||||
<IconButton
|
||||
tabIndex={-1}
|
||||
size="sm"
|
||||
variant="link"
|
||||
minW={8}
|
||||
minH={8}
|
||||
onClick={onClickRemove}
|
||||
isDisabled={isRemoveDisabled}
|
||||
icon={<PiXBold />}
|
||||
aria-label={t('common.delete')}
|
||||
/>
|
||||
</GridItem>
|
||||
</>
|
||||
);
|
||||
}
|
||||
);
|
||||
ListItemContent.displayName = 'ListItemContent';
|
||||
@@ -1,36 +1,55 @@
|
||||
import { FormControl, FormLabel, Select } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { formElementNodeFieldDataChanged } from 'features/nodes/store/workflowSlice';
|
||||
import { type NodeFieldStringSettings, zStringComponent } from 'features/nodes/types/workflow';
|
||||
import { getDefaultStringOption, type NodeFieldStringSettings, zStringComponent } from 'features/nodes/types/workflow';
|
||||
import { omit } from 'lodash-es';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export const NodeFieldElementStringSettings = memo(
|
||||
({ id, config }: { id: string; config: NodeFieldStringSettings }) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
type Props = {
|
||||
id: string;
|
||||
settings: NodeFieldStringSettings;
|
||||
};
|
||||
|
||||
const onChangeComponent = useCallback(
|
||||
(e: ChangeEvent<HTMLSelectElement>) => {
|
||||
const newConfig: NodeFieldStringSettings = {
|
||||
...config,
|
||||
component: zStringComponent.parse(e.target.value),
|
||||
export const NodeFieldElementStringSettings = memo(({ id, settings }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const onChangeComponent = useCallback(
|
||||
(e: ChangeEvent<HTMLSelectElement>) => {
|
||||
const component = zStringComponent.parse(e.target.value);
|
||||
if (component === settings.component) {
|
||||
// no change
|
||||
return;
|
||||
}
|
||||
if (component === 'dropdown') {
|
||||
// if the component is changing to dropdown, we need to set the options
|
||||
const newSettings: NodeFieldStringSettings = {
|
||||
...settings,
|
||||
component,
|
||||
options: [getDefaultStringOption()],
|
||||
};
|
||||
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: newConfig } }));
|
||||
},
|
||||
[config, dispatch, id]
|
||||
);
|
||||
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: newSettings } }));
|
||||
return;
|
||||
} else {
|
||||
// if the component is changing from dropdown, we need to remove the options
|
||||
const newSettings: NodeFieldStringSettings = omit({ ...settings, component }, 'options');
|
||||
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: newSettings } }));
|
||||
}
|
||||
},
|
||||
[settings, dispatch, id]
|
||||
);
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
<FormLabel flex={1}>{t('workflows.builder.component')}</FormLabel>
|
||||
<Select value={config.component} onChange={onChangeComponent} size="sm">
|
||||
<option value="input">{t('workflows.builder.singleLine')}</option>
|
||||
<option value="textarea">{t('workflows.builder.multiLine')}</option>
|
||||
</Select>
|
||||
</FormControl>
|
||||
);
|
||||
}
|
||||
);
|
||||
return (
|
||||
<FormControl>
|
||||
<FormLabel flex={1}>{t('workflows.builder.component')}</FormLabel>
|
||||
<Select value={settings.component} onChange={onChangeComponent} size="sm">
|
||||
<option value="input">{t('workflows.builder.singleLine')}</option>
|
||||
<option value="textarea">{t('workflows.builder.multiLine')}</option>
|
||||
<option value="dropdown">{t('workflows.builder.dropdown')}</option>
|
||||
</Select>
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
NodeFieldElementStringSettings.displayName = 'NodeFieldElementStringSettings';
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Flex, FormControl, FormHelperText } from '@invoke-ai/ui-library';
|
||||
import { linkifyOptions, linkifySx } from 'common/components/linkify';
|
||||
import { InputFieldGate } from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldGate';
|
||||
import { InputFieldRenderer } from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer';
|
||||
import { useContainerContext } from 'features/nodes/components/sidePanel/builder/contexts';
|
||||
import { NodeFieldElementLabel } from 'features/nodes/components/sidePanel/builder/NodeFieldElementLabel';
|
||||
import { useInputFieldDescription } from 'features/nodes/hooks/useInputFieldDescription';
|
||||
import { useInputFieldTemplate } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { useInputFieldDescriptionSafe } from 'features/nodes/hooks/useInputFieldDescriptionSafe';
|
||||
import { useInputFieldTemplateOrThrow, useInputFieldTemplateSafe } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import type { NodeFieldElement } from 'features/nodes/types/workflow';
|
||||
import { NODE_FIELD_CLASS_NAME } from 'features/nodes/types/workflow';
|
||||
import Linkify from 'linkify-react';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const sx: SystemStyleObject = {
|
||||
'&[data-parent-layout="column"]': {
|
||||
@@ -25,16 +27,23 @@ const sx: SystemStyleObject = {
|
||||
},
|
||||
};
|
||||
|
||||
const useFormatFallbackLabel = () => {
|
||||
const { t } = useTranslation();
|
||||
const formatLabel = useCallback((name: string) => t('nodes.unknownFieldEditWorkflowToFix_withName', { name }), [t]);
|
||||
return formatLabel;
|
||||
};
|
||||
|
||||
export const NodeFieldElementViewMode = memo(({ el }: { el: NodeFieldElement }) => {
|
||||
const { id, data } = el;
|
||||
const { fieldIdentifier, showDescription } = data;
|
||||
const description = useInputFieldDescription(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const fieldTemplate = useInputFieldTemplate(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const description = useInputFieldDescriptionSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const fieldTemplate = useInputFieldTemplateSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const containerCtx = useContainerContext();
|
||||
const formatFallbackLabel = useFormatFallbackLabel();
|
||||
|
||||
const _description = useMemo(
|
||||
() => description || fieldTemplate.description,
|
||||
[description, fieldTemplate.description]
|
||||
() => description || fieldTemplate?.description || '',
|
||||
[description, fieldTemplate?.description]
|
||||
);
|
||||
|
||||
return (
|
||||
@@ -45,22 +54,45 @@ export const NodeFieldElementViewMode = memo(({ el }: { el: NodeFieldElement })
|
||||
data-parent-layout={containerCtx.layout}
|
||||
data-with-description={showDescription && !!_description}
|
||||
>
|
||||
<FormControl flex="1 1 0" orientation="vertical">
|
||||
<NodeFieldElementLabel el={el} />
|
||||
<Flex w="full" gap={4}>
|
||||
<InputFieldRenderer
|
||||
nodeId={fieldIdentifier.nodeId}
|
||||
fieldName={fieldIdentifier.fieldName}
|
||||
settings={data.settings}
|
||||
/>
|
||||
</Flex>
|
||||
{showDescription && _description && (
|
||||
<FormHelperText sx={linkifySx}>
|
||||
<Linkify options={linkifyOptions}>{_description}</Linkify>
|
||||
</FormHelperText>
|
||||
)}
|
||||
</FormControl>
|
||||
<InputFieldGate
|
||||
nodeId={el.data.fieldIdentifier.nodeId}
|
||||
fieldName={el.data.fieldIdentifier.fieldName}
|
||||
formatLabel={formatFallbackLabel}
|
||||
>
|
||||
<NodeFieldElementViewModeContent el={el} />
|
||||
</InputFieldGate>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
NodeFieldElementViewMode.displayName = 'NodeFieldElementViewMode';
|
||||
|
||||
const NodeFieldElementViewModeContent = memo(({ el }: { el: NodeFieldElement }) => {
|
||||
const { data } = el;
|
||||
const { fieldIdentifier, showDescription } = data;
|
||||
const description = useInputFieldDescriptionSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
const fieldTemplate = useInputFieldTemplateOrThrow(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
|
||||
|
||||
const _description = useMemo(
|
||||
() => description || fieldTemplate.description,
|
||||
[description, fieldTemplate.description]
|
||||
);
|
||||
|
||||
return (
|
||||
<FormControl flex="1 1 0" orientation="vertical">
|
||||
<NodeFieldElementLabel el={el} />
|
||||
<Flex w="full" gap={4}>
|
||||
<InputFieldRenderer
|
||||
nodeId={fieldIdentifier.nodeId}
|
||||
fieldName={fieldIdentifier.fieldName}
|
||||
settings={data.settings}
|
||||
/>
|
||||
</Flex>
|
||||
{showDescription && _description && (
|
||||
<FormHelperText sx={linkifySx}>
|
||||
<Linkify options={linkifyOptions}>{_description}</Linkify>
|
||||
</FormHelperText>
|
||||
)}
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
NodeFieldElementViewModeContent.displayName = 'NodeFieldElementViewModeContent';
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useInputFieldInstance } from 'features/nodes/hooks/useInputFieldInstance';
|
||||
import { useInputFieldTemplate } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { formElementAdded, selectFormRootElementId } from 'features/nodes/store/workflowSlice';
|
||||
import { buildNodeFieldElement } from 'features/nodes/types/workflow';
|
||||
import { useCallback } from 'react';
|
||||
@@ -8,7 +8,7 @@ import { useCallback } from 'react';
|
||||
export const useAddNodeFieldToRoot = (nodeId: string, fieldName: string) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const rootElementId = useAppSelector(selectFormRootElementId);
|
||||
const fieldTemplate = useInputFieldTemplate(nodeId, fieldName);
|
||||
const fieldTemplate = useInputFieldTemplateOrThrow(nodeId, fieldName);
|
||||
const field = useInputFieldInstance(nodeId, fieldName);
|
||||
|
||||
const addNodeFieldToRoot = useCallback(() => {
|
||||
|
||||
@@ -1,7 +1,20 @@
|
||||
import type { ButtonProps, CheckboxProps } from '@invoke-ai/ui-library';
|
||||
import { Button, Checkbox, Collapse, Flex, Spacer, Text } from '@invoke-ai/ui-library';
|
||||
import {
|
||||
Box,
|
||||
Button,
|
||||
ButtonGroup,
|
||||
Checkbox,
|
||||
Collapse,
|
||||
Flex,
|
||||
Icon,
|
||||
IconButton,
|
||||
Spacer,
|
||||
Text,
|
||||
Tooltip,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { getOverlayScrollbarsParams, overlayScrollbarsStyles } from 'common/components/OverlayScrollbars/constants';
|
||||
import type { WorkflowLibraryView, WorkflowTagCategory } from 'features/nodes/store/workflowLibrarySlice';
|
||||
import {
|
||||
$workflowLibraryCategoriesOptions,
|
||||
@@ -15,9 +28,10 @@ import {
|
||||
} from 'features/nodes/store/workflowLibrarySlice';
|
||||
import { NewWorkflowButton } from 'features/workflowLibrary/components/NewWorkflowButton';
|
||||
import { UploadWorkflowButton } from 'features/workflowLibrary/components/UploadWorkflowButton';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiArrowCounterClockwiseBold, PiUsersBold } from 'react-icons/pi';
|
||||
import { PiArrowCounterClockwiseBold, PiStarFill, PiUsersBold } from 'react-icons/pi';
|
||||
import { useDispatch } from 'react-redux';
|
||||
import { useGetCountsByTagQuery } from 'services/api/endpoints/workflows';
|
||||
|
||||
@@ -27,7 +41,7 @@ export const WorkflowLibrarySideNav = () => {
|
||||
const view = useAppSelector(selectWorkflowLibraryView);
|
||||
|
||||
return (
|
||||
<Flex h="full" minH={0} overflow="hidden" flexDir="column" w={64} gap={1}>
|
||||
<Flex h="full" minH={0} overflow="hidden" flexDir="column" w={64} gap={0}>
|
||||
<Flex flexDir="column" w="full" pb={2}>
|
||||
<WorkflowLibraryViewButton view="recent">{t('workflows.recentlyOpened')}</WorkflowLibraryViewButton>
|
||||
</Flex>
|
||||
@@ -48,7 +62,7 @@ export const WorkflowLibrarySideNav = () => {
|
||||
)}
|
||||
</Flex>
|
||||
<Flex h="full" minH={0} overflow="hidden" flexDir="column">
|
||||
<WorkflowLibraryViewButton view="defaults">{t('workflows.browseWorkflows')}</WorkflowLibraryViewButton>
|
||||
<BrowseWorkflowsButton />
|
||||
<DefaultsViewCheckboxesCollapsible />
|
||||
</Flex>
|
||||
<Spacer />
|
||||
@@ -58,39 +72,56 @@ export const WorkflowLibrarySideNav = () => {
|
||||
);
|
||||
};
|
||||
|
||||
const DefaultsViewCheckboxesCollapsible = memo(() => {
|
||||
const BrowseWorkflowsButton = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useDispatch();
|
||||
const tags = useAppSelector(selectWorkflowLibrarySelectedTags);
|
||||
const tagCategoryOptions = useStore($workflowLibraryTagCategoriesOptions);
|
||||
const view = useAppSelector(selectWorkflowLibraryView);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const selectedTags = useAppSelector(selectWorkflowLibrarySelectedTags);
|
||||
const resetTags = useCallback(() => {
|
||||
dispatch(workflowLibraryTagsReset());
|
||||
}, [dispatch]);
|
||||
|
||||
if (view === 'defaults' && selectedTags.length > 0) {
|
||||
return (
|
||||
<ButtonGroup>
|
||||
<WorkflowLibraryViewButton view="defaults" w="auto">
|
||||
{t('workflows.browseWorkflows')}
|
||||
</WorkflowLibraryViewButton>
|
||||
<Tooltip label={t('workflows.deselectAll')}>
|
||||
<IconButton
|
||||
onClick={resetTags}
|
||||
size="md"
|
||||
aria-label={t('workflows.deselectAll')}
|
||||
icon={<PiArrowCounterClockwiseBold size={12} />}
|
||||
variant="ghost"
|
||||
bg="base.700"
|
||||
color="base.50"
|
||||
/>
|
||||
</Tooltip>
|
||||
</ButtonGroup>
|
||||
);
|
||||
}
|
||||
|
||||
return <WorkflowLibraryViewButton view="defaults">{t('workflows.browseWorkflows')}</WorkflowLibraryViewButton>;
|
||||
});
|
||||
BrowseWorkflowsButton.displayName = 'BrowseWorkflowsButton';
|
||||
|
||||
const overlayscrollbarsOptions = getOverlayScrollbarsParams({ visibility: 'visible' }).options;
|
||||
|
||||
const DefaultsViewCheckboxesCollapsible = memo(() => {
|
||||
const tagCategoryOptions = useStore($workflowLibraryTagCategoriesOptions);
|
||||
const view = useAppSelector(selectWorkflowLibraryView);
|
||||
|
||||
return (
|
||||
<Collapse in={view === 'defaults'}>
|
||||
<Flex flexDir="column" gap={2} pl={4} py={2} overflow="hidden" h="100%" minH={0}>
|
||||
<Button
|
||||
isDisabled={tags.length === 0}
|
||||
onClick={resetTags}
|
||||
size="sm"
|
||||
variant="link"
|
||||
fontWeight="bold"
|
||||
justifyContent="flex-start"
|
||||
flexGrow={0}
|
||||
flexShrink={0}
|
||||
leftIcon={<PiArrowCounterClockwiseBold />}
|
||||
h={8}
|
||||
>
|
||||
{t('workflows.deselectAll')}
|
||||
</Button>
|
||||
<Flex flexDir="column" gap={2} overflow="auto">
|
||||
{tagCategoryOptions.map((tagCategory) => (
|
||||
<TagCategory key={tagCategory.categoryTKey} tagCategory={tagCategory} />
|
||||
))}
|
||||
</Flex>
|
||||
<OverlayScrollbarsComponent style={overlayScrollbarsStyles} options={overlayscrollbarsOptions}>
|
||||
<Flex flexDir="column" gap={2} overflow="auto">
|
||||
{tagCategoryOptions.map((tagCategory) => (
|
||||
<TagCategory key={tagCategory.categoryTKey} tagCategory={tagCategory} />
|
||||
))}
|
||||
</Flex>
|
||||
</OverlayScrollbarsComponent>
|
||||
</Flex>
|
||||
</Collapse>
|
||||
);
|
||||
@@ -102,7 +133,7 @@ const useCountForIndividualTag = (tag: string) => {
|
||||
const queryArg = useMemo(
|
||||
() =>
|
||||
({
|
||||
tags: allTags,
|
||||
tags: allTags.map((tag) => tag.label),
|
||||
categories: ['default'],
|
||||
}) satisfies Parameters<typeof useGetCountsByTagQuery>[0],
|
||||
[allTags]
|
||||
@@ -127,7 +158,7 @@ const useCountForTagCategory = (tagCategory: WorkflowTagCategory) => {
|
||||
const queryArg = useMemo(
|
||||
() =>
|
||||
({
|
||||
tags: allTags,
|
||||
tags: allTags.map((tag) => tag.label),
|
||||
categories: ['default'], // We only allow filtering by tag for default workflows
|
||||
}) satisfies Parameters<typeof useGetCountsByTagQuery>[0],
|
||||
[allTags]
|
||||
@@ -140,7 +171,7 @@ const useCountForTagCategory = (tagCategory: WorkflowTagCategory) => {
|
||||
return { count: 0 };
|
||||
}
|
||||
return {
|
||||
count: tagCategory.tags.reduce((acc, tag) => acc + (data[tag] ?? 0), 0),
|
||||
count: tagCategory.tags.reduce((acc, tag) => acc + (data[tag.label] ?? 0), 0),
|
||||
};
|
||||
},
|
||||
}) satisfies Parameters<typeof useGetCountsByTagQuery>[1],
|
||||
@@ -190,7 +221,7 @@ const TagCategory = memo(({ tagCategory }: { tagCategory: WorkflowTagCategory })
|
||||
</Text>
|
||||
<Flex flexDir="column" gap={2} pl={4}>
|
||||
{tagCategory.tags.map((tag) => (
|
||||
<TagCheckbox key={tag} tag={tag} />
|
||||
<TagCheckbox key={tag.label} tag={tag} />
|
||||
))}
|
||||
</Flex>
|
||||
</Flex>
|
||||
@@ -198,14 +229,15 @@ const TagCategory = memo(({ tagCategory }: { tagCategory: WorkflowTagCategory })
|
||||
});
|
||||
TagCategory.displayName = 'TagCategory';
|
||||
|
||||
const TagCheckbox = memo(({ tag, ...rest }: CheckboxProps & { tag: string }) => {
|
||||
const TagCheckbox = memo(({ tag, ...rest }: CheckboxProps & { tag: { label: string; recommended?: boolean } }) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const selectedTags = useAppSelector(selectWorkflowLibrarySelectedTags);
|
||||
const isChecked = selectedTags.includes(tag);
|
||||
const count = useCountForIndividualTag(tag);
|
||||
const isChecked = selectedTags.includes(tag.label);
|
||||
const count = useCountForIndividualTag(tag.label);
|
||||
|
||||
const onChange = useCallback(() => {
|
||||
dispatch(workflowLibraryTagToggled(tag));
|
||||
dispatch(workflowLibraryTagToggled(tag.label));
|
||||
}, [dispatch, tag]);
|
||||
|
||||
if (count === 0) {
|
||||
@@ -213,9 +245,17 @@ const TagCheckbox = memo(({ tag, ...rest }: CheckboxProps & { tag: string }) =>
|
||||
}
|
||||
|
||||
return (
|
||||
<Checkbox isChecked={isChecked} onChange={onChange} {...rest} flexShrink={0}>
|
||||
<Text>{`${tag} (${count})`}</Text>
|
||||
</Checkbox>
|
||||
<Flex alignItems="center" gap={2}>
|
||||
<Checkbox isChecked={isChecked} onChange={onChange} {...rest} flexShrink={0} />
|
||||
<Text>{`${tag.label} (${count})`}</Text>
|
||||
{tag.recommended && (
|
||||
<Tooltip label={t('workflows.recommended')}>
|
||||
<Box as="span" lineHeight={0}>
|
||||
<Icon as={PiStarFill} boxSize={4} fill="invokeYellow.500" />
|
||||
</Box>
|
||||
</Tooltip>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
TagCheckbox.displayName = 'TagCheckbox';
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useInputFieldTemplate } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { fieldValueReset } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
@@ -10,7 +10,7 @@ import { useCallback, useMemo } from 'react';
|
||||
export const useInputFieldDefaultValue = (nodeId: string, fieldName: string) => {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const fieldTemplate = useInputFieldTemplate(nodeId, fieldName);
|
||||
const fieldTemplate = useInputFieldTemplateOrThrow(nodeId, fieldName);
|
||||
const selectIsChanged = useMemo(
|
||||
() =>
|
||||
createSelector(selectNodesSlice, (nodes) => {
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useInputFieldDescription = (nodeId: string, fieldName: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(selectNodesSlice, (nodes) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return '';
|
||||
}
|
||||
return node?.data.inputs[fieldName]?.description ?? '';
|
||||
}),
|
||||
[fieldName, nodeId]
|
||||
);
|
||||
|
||||
const notes = useAppSelector(selector);
|
||||
return notes;
|
||||
};
|
||||
@@ -0,0 +1,26 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectFieldInputInstanceSafe, selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
/**
|
||||
* Gets the user-defined description of an input field for a given node.
|
||||
*
|
||||
* If the node doesn't exist or is not an invocation node, an empty string is returned.
|
||||
*
|
||||
* @param nodeId The ID of the node
|
||||
* @param fieldName The name of the field
|
||||
*/
|
||||
export const useInputFieldDescriptionSafe = (nodeId: string, fieldName: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
selectNodesSlice,
|
||||
(nodes) => selectFieldInputInstanceSafe(nodes, nodeId, fieldName)?.description ?? ''
|
||||
),
|
||||
[fieldName, nodeId]
|
||||
);
|
||||
|
||||
const description = useAppSelector(selector);
|
||||
return description;
|
||||
};
|
||||
@@ -1,18 +0,0 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectFieldInputInstanceSafe, selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useInputFieldLabel = (nodeId: string, fieldName: string): string => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(selectNodesSlice, (nodes) => {
|
||||
return selectFieldInputInstanceSafe(nodes, nodeId, fieldName)?.label ?? '';
|
||||
}),
|
||||
[fieldName, nodeId]
|
||||
);
|
||||
|
||||
const label = useAppSelector(selector);
|
||||
|
||||
return label;
|
||||
};
|
||||
@@ -0,0 +1,24 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectFieldInputInstanceSafe, selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
/**
|
||||
* Gets the user-defined label of an input field for a given node.
|
||||
*
|
||||
* If the node doesn't exist or is not an invocation node, an empty string is returned.
|
||||
*
|
||||
* @param nodeId The ID of the node
|
||||
* @param fieldName The name of the field
|
||||
*/
|
||||
export const useInputFieldLabelSafe = (nodeId: string, fieldName: string): string => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(selectNodesSlice, (nodes) => selectFieldInputInstanceSafe(nodes, nodeId, fieldName)?.label ?? ''),
|
||||
[fieldName, nodeId]
|
||||
);
|
||||
|
||||
const label = useAppSelector(selector);
|
||||
|
||||
return label;
|
||||
};
|
||||
@@ -5,7 +5,7 @@ import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import { selectInvocationNodeSafe, selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useInputFieldName = (nodeId: string, fieldName: string) => {
|
||||
export const useInputFieldNameSafe = (nodeId: string, fieldName: string) => {
|
||||
const templates = useStore($templates);
|
||||
|
||||
const selector = useMemo(
|
||||
@@ -3,7 +3,16 @@ import type { FieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { useMemo } from 'react';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
export const useInputFieldTemplate = (nodeId: string, fieldName: string): FieldInputTemplate => {
|
||||
/**
|
||||
* Returns the template for a specific input field of a node.
|
||||
*
|
||||
* **Note:** This hook will throw an error if the template for the input field is not found.
|
||||
*
|
||||
* @param nodeId - The ID of the node.
|
||||
* @param fieldName - The name of the input field.
|
||||
* @throws Will throw an error if the template for the input field is not found.
|
||||
*/
|
||||
export const useInputFieldTemplateOrThrow = (nodeId: string, fieldName: string): FieldInputTemplate => {
|
||||
const template = useNodeTemplate(nodeId);
|
||||
const fieldTemplate = useMemo(() => {
|
||||
const _fieldTemplate = template.inputs[fieldName];
|
||||
@@ -12,3 +21,17 @@ export const useInputFieldTemplate = (nodeId: string, fieldName: string): FieldI
|
||||
}, [fieldName, template.inputs]);
|
||||
return fieldTemplate;
|
||||
};
|
||||
|
||||
/**
|
||||
* Returns the template for a specific input field of a node.
|
||||
*
|
||||
* **Note:** This function is a safe version of `useInputFieldTemplate` and will not throw an error if the template is not found.
|
||||
*
|
||||
* @param nodeId - The ID of the node.
|
||||
* @param fieldName - The name of the input field.
|
||||
*/
|
||||
export const useInputFieldTemplateSafe = (nodeId: string, fieldName: string): FieldInputTemplate | null => {
|
||||
const template = useNodeTemplate(nodeId);
|
||||
const fieldTemplate = useMemo(() => template.inputs[fieldName] ?? null, [fieldName, template.inputs]);
|
||||
return fieldTemplate;
|
||||
};
|
||||
|
||||
@@ -27,6 +27,7 @@ import type {
|
||||
IntegerFieldValue,
|
||||
IntegerGeneratorFieldValue,
|
||||
IPAdapterModelFieldValue,
|
||||
LLaVAModelFieldValue,
|
||||
LoRAModelFieldValue,
|
||||
MainModelFieldValue,
|
||||
ModelIdentifierFieldValue,
|
||||
@@ -64,6 +65,7 @@ import {
|
||||
zIntegerFieldValue,
|
||||
zIntegerGeneratorFieldValue,
|
||||
zIPAdapterModelFieldValue,
|
||||
zLLaVAModelFieldValue,
|
||||
zLoRAModelFieldValue,
|
||||
zMainModelFieldValue,
|
||||
zModelIdentifierFieldValue,
|
||||
@@ -380,6 +382,9 @@ export const nodesSlice = createSlice({
|
||||
fieldLoRAModelValueChanged: (state, action: FieldValueAction<LoRAModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zLoRAModelFieldValue);
|
||||
},
|
||||
fieldLLaVAModelValueChanged: (state, action: FieldValueAction<LLaVAModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zLLaVAModelFieldValue);
|
||||
},
|
||||
fieldControlNetModelValueChanged: (state, action: FieldValueAction<ControlNetModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zControlNetModelFieldValue);
|
||||
},
|
||||
@@ -509,6 +514,7 @@ export const {
|
||||
fieldSpandrelImageToImageModelValueChanged,
|
||||
fieldLabelChanged,
|
||||
fieldLoRAModelValueChanged,
|
||||
fieldLLaVAModelValueChanged,
|
||||
fieldModelIdentifierValueChanged,
|
||||
fieldMainModelValueChanged,
|
||||
fieldIntegerValueChanged,
|
||||
@@ -633,6 +639,7 @@ export const isAnyNodeOrEdgeMutation = isAnyOf(
|
||||
fieldT2IAdapterModelValueChanged,
|
||||
fieldLabelChanged,
|
||||
fieldLoRAModelValueChanged,
|
||||
fieldLLaVAModelValueChanged,
|
||||
fieldMainModelValueChanged,
|
||||
fieldIntegerValueChanged,
|
||||
fieldIntegerCollectionValueChanged,
|
||||
|
||||
@@ -92,12 +92,21 @@ export const selectWorkflowLibraryView = createWorkflowLibrarySelector(({ view }
|
||||
export const DEFAULT_WORKFLOW_LIBRARY_CATEGORIES = ['user', 'default'] satisfies WorkflowCategory[];
|
||||
export const $workflowLibraryCategoriesOptions = atom<WorkflowCategory[]>(DEFAULT_WORKFLOW_LIBRARY_CATEGORIES);
|
||||
|
||||
export type WorkflowTagCategory = { categoryTKey: string; tags: string[] };
|
||||
export type WorkflowTagCategory = { categoryTKey: string; tags: Array<{ label: string; recommended?: boolean }> };
|
||||
export const DEFAULT_WORKFLOW_LIBRARY_TAG_CATEGORIES: WorkflowTagCategory[] = [
|
||||
{ categoryTKey: 'Industry', tags: ['Architecture', 'Fashion', 'Game Dev', 'Food'] },
|
||||
{ categoryTKey: 'Common Tasks', tags: ['Upscaling', 'Text to Image', 'Image to Image'] },
|
||||
{ categoryTKey: 'Model Architecture', tags: ['SD1.5', 'SDXL', 'SD3.5', 'FLUX'] },
|
||||
{ categoryTKey: 'Tech Showcase', tags: ['Control', 'Reference Image'] },
|
||||
{
|
||||
categoryTKey: 'Industry',
|
||||
tags: [{ label: 'Architecture' }, { label: 'Fashion' }, { label: 'Game Dev' }, { label: 'Food' }],
|
||||
},
|
||||
{
|
||||
categoryTKey: 'Common Tasks',
|
||||
tags: [{ label: 'Upscaling' }, { label: 'Text to Image' }, { label: 'Image to Image' }],
|
||||
},
|
||||
{
|
||||
categoryTKey: 'Model Architecture',
|
||||
tags: [{ label: 'SD1.5' }, { label: 'SDXL' }, { label: 'SD3.5' }, { label: 'FLUX' }],
|
||||
},
|
||||
{ categoryTKey: 'Tech Showcase', tags: [{ label: 'Control' }, { label: 'Reference Image' }] },
|
||||
];
|
||||
export const $workflowLibraryTagCategoriesOptions = atom<WorkflowTagCategory[]>(
|
||||
DEFAULT_WORKFLOW_LIBRARY_TAG_CATEGORIES
|
||||
|
||||
@@ -69,6 +69,7 @@ const zModelType = z.enum([
|
||||
'main',
|
||||
'vae',
|
||||
'lora',
|
||||
'llava_onevision',
|
||||
'control_lora',
|
||||
'controlnet',
|
||||
't2i_adapter',
|
||||
|
||||
@@ -189,6 +189,10 @@ const zLoRAModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('LoRAModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zLLaVAModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('LLaVAModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zControlNetModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('ControlNetModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
@@ -273,6 +277,7 @@ const zStatefulFieldType = z.union([
|
||||
zSDXLRefinerModelFieldType,
|
||||
zVAEModelFieldType,
|
||||
zLoRAModelFieldType,
|
||||
zLLaVAModelFieldType,
|
||||
zControlNetModelFieldType,
|
||||
zIPAdapterModelFieldType,
|
||||
zT2IAdapterModelFieldType,
|
||||
@@ -309,6 +314,7 @@ const modelFieldTypeNames = [
|
||||
zSDXLRefinerModelFieldType.shape.name.value,
|
||||
zVAEModelFieldType.shape.name.value,
|
||||
zLoRAModelFieldType.shape.name.value,
|
||||
zLLaVAModelFieldType.shape.name.value,
|
||||
zControlNetModelFieldType.shape.name.value,
|
||||
zIPAdapterModelFieldType.shape.name.value,
|
||||
zT2IAdapterModelFieldType.shape.name.value,
|
||||
@@ -891,6 +897,26 @@ export const isLoRAModelFieldInputInstance = buildInstanceTypeGuard(zLoRAModelFi
|
||||
export const isLoRAModelFieldInputTemplate = buildTemplateTypeGuard<LoRAModelFieldInputTemplate>('LoRAModelField');
|
||||
// #endregion
|
||||
|
||||
// #region LLaVAModelField
|
||||
export const zLLaVAModelFieldValue = zModelIdentifierField.optional();
|
||||
const zLLaVAModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zLLaVAModelFieldValue,
|
||||
});
|
||||
const zLLaVAModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zLLaVAModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zLLaVAModelFieldValue,
|
||||
});
|
||||
const zLLaVAModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zLLaVAModelFieldType,
|
||||
});
|
||||
export type LLaVAModelFieldValue = z.infer<typeof zLLaVAModelFieldValue>;
|
||||
export type LLaVAModelFieldInputInstance = z.infer<typeof zLLaVAModelFieldInputInstance>;
|
||||
export type LLaVAModelFieldInputTemplate = z.infer<typeof zLLaVAModelFieldInputTemplate>;
|
||||
export const isLLaVAModelFieldInputInstance = buildInstanceTypeGuard(zLLaVAModelFieldInputInstance);
|
||||
export const isLLaVAModelFieldInputTemplate = buildTemplateTypeGuard<LLaVAModelFieldInputTemplate>('LLaVAModelField');
|
||||
// #endregion
|
||||
|
||||
// #region ControlNetModelField
|
||||
export const zControlNetModelFieldValue = zModelIdentifierField.optional();
|
||||
const zControlNetModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
@@ -1739,6 +1765,7 @@ export const zStatefulFieldValue = z.union([
|
||||
zSDXLRefinerModelFieldValue,
|
||||
zVAEModelFieldValue,
|
||||
zLoRAModelFieldValue,
|
||||
zLLaVAModelFieldValue,
|
||||
zControlNetModelFieldValue,
|
||||
zIPAdapterModelFieldValue,
|
||||
zT2IAdapterModelFieldValue,
|
||||
@@ -1785,6 +1812,7 @@ const zStatefulFieldInputInstance = z.union([
|
||||
zSDXLRefinerModelFieldInputInstance,
|
||||
zVAEModelFieldInputInstance,
|
||||
zLoRAModelFieldInputInstance,
|
||||
zLLaVAModelFieldInputInstance,
|
||||
zControlNetModelFieldInputInstance,
|
||||
zIPAdapterModelFieldInputInstance,
|
||||
zT2IAdapterModelFieldInputInstance,
|
||||
@@ -1825,6 +1853,7 @@ const zStatefulFieldInputTemplate = z.union([
|
||||
zSDXLRefinerModelFieldInputTemplate,
|
||||
zVAEModelFieldInputTemplate,
|
||||
zLoRAModelFieldInputTemplate,
|
||||
zLLaVAModelFieldInputTemplate,
|
||||
zControlNetModelFieldInputTemplate,
|
||||
zIPAdapterModelFieldInputTemplate,
|
||||
zT2IAdapterModelFieldInputTemplate,
|
||||
@@ -1871,6 +1900,7 @@ const zStatefulFieldOutputTemplate = z.union([
|
||||
zSDXLRefinerModelFieldOutputTemplate,
|
||||
zVAEModelFieldOutputTemplate,
|
||||
zLoRAModelFieldOutputTemplate,
|
||||
zLLaVAModelFieldOutputTemplate,
|
||||
zControlNetModelFieldOutputTemplate,
|
||||
zIPAdapterModelFieldOutputTemplate,
|
||||
zT2IAdapterModelFieldOutputTemplate,
|
||||
|
||||
@@ -91,14 +91,37 @@ const zNodeFieldIntegerSettings = z.object({
|
||||
export type NodeFieldIntegerSettings = z.infer<typeof zNodeFieldIntegerSettings>;
|
||||
export const getIntegerFieldSettingsDefaults = (): NodeFieldIntegerSettings => zNodeFieldIntegerSettings.parse({});
|
||||
|
||||
export const zStringComponent = z.enum(['input', 'textarea']);
|
||||
const zStringOption = z
|
||||
.object({
|
||||
label: z.string(),
|
||||
value: z.string(),
|
||||
})
|
||||
.default({ label: '', value: '' });
|
||||
type StringOption = z.infer<typeof zStringOption>;
|
||||
export const getDefaultStringOption = (): StringOption => ({ label: '', value: '' });
|
||||
export const zStringComponent = z.enum(['input', 'textarea', 'dropdown']);
|
||||
const STRING_FIELD_CONFIG_TYPE = 'string-field-config';
|
||||
const zNodeFieldStringSettings = z.object({
|
||||
const zNodeFieldStringInputSettings = z.object({
|
||||
type: z.literal(STRING_FIELD_CONFIG_TYPE).default(STRING_FIELD_CONFIG_TYPE),
|
||||
component: zStringComponent.default('input'),
|
||||
component: z.literal('input').default('input'),
|
||||
});
|
||||
const zNodeFieldStringTextareaSettings = z.object({
|
||||
type: z.literal(STRING_FIELD_CONFIG_TYPE).default(STRING_FIELD_CONFIG_TYPE),
|
||||
component: z.literal('textarea').default('textarea'),
|
||||
});
|
||||
const zNodeFieldStringDropdownSettings = z.object({
|
||||
type: z.literal(STRING_FIELD_CONFIG_TYPE).default(STRING_FIELD_CONFIG_TYPE),
|
||||
component: z.literal('dropdown').default('dropdown'),
|
||||
options: z.array(zStringOption),
|
||||
});
|
||||
export type NodeFieldStringDropdownSettings = z.infer<typeof zNodeFieldStringDropdownSettings>;
|
||||
const zNodeFieldStringSettings = z.union([
|
||||
zNodeFieldStringInputSettings,
|
||||
zNodeFieldStringTextareaSettings,
|
||||
zNodeFieldStringDropdownSettings,
|
||||
]);
|
||||
export type NodeFieldStringSettings = z.infer<typeof zNodeFieldStringSettings>;
|
||||
export const getStringFieldSettingsDefaults = (): NodeFieldStringSettings => zNodeFieldStringSettings.parse({});
|
||||
export const getStringFieldSettingsDefaults = (): NodeFieldStringSettings => zNodeFieldStringInputSettings.parse({});
|
||||
|
||||
const zNodeFieldData = z.object({
|
||||
fieldIdentifier: zFieldIdentifier,
|
||||
|
||||
@@ -4,9 +4,8 @@ import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import type { CanvasControlLayerState, Rect } from 'features/controlLayers/store/types';
|
||||
import { getControlLayerWarnings } from 'features/controlLayers/store/validators';
|
||||
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import type { ImageDTO, Invocation } from 'services/api/types';
|
||||
import type { ImageDTO, Invocation, MainModelConfig } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const log = logger('system');
|
||||
@@ -17,7 +16,7 @@ type AddControlNetsArg = {
|
||||
g: Graph;
|
||||
rect: Rect;
|
||||
collector: Invocation<'collect'>;
|
||||
model: ParameterModel;
|
||||
model: MainModelConfig;
|
||||
};
|
||||
|
||||
type AddControlNetsResult = {
|
||||
@@ -66,7 +65,7 @@ type AddT2IAdaptersArg = {
|
||||
g: Graph;
|
||||
rect: Rect;
|
||||
collector: Invocation<'collect'>;
|
||||
model: ParameterModel;
|
||||
model: MainModelConfig;
|
||||
};
|
||||
|
||||
type AddT2IAdaptersResult = {
|
||||
@@ -114,7 +113,7 @@ type AddControlLoRAArg = {
|
||||
entities: CanvasControlLayerState[];
|
||||
g: Graph;
|
||||
rect: Rect;
|
||||
model: ParameterModel;
|
||||
model: MainModelConfig;
|
||||
denoise: Invocation<'flux_denoise'>;
|
||||
};
|
||||
|
||||
@@ -129,9 +128,9 @@ export const addControlLoRA = async ({ manager, entities, g, rect, model, denois
|
||||
// No valid control LoRA found
|
||||
return;
|
||||
}
|
||||
if (validControlLayers.length > 1) {
|
||||
throw new Error('Cannot add more than one FLUX control LoRA.');
|
||||
}
|
||||
|
||||
assert(model.variant !== 'inpaint', 'FLUX Control LoRA is not compatible with FLUX Fill.');
|
||||
assert(validControlLayers.length <= 1, 'Cannot add more than one FLUX control LoRA.');
|
||||
|
||||
const getImageDTOResult = await withResultAsync(() => {
|
||||
const adapter = manager.adapters.controlLayers.get(validControlLayer.id);
|
||||
|
||||
@@ -0,0 +1,204 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { selectCanvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import type { Dimensions } from 'features/controlLayers/store/types';
|
||||
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import type { Invocation } from 'services/api/types';
|
||||
|
||||
type AddFLUXFillArg = {
|
||||
state: RootState;
|
||||
g: Graph;
|
||||
manager: CanvasManager;
|
||||
l2i: Invocation<'flux_vae_decode'>;
|
||||
denoise: Invocation<'flux_denoise'>;
|
||||
originalSize: Dimensions;
|
||||
scaledSize: Dimensions;
|
||||
};
|
||||
|
||||
export const addFLUXFill = async ({
|
||||
state,
|
||||
g,
|
||||
manager,
|
||||
l2i,
|
||||
denoise,
|
||||
originalSize,
|
||||
scaledSize,
|
||||
}: AddFLUXFillArg): Promise<Invocation<'invokeai_img_blend' | 'apply_mask_to_image'>> => {
|
||||
// FLUX Fill always fully denoises
|
||||
denoise.denoising_start = 0;
|
||||
denoise.denoising_end = 1;
|
||||
|
||||
const params = selectParamsSlice(state);
|
||||
const canvasSettings = selectCanvasSettingsSlice(state);
|
||||
const canvas = selectCanvasSlice(state);
|
||||
|
||||
const { bbox } = canvas;
|
||||
|
||||
const rasterAdapters = manager.compositor.getVisibleAdaptersOfType('raster_layer');
|
||||
const initialImage = await manager.compositor.getCompositeImageDTO(rasterAdapters, bbox.rect, {
|
||||
is_intermediate: true,
|
||||
silent: true,
|
||||
});
|
||||
|
||||
const inpaintMaskAdapters = manager.compositor.getVisibleAdaptersOfType('inpaint_mask');
|
||||
const maskImage = await manager.compositor.getCompositeImageDTO(inpaintMaskAdapters, bbox.rect, {
|
||||
is_intermediate: true,
|
||||
silent: true,
|
||||
});
|
||||
|
||||
const fluxFill = g.addNode({ type: 'flux_fill', id: getPrefixedId('flux_fill') });
|
||||
|
||||
const needsScaleBeforeProcessing = !isEqual(scaledSize, originalSize);
|
||||
|
||||
if (needsScaleBeforeProcessing) {
|
||||
// Scale before processing requires some resizing
|
||||
|
||||
// Combine the inpaint mask and the initial image's alpha channel into a single mask
|
||||
const maskAlphaToMask = g.addNode({
|
||||
id: getPrefixedId('alpha_to_mask'),
|
||||
type: 'tomask',
|
||||
image: { image_name: maskImage.image_name },
|
||||
invert: !canvasSettings.preserveMask,
|
||||
});
|
||||
const initialImageAlphaToMask = g.addNode({
|
||||
id: getPrefixedId('image_alpha_to_mask'),
|
||||
type: 'tomask',
|
||||
image: { image_name: initialImage.image_name },
|
||||
});
|
||||
const maskCombine = g.addNode({
|
||||
id: getPrefixedId('mask_combine'),
|
||||
type: 'mask_combine',
|
||||
});
|
||||
g.addEdge(maskAlphaToMask, 'image', maskCombine, 'mask1');
|
||||
g.addEdge(initialImageAlphaToMask, 'image', maskCombine, 'mask2');
|
||||
|
||||
// Resize the combined and initial image to the scaled size
|
||||
const resizeInputMaskToScaledSize = g.addNode({
|
||||
id: getPrefixedId('resize_mask_to_scaled_size'),
|
||||
type: 'img_resize',
|
||||
...scaledSize,
|
||||
});
|
||||
g.addEdge(maskCombine, 'image', resizeInputMaskToScaledSize, 'image');
|
||||
|
||||
const alphaMaskToTensorMask = g.addNode({
|
||||
type: 'image_mask_to_tensor',
|
||||
id: getPrefixedId('image_mask_to_tensor'),
|
||||
});
|
||||
g.addEdge(resizeInputMaskToScaledSize, 'image', alphaMaskToTensorMask, 'image');
|
||||
g.addEdge(alphaMaskToTensorMask, 'mask', fluxFill, 'mask');
|
||||
|
||||
// Resize the initial image to the scaled size and add to the FLUX Fill node
|
||||
const resizeInputImageToScaledSize = g.addNode({
|
||||
id: getPrefixedId('resize_image_to_scaled_size'),
|
||||
type: 'img_resize',
|
||||
image: { image_name: initialImage.image_name },
|
||||
...scaledSize,
|
||||
});
|
||||
g.addEdge(resizeInputImageToScaledSize, 'image', fluxFill, 'image');
|
||||
|
||||
// Provide the FLUX Fill conditioning w/ image and mask to the denoise node
|
||||
g.addEdge(fluxFill, 'fill_cond', denoise, 'fill_conditioning');
|
||||
|
||||
// Resize the output image back to the original size
|
||||
const resizeOutputImageToOriginalSize = g.addNode({
|
||||
id: getPrefixedId('resize_image_to_original_size'),
|
||||
type: 'img_resize',
|
||||
...originalSize,
|
||||
});
|
||||
// After denoising, resize the image and mask back to original size
|
||||
g.addEdge(l2i, 'image', resizeOutputImageToOriginalSize, 'image');
|
||||
|
||||
const expandMask = g.addNode({
|
||||
type: 'expand_mask_with_fade',
|
||||
id: getPrefixedId('expand_mask_with_fade'),
|
||||
fade_size_px: params.maskBlur,
|
||||
});
|
||||
g.addEdge(maskCombine, 'image', expandMask, 'mask');
|
||||
|
||||
// Do the paste back if we are sending to gallery (in which case we want to see the full image), or if we are sending
|
||||
// to canvas but not outputting only masked regions
|
||||
if (!canvasSettings.sendToCanvas || !canvasSettings.outputOnlyMaskedRegions) {
|
||||
const imageLayerBlend = g.addNode({
|
||||
type: 'invokeai_img_blend',
|
||||
id: getPrefixedId('image_layer_blend'),
|
||||
layer_base: { image_name: initialImage.image_name },
|
||||
});
|
||||
g.addEdge(resizeOutputImageToOriginalSize, 'image', imageLayerBlend, 'layer_upper');
|
||||
g.addEdge(expandMask, 'image', imageLayerBlend, 'mask');
|
||||
return imageLayerBlend;
|
||||
} else {
|
||||
// Otherwise, just apply the mask
|
||||
const applyMaskToImage = g.addNode({
|
||||
type: 'apply_mask_to_image',
|
||||
id: getPrefixedId('apply_mask_to_image'),
|
||||
invert_mask: true,
|
||||
});
|
||||
g.addEdge(expandMask, 'image', applyMaskToImage, 'mask');
|
||||
g.addEdge(resizeOutputImageToOriginalSize, 'image', applyMaskToImage, 'image');
|
||||
return applyMaskToImage;
|
||||
}
|
||||
} else {
|
||||
// Combine the inpaint mask and the initial image's alpha channel into a single mask
|
||||
const maskAlphaToMask = g.addNode({
|
||||
id: getPrefixedId('alpha_to_mask'),
|
||||
type: 'tomask',
|
||||
image: { image_name: maskImage.image_name },
|
||||
invert: !canvasSettings.preserveMask,
|
||||
});
|
||||
const initialImageAlphaToMask = g.addNode({
|
||||
id: getPrefixedId('image_alpha_to_mask'),
|
||||
type: 'tomask',
|
||||
image: { image_name: initialImage.image_name },
|
||||
});
|
||||
const maskCombine = g.addNode({
|
||||
id: getPrefixedId('mask_combine'),
|
||||
type: 'mask_combine',
|
||||
});
|
||||
g.addEdge(maskAlphaToMask, 'image', maskCombine, 'mask1');
|
||||
g.addEdge(initialImageAlphaToMask, 'image', maskCombine, 'mask2');
|
||||
|
||||
const alphaMaskToTensorMask = g.addNode({
|
||||
type: 'image_mask_to_tensor',
|
||||
id: getPrefixedId('image_mask_to_tensor'),
|
||||
});
|
||||
g.addEdge(maskCombine, 'image', alphaMaskToTensorMask, 'image');
|
||||
g.addEdge(alphaMaskToTensorMask, 'mask', fluxFill, 'mask');
|
||||
|
||||
fluxFill.image = { image_name: initialImage.image_name };
|
||||
g.addEdge(fluxFill, 'fill_cond', denoise, 'fill_conditioning');
|
||||
|
||||
const expandMask = g.addNode({
|
||||
type: 'expand_mask_with_fade',
|
||||
id: getPrefixedId('expand_mask_with_fade'),
|
||||
fade_size_px: params.maskBlur,
|
||||
});
|
||||
g.addEdge(maskCombine, 'image', expandMask, 'mask');
|
||||
|
||||
// Do the paste back if we are sending to gallery (in which case we want to see the full image), or if we are sending
|
||||
// to canvas but not outputting only masked regions
|
||||
if (!canvasSettings.sendToCanvas || !canvasSettings.outputOnlyMaskedRegions) {
|
||||
const imageLayerBlend = g.addNode({
|
||||
type: 'invokeai_img_blend',
|
||||
id: getPrefixedId('image_layer_blend'),
|
||||
layer_base: { image_name: initialImage.image_name },
|
||||
});
|
||||
g.addEdge(l2i, 'image', imageLayerBlend, 'layer_upper');
|
||||
g.addEdge(expandMask, 'image', imageLayerBlend, 'mask');
|
||||
return imageLayerBlend;
|
||||
} else {
|
||||
// Otherwise, just apply the mask
|
||||
const applyMaskToImage = g.addNode({
|
||||
type: 'apply_mask_to_image',
|
||||
id: getPrefixedId('apply_mask_to_image'),
|
||||
invert_mask: true,
|
||||
});
|
||||
g.addEdge(expandMask, 'image', applyMaskToImage, 'mask');
|
||||
g.addEdge(l2i, 'image', applyMaskToImage, 'image');
|
||||
return applyMaskToImage;
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -2,8 +2,7 @@ import type { CanvasReferenceImageState, FLUXReduxConfig } from 'features/contro
|
||||
import { isFLUXReduxConfig } from 'features/controlLayers/store/types';
|
||||
import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators';
|
||||
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
|
||||
import type { Invocation } from 'services/api/types';
|
||||
import type { Invocation, MainModelConfig } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
type AddFLUXReduxResult = {
|
||||
@@ -14,7 +13,7 @@ type AddFLUXReduxArg = {
|
||||
entities: CanvasReferenceImageState[];
|
||||
g: Graph;
|
||||
collector: Invocation<'collect'>;
|
||||
model: ParameterModel;
|
||||
model: MainModelConfig;
|
||||
};
|
||||
|
||||
export const addFLUXReduxes = ({ entities, g, collector, model }: AddFLUXReduxArg): AddFLUXReduxResult => {
|
||||
|
||||
@@ -5,8 +5,7 @@ import {
|
||||
} from 'features/controlLayers/store/types';
|
||||
import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators';
|
||||
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
|
||||
import type { Invocation } from 'services/api/types';
|
||||
import type { Invocation, MainModelConfig } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
type AddIPAdaptersResult = {
|
||||
@@ -17,7 +16,7 @@ type AddIPAdaptersArg = {
|
||||
entities: CanvasReferenceImageState[];
|
||||
g: Graph;
|
||||
collector: Invocation<'collect'>;
|
||||
model: ParameterModel;
|
||||
model: MainModelConfig;
|
||||
};
|
||||
|
||||
export const addIPAdapters = ({ entities, g, collector, model }: AddIPAdaptersArg): AddIPAdaptersResult => {
|
||||
|
||||
@@ -39,7 +39,7 @@ export const addInpaint = async ({
|
||||
scaledSize,
|
||||
denoising_start,
|
||||
fp32,
|
||||
}: AddInpaintArg): Promise<Invocation<'canvas_v2_mask_and_crop' | 'img_resize'>> => {
|
||||
}: AddInpaintArg): Promise<Invocation<'invokeai_img_blend' | 'apply_mask_to_image'>> => {
|
||||
denoise.denoising_start = denoising_start;
|
||||
|
||||
const params = selectParamsSlice(state);
|
||||
@@ -60,7 +60,9 @@ export const addInpaint = async ({
|
||||
silent: true,
|
||||
});
|
||||
|
||||
if (!isEqual(scaledSize, originalSize)) {
|
||||
const needsScaleBeforeProcessing = !isEqual(scaledSize, originalSize);
|
||||
|
||||
if (needsScaleBeforeProcessing) {
|
||||
// Scale before processing requires some resizing
|
||||
const i2l = g.addNode({
|
||||
id: i2lNodeType,
|
||||
@@ -104,10 +106,10 @@ export const addInpaint = async ({
|
||||
edge_radius: params.canvasCoherenceEdgeSize,
|
||||
fp32,
|
||||
});
|
||||
const canvasPasteBack = g.addNode({
|
||||
id: getPrefixedId('canvas_v2_mask_and_crop'),
|
||||
type: 'canvas_v2_mask_and_crop',
|
||||
mask_blur: params.maskBlur,
|
||||
const expandMask = g.addNode({
|
||||
type: 'expand_mask_with_fade',
|
||||
id: getPrefixedId('expand_mask_with_fade'),
|
||||
fade_size_px: params.maskBlur,
|
||||
});
|
||||
|
||||
// Resize initial image and mask to scaled size, feed into to gradient mask
|
||||
@@ -127,19 +129,32 @@ export const addInpaint = async ({
|
||||
|
||||
// After denoising, resize the image and mask back to original size
|
||||
g.addEdge(l2i, 'image', resizeImageToOriginalSize, 'image');
|
||||
g.addEdge(createGradientMask, 'expanded_mask_area', resizeMaskToOriginalSize, 'image');
|
||||
|
||||
// Finally, paste the generated masked image back onto the original image
|
||||
g.addEdge(resizeImageToOriginalSize, 'image', canvasPasteBack, 'generated_image');
|
||||
g.addEdge(resizeMaskToOriginalSize, 'image', canvasPasteBack, 'mask');
|
||||
g.addEdge(createGradientMask, 'expanded_mask_area', expandMask, 'mask');
|
||||
g.addEdge(expandMask, 'image', resizeMaskToOriginalSize, 'image');
|
||||
|
||||
// After denoising, resize the image and mask back to original size
|
||||
// Do the paste back if we are sending to gallery (in which case we want to see the full image), or if we are sending
|
||||
// to canvas but not outputting only masked regions
|
||||
if (!canvasSettings.sendToCanvas || !canvasSettings.outputOnlyMaskedRegions) {
|
||||
canvasPasteBack.source_image = { image_name: initialImage.image_name };
|
||||
const imageLayerBlend = g.addNode({
|
||||
type: 'invokeai_img_blend',
|
||||
id: getPrefixedId('image_layer_blend'),
|
||||
layer_base: { image_name: initialImage.image_name },
|
||||
});
|
||||
g.addEdge(resizeImageToOriginalSize, 'image', imageLayerBlend, 'layer_upper');
|
||||
g.addEdge(resizeMaskToOriginalSize, 'image', imageLayerBlend, 'mask');
|
||||
return imageLayerBlend;
|
||||
} else {
|
||||
// Otherwise, just apply the mask
|
||||
const applyMaskToImage = g.addNode({
|
||||
type: 'apply_mask_to_image',
|
||||
id: getPrefixedId('apply_mask_to_image'),
|
||||
invert_mask: true,
|
||||
});
|
||||
g.addEdge(resizeMaskToOriginalSize, 'image', applyMaskToImage, 'mask');
|
||||
g.addEdge(resizeImageToOriginalSize, 'image', applyMaskToImage, 'image');
|
||||
return applyMaskToImage;
|
||||
}
|
||||
|
||||
return canvasPasteBack;
|
||||
} else {
|
||||
// No scale before processing, much simpler
|
||||
const i2l = g.addNode({
|
||||
@@ -164,11 +179,6 @@ export const addInpaint = async ({
|
||||
fp32,
|
||||
image: { image_name: initialImage.image_name },
|
||||
});
|
||||
const canvasPasteBack = g.addNode({
|
||||
id: getPrefixedId('canvas_v2_mask_and_crop'),
|
||||
type: 'canvas_v2_mask_and_crop',
|
||||
mask_blur: params.maskBlur,
|
||||
});
|
||||
|
||||
g.addEdge(alphaToMask, 'image', createGradientMask, 'mask');
|
||||
g.addEdge(i2l, 'latents', denoise, 'latents');
|
||||
@@ -178,16 +188,35 @@ export const addInpaint = async ({
|
||||
g.addEdge(modelLoader, 'unet', createGradientMask, 'unet');
|
||||
}
|
||||
g.addEdge(createGradientMask, 'denoise_mask', denoise, 'denoise_mask');
|
||||
g.addEdge(createGradientMask, 'expanded_mask_area', canvasPasteBack, 'mask');
|
||||
|
||||
g.addEdge(l2i, 'image', canvasPasteBack, 'generated_image');
|
||||
const expandMask = g.addNode({
|
||||
type: 'expand_mask_with_fade',
|
||||
id: getPrefixedId('expand_mask_with_fade'),
|
||||
fade_size_px: params.maskBlur,
|
||||
});
|
||||
g.addEdge(createGradientMask, 'expanded_mask_area', expandMask, 'mask');
|
||||
|
||||
// Do the paste back if we are sending to gallery (in which case we want to see the full image), or if we are sending
|
||||
// to canvas but not outputting only masked regions
|
||||
if (!canvasSettings.sendToCanvas || !canvasSettings.outputOnlyMaskedRegions) {
|
||||
canvasPasteBack.source_image = { image_name: initialImage.image_name };
|
||||
const imageLayerBlend = g.addNode({
|
||||
type: 'invokeai_img_blend',
|
||||
id: getPrefixedId('image_layer_blend'),
|
||||
layer_base: { image_name: initialImage.image_name },
|
||||
});
|
||||
g.addEdge(l2i, 'image', imageLayerBlend, 'layer_upper');
|
||||
g.addEdge(expandMask, 'image', imageLayerBlend, 'mask');
|
||||
return imageLayerBlend;
|
||||
} else {
|
||||
// Otherwise, just apply the mask
|
||||
const applyMaskToImage = g.addNode({
|
||||
type: 'apply_mask_to_image',
|
||||
id: getPrefixedId('apply_mask_to_image'),
|
||||
invert_mask: true,
|
||||
});
|
||||
g.addEdge(expandMask, 'image', applyMaskToImage, 'mask');
|
||||
g.addEdge(l2i, 'image', applyMaskToImage, 'image');
|
||||
return applyMaskToImage;
|
||||
}
|
||||
|
||||
return canvasPasteBack;
|
||||
}
|
||||
};
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user