Compare commits
131 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ac1071a5e5 | ||
|
|
5295a398f3 | ||
|
|
0c7283c82d | ||
|
|
73ad173c74 | ||
|
|
c828a4e59f | ||
|
|
6bab040d24 | ||
|
|
f46bbaf8c4 | ||
|
|
fce6b3e44c | ||
|
|
d27907cc6d | ||
|
|
7ee3fef2db | ||
|
|
b39ce642b6 | ||
|
|
a148c4322c | ||
|
|
f6b7bc5d98 | ||
|
|
5f6c6abf9c | ||
|
|
cd76a31a8f | ||
|
|
e93f4d632d | ||
|
|
5a8489bbfc | ||
|
|
a24c9d0f7a | ||
|
|
7a92afc117 | ||
|
|
b508945b11 | ||
|
|
7cf788e658 | ||
|
|
06bc38d3f4 | ||
|
|
d3b0212da5 | ||
|
|
c2b79ce14c | ||
|
|
70185b0173 | ||
|
|
a83a0c6146 | ||
|
|
12f41039cc | ||
|
|
b3b5b7e261 | ||
|
|
f706a13230 | ||
|
|
22c6400bb8 | ||
|
|
1ca152f6c8 | ||
|
|
982e255878 | ||
|
|
7899149144 | ||
|
|
bef97b46bf | ||
|
|
cc256fee0e | ||
|
|
ec69a58c8d | ||
|
|
ec67ba61db | ||
|
|
66126996e7 | ||
|
|
4eb66a9198 | ||
|
|
14e41a1fd9 | ||
|
|
fc55522003 | ||
|
|
cd6d8ae9cc | ||
|
|
2933eb594d | ||
|
|
4e08fab3f5 | ||
|
|
8bca7e2aa2 | ||
|
|
3706cf0ad4 | ||
|
|
a459361376 | ||
|
|
bb330d50a6 | ||
|
|
102cb62960 | ||
|
|
8eeab22ecd | ||
|
|
4343852b83 | ||
|
|
0a9bf25bff | ||
|
|
4cd09850b8 | ||
|
|
dbc586e0b2 | ||
|
|
8426f1e7b2 | ||
|
|
c2e3c61f28 | ||
|
|
fbfa29c2ef | ||
|
|
9ee7b951eb | ||
|
|
29dd1bb35b | ||
|
|
68d8a2497e | ||
|
|
4b171fa696 | ||
|
|
d0beb45431 | ||
|
|
e724781a80 | ||
|
|
636ece323f | ||
|
|
77b3281f08 | ||
|
|
bd7c8cd517 | ||
|
|
489d485907 | ||
|
|
6eed5ad531 | ||
|
|
9cb0f63c44 | ||
|
|
2d5786d3bb | ||
|
|
27466ffa1a | ||
|
|
f50b156511 | ||
|
|
9fc73743b2 | ||
|
|
d4393e4170 | ||
|
|
145a0b029e | ||
|
|
f2506cc769 | ||
|
|
7a67fd6a06 | ||
|
|
af36fe8c1e | ||
|
|
e9f16ac8c7 | ||
|
|
6ea183f0d4 | ||
|
|
24f2cde862 | ||
|
|
b18442ded4 | ||
|
|
651c0b39b1 | ||
|
|
46d23cd868 | ||
|
|
dedf0c6ffa | ||
|
|
579082ac10 | ||
|
|
7bc77ddb40 | ||
|
|
026d095afe | ||
|
|
7e2ade50e1 | ||
|
|
c0d54d5414 | ||
|
|
98bfbb73ac | ||
|
|
f9af32a6d1 | ||
|
|
fba40eb1bd | ||
|
|
69f6c24f52 | ||
|
|
80d631118d | ||
|
|
0c6dd32ece | ||
|
|
0bdbfd4d1d | ||
|
|
2e27ed5f3d | ||
|
|
babdc64b17 | ||
|
|
54327ec4a7 | ||
|
|
4a828818da | ||
|
|
fe386252f3 | ||
|
|
182810337c | ||
|
|
338bf808d6 | ||
|
|
5b5a4204a1 | ||
|
|
75ef473748 | ||
|
|
926b8d0efe | ||
|
|
9d9d1761f3 | ||
|
|
a78df8123f | ||
|
|
7ca677578e | ||
|
|
31c456c1e6 | ||
|
|
2ce79b61f5 | ||
|
|
109e3f0e7f | ||
|
|
dc64fec771 | ||
|
|
d1e45585d0 | ||
|
|
aba023e0c5 | ||
|
|
e354c29b52 | ||
|
|
a7f363e654 | ||
|
|
9b2162e564 | ||
|
|
4e64b26702 | ||
|
|
c22d772062 | ||
|
|
d6be7662c9 | ||
|
|
95050088d1 | ||
|
|
94b5084cd5 | ||
|
|
ca0d60bee6 | ||
|
|
fd1f240853 | ||
|
|
381b41a56e | ||
|
|
b58494c420 | ||
|
|
dca30d5462 | ||
|
|
9ab6655491 | ||
|
|
29cfe5a274 |
BIN
docs/assets/gallery/board_settings.png
Normal file
|
After Width: | Height: | Size: 23 KiB |
BIN
docs/assets/gallery/board_tabs.png
Normal file
|
After Width: | Height: | Size: 2.7 KiB |
BIN
docs/assets/gallery/board_thumbnails.png
Normal file
|
After Width: | Height: | Size: 30 KiB |
BIN
docs/assets/gallery/gallery.png
Normal file
|
After Width: | Height: | Size: 221 KiB |
BIN
docs/assets/gallery/image_menu.png
Normal file
|
After Width: | Height: | Size: 53 KiB |
BIN
docs/assets/gallery/info_button.png
Normal file
|
After Width: | Height: | Size: 786 B |
BIN
docs/assets/gallery/thumbnail_menu.png
Normal file
|
After Width: | Height: | Size: 27 KiB |
BIN
docs/assets/gallery/top_controls.png
Normal file
|
After Width: | Height: | Size: 3.3 KiB |
92
docs/features/GALLERY.md
Normal file
@@ -0,0 +1,92 @@
|
||||
---
|
||||
title: InvokeAI Gallery Panel
|
||||
---
|
||||
|
||||
# :material-web: InvokeAI Gallery Panel
|
||||
|
||||
## Quick guided walkthrough of the Gallery Panel's features
|
||||
|
||||
The Gallery Panel is a fast way to review, find, and make use of images you've
|
||||
generated and loaded. The Gallery is divided into Boards. The Uncategorized board is always
|
||||
present but you can create your own for better organization.
|
||||
|
||||

|
||||
|
||||
### Board Display and Settings
|
||||
|
||||
At the very top of the Gallery Panel are the boards disclosure and settings buttons.
|
||||
|
||||

|
||||
|
||||
The disclosure button shows the name of the currently selected board and allows you to show and hide the board thumbnails (shown in the image below).
|
||||
|
||||

|
||||
|
||||
The settings button opens a list of options.
|
||||
|
||||

|
||||
|
||||
- ***Image Size*** this slider lets you control the size of the image previews (images of three different sizes).
|
||||
- ***Auto-Switch to New Images*** if you turn this on, whenever a new image is generated, it will automatically be loaded into the current image panel on the Text to Image tab and into the result panel on the [Image to Image](IMG2IMG.md) tab. This will happen invisibly if you are on any other tab when the image is generated.
|
||||
- ***Auto-Assign Board on Click*** whenever an image is generated or saved, it always gets put in a board. The board it gets put into is marked with AUTO (image of board marked). Turning on Auto-Assign Board on Click will make whichever board you last selected be the destination when you click Invoke. That means you can click Invoke, select a different board, and then click Invoke again and the two images will be put in two different boards. (bold)It's the board selected when Invoke is clicked that's used, not the board that's selected when the image is finished generating.(bold) Turning this off, enables the Auto-Add Board drop down which lets you set one specific board to always put generated images into. This also enables and disables the Auto-add to this Board menu item described below.
|
||||
- ***Always Show Image Size Badge*** this toggles whether to show image sizes for each image preview (show two images, one with sizes shown, one without)
|
||||
|
||||
Below these two buttons, you'll see the Search Boards text entry area. You use this to search for specific boards by the name of the board.
|
||||
Next to it is the Add Board (+) button which lets you add new boards. Boards can be renamed by clicking on the name of the board under its thumbnail and typing in the new name.
|
||||
|
||||
### Board Thumbnail Menu
|
||||
|
||||
Each board has a context menu (ctrl+click / right-click).
|
||||
|
||||

|
||||
|
||||
- ***Auto-add to this Board*** if you've disabled Auto-Assign Board on Click in the board settings, you can use this option to set this board to be where new images are put.
|
||||
- ***Download Board*** this will add all the images in the board into a zip file and provide a link to it in a notification (image of notification)
|
||||
- ***Delete Board*** this will delete the board
|
||||
> [!CAUTION]
|
||||
> This will delete all the images in the board and the board itself.
|
||||
|
||||
### Board Contents
|
||||
|
||||
Every board is organized by two tabs, Images and Assets.
|
||||
|
||||

|
||||
|
||||
Images are the Invoke-generated images that are placed into the board. Assets are images that you upload into Invoke to be used as an [Image Prompt](https://support.invoke.ai/support/solutions/articles/151000159340-using-the-image-prompt-adapter-ip-adapter-) or in the [Image to Image](IMG2IMG.md) tab.
|
||||
|
||||
### Image Thumbnail Menu
|
||||
|
||||
Every image generated by Invoke has its generation information stored as text inside the image file itself. This can be read directly by selecting the image and clicking on the Info button  in any of the image result panels.
|
||||
|
||||
Each image also has a context menu (ctrl+click / right-click).
|
||||
|
||||

|
||||
|
||||
The options are (items marked with an * will not work with images that lack generation information):
|
||||
- ***Open in New Tab*** this will open the image alone in a new browser tab, separate from the Invoke interface.
|
||||
- ***Download Image*** this will trigger your browser to download the image.
|
||||
- ***Load Workflow **** this will load any workflow settings into the Workflow tab and automatically open it.
|
||||
- ***Remix Image **** this will load all of the image's generation information, (bold)excluding its Seed, into the left hand control panel
|
||||
- ***Use Prompt **** this will load only the image's text prompts into the left-hand control panel
|
||||
- ***Use Seed **** this will load only the image's Seed into the left-hand control panel
|
||||
- ***Use All **** this will load all of the image's generation information into the left-hand control panel
|
||||
- ***Send to Image to Image*** this will put the image into the left-hand panel in the Image to Image tab ana automatically open it
|
||||
- ***Send to Unified Canvas*** This will (bold)replace whatever is already present(bold) in the Unified Canvas tab with the image and automatically open the tab
|
||||
- ***Change Board*** this will oipen a small window that will let you move the image to a different board. This is the same as dragging the image to that board's thumbnail.
|
||||
- ***Star Image*** this will add the image to the board's list of starred images that are always kept at the top of the gallery. This is the same as clicking on the star on the top right-hand side of the image that appears when you hover over the image with the mouse
|
||||
- ***Delete Image*** this will delete the image from the board
|
||||
> [!CAUTION]
|
||||
> This will delete the image entirely from Invoke.
|
||||
|
||||
## Summary
|
||||
|
||||
This walkthrough only covers the Gallery interface and Boards. Actually generating images is handled by [Prompts](PROMPTS.md), the [Image to Image](IMG2IMG.md) tab, and the [Unified Canvas](UNIFIED_CANVAS.md).
|
||||
|
||||
## Acknowledgements
|
||||
|
||||
A huge shout-out to the core team working to make the Web GUI a reality,
|
||||
including [psychedelicious](https://github.com/psychedelicious),
|
||||
[Kyle0654](https://github.com/Kyle0654) and
|
||||
[blessedcoolant](https://github.com/blessedcoolant).
|
||||
[hipsterusername](https://github.com/hipsterusername) was the team's unofficial
|
||||
cheerleader and added tooltips/docs.
|
||||
@@ -108,40 +108,6 @@ Can be used with .and():
|
||||
Each will give you different results - try them out and see what you prefer!
|
||||
|
||||
|
||||
|
||||
### Cross-Attention Control ('prompt2prompt')
|
||||
|
||||
Sometimes an image you generate is almost right, and you just want to change one
|
||||
detail without affecting the rest. You could use a photo editor and inpainting
|
||||
to overpaint the area, but that's a pain. Here's where `prompt2prompt` comes in
|
||||
handy.
|
||||
|
||||
Generate an image with a given prompt, record the seed of the image, and then
|
||||
use the `prompt2prompt` syntax to substitute words in the original prompt for
|
||||
words in a new prompt. This works for `img2img` as well.
|
||||
|
||||
For example, consider the prompt `a cat.swap(dog) playing with a ball in the forest`. Normally, because the words interact with each other when doing a stable diffusion image generation, these two prompts would generate different compositions:
|
||||
- `a cat playing with a ball in the forest`
|
||||
- `a dog playing with a ball in the forest`
|
||||
|
||||
| `a cat playing with a ball in the forest` | `a dog playing with a ball in the forest` |
|
||||
| --- | --- |
|
||||
| img | img |
|
||||
|
||||
|
||||
- For multiple word swaps, use parentheses: `a (fluffy cat).swap(barking dog) playing with a ball in the forest`.
|
||||
- To swap a comma, use quotes: `a ("fluffy, grey cat").swap("big, barking dog") playing with a ball in the forest`.
|
||||
- Supports options `t_start` and `t_end` (each 0-1) loosely corresponding to (bloc97's)[(https://github.com/bloc97/CrossAttentionControl)] `prompt_edit_tokens_start/_end` but with the math swapped to make it easier to
|
||||
intuitively understand. `t_start` and `t_end` are used to control on which steps cross-attention control should run. With the default values `t_start=0` and `t_end=1`, cross-attention control is active on every step of image generation. Other values can be used to turn cross-attention control off for part of the image generation process.
|
||||
- For example, if doing a diffusion with 10 steps for the prompt is `a cat.swap(dog, t_start=0.3, t_end=1.0) playing with a ball in the forest`, the first 3 steps will be run as `a cat playing with a ball in the forest`, while the last 7 steps will run as `a dog playing with a ball in the forest`, but the pixels that represent `dog` will be locked to the pixels that would have represented `cat` if the `cat` prompt had been used instead.
|
||||
- Conversely, for `a cat.swap(dog, t_start=0, t_end=0.7) playing with a ball in the forest`, the first 7 steps will run as `a dog playing with a ball in the forest` with the pixels that represent `dog` locked to the same pixels that would have represented `cat` if the `cat` prompt was being used instead. The final 3 steps will just run `a cat playing with a ball in the forest`.
|
||||
> For img2img, the step sequence does not start at 0 but instead at `(1.0-strength)` - so if the img2img `strength` is `0.7`, `t_start` and `t_end` must both be greater than `0.3` (`1.0-0.7`) to have any effect.
|
||||
|
||||
Prompt2prompt `.swap()` is not compatible with xformers, which will be temporarily disabled when doing a `.swap()` - so you should expect to use more VRAM and run slower that with xformers enabled.
|
||||
|
||||
The `prompt2prompt` code is based off
|
||||
[bloc97's colab](https://github.com/bloc97/CrossAttentionControl).
|
||||
|
||||
### Escaping parentheses and speech marks
|
||||
|
||||
If the model you are using has parentheses () or speech marks "" as part of its
|
||||
|
||||
@@ -54,7 +54,7 @@ main sections:
|
||||
of buttons at the top lets you modify and manipulate the image in
|
||||
various ways.
|
||||
|
||||
3. A **gallery** section on the left that contains a history of the images you
|
||||
3. A **gallery** section on the right that contains a history of the images you
|
||||
have generated. These images are read and written to the directory specified
|
||||
in the `INVOKEAIROOT/invokeai.yaml` initialization file, usually a directory
|
||||
named `outputs` in `INVOKEAIROOT`.
|
||||
|
||||
@@ -23,6 +23,7 @@ If you have an interest in how InvokeAI works, or you would like to add features
|
||||
|
||||
1. [Fork and clone] the [InvokeAI repo].
|
||||
1. Follow the [manual installation] docs to create a new virtual environment for the development install.
|
||||
- Create a new folder outside the repo root for the installation and create the venv inside that folder.
|
||||
- When installing the InvokeAI package, add `-e` to the command so you get an [editable install].
|
||||
1. Install the [frontend dev toolchain] and do a production build of the UI as described.
|
||||
1. You can now run the app as described in the [manual installation] docs.
|
||||
|
||||
@@ -28,7 +28,7 @@ from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||
from invokeai.backend.util.devices import get_torch_device_name
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
from ..backend.util.logging import InvokeAILogger
|
||||
from .api.dependencies import ApiDependencies
|
||||
@@ -63,7 +63,7 @@ logger = InvokeAILogger.get_logger(config=app_config)
|
||||
mimetypes.add_type("application/javascript", ".js")
|
||||
mimetypes.add_type("text/css", ".css")
|
||||
|
||||
torch_device_name = get_torch_device_name()
|
||||
torch_device_name = TorchDevice.get_torch_device_name()
|
||||
logger.info(f"Using torch device: {torch_device_name}")
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,15 @@ from compel import Compel, ReturnedEmbeddingsType
|
||||
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent
|
||||
from invokeai.app.invocations.fields import (
|
||||
ConditioningField,
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
OutputField,
|
||||
TensorField,
|
||||
UIComponent,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.ti_utils import generate_ti_list
|
||||
@@ -14,10 +22,9 @@ from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
BasicConditioningInfo,
|
||||
ConditioningFieldData,
|
||||
ExtraConditioningInfo,
|
||||
SDXLConditioningInfo,
|
||||
)
|
||||
from invokeai.backend.util.devices import torch_dtype
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from .model import CLIPField
|
||||
@@ -36,7 +43,7 @@ from .model import CLIPField
|
||||
title="Prompt",
|
||||
tags=["prompt", "compel"],
|
||||
category="conditioning",
|
||||
version="1.1.1",
|
||||
version="1.2.0",
|
||||
)
|
||||
class CompelInvocation(BaseInvocation):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
@@ -51,6 +58,9 @@ class CompelInvocation(BaseInvocation):
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
)
|
||||
mask: Optional[TensorField] = InputField(
|
||||
default=None, description="A mask defining the region that this conditioning prompt applies to."
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
@@ -89,7 +99,7 @@ class CompelInvocation(BaseInvocation):
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
textual_inversion_manager=ti_manager,
|
||||
dtype_for_device_getter=torch_dtype,
|
||||
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
|
||||
truncate_long_prompts=False,
|
||||
)
|
||||
|
||||
@@ -98,27 +108,19 @@ class CompelInvocation(BaseInvocation):
|
||||
if context.config.get().log_tokenization:
|
||||
log_tokenization_for_conjunction(conjunction, tokenizer)
|
||||
|
||||
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
||||
|
||||
ec = ExtraConditioningInfo(
|
||||
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
|
||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
||||
)
|
||||
c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
||||
|
||||
c = c.detach().to("cpu")
|
||||
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[
|
||||
BasicConditioningInfo(
|
||||
embeds=c,
|
||||
extra_conditioning=ec,
|
||||
)
|
||||
]
|
||||
)
|
||||
conditioning_data = ConditioningFieldData(conditionings=[BasicConditioningInfo(embeds=c)])
|
||||
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
|
||||
return ConditioningOutput.build(conditioning_name)
|
||||
return ConditioningOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
mask=self.mask,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class SDXLPromptInvocationBase:
|
||||
@@ -132,7 +134,7 @@ class SDXLPromptInvocationBase:
|
||||
get_pooled: bool,
|
||||
lora_prefix: str,
|
||||
zero_on_empty: bool,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]:
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
tokenizer_info = context.models.load(clip_field.tokenizer)
|
||||
tokenizer_model = tokenizer_info.model
|
||||
assert isinstance(tokenizer_model, CLIPTokenizer)
|
||||
@@ -159,7 +161,7 @@ class SDXLPromptInvocationBase:
|
||||
)
|
||||
else:
|
||||
c_pooled = None
|
||||
return c, c_pooled, None
|
||||
return c, c_pooled
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in clip_field.loras:
|
||||
@@ -191,7 +193,7 @@ class SDXLPromptInvocationBase:
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
textual_inversion_manager=ti_manager,
|
||||
dtype_for_device_getter=torch_dtype,
|
||||
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
|
||||
truncate_long_prompts=False, # TODO:
|
||||
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
|
||||
requires_pooled=get_pooled,
|
||||
@@ -204,17 +206,12 @@ class SDXLPromptInvocationBase:
|
||||
log_tokenization_for_conjunction(conjunction, tokenizer)
|
||||
|
||||
# TODO: ask for optimizations? to not run text_encoder twice
|
||||
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
||||
c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
||||
if get_pooled:
|
||||
c_pooled = compel.conditioning_provider.get_pooled_embeddings([prompt])
|
||||
else:
|
||||
c_pooled = None
|
||||
|
||||
ec = ExtraConditioningInfo(
|
||||
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
|
||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
||||
)
|
||||
|
||||
del tokenizer
|
||||
del text_encoder
|
||||
del tokenizer_info
|
||||
@@ -224,7 +221,7 @@ class SDXLPromptInvocationBase:
|
||||
if c_pooled is not None:
|
||||
c_pooled = c_pooled.detach().to("cpu")
|
||||
|
||||
return c, c_pooled, ec
|
||||
return c, c_pooled
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -232,7 +229,7 @@ class SDXLPromptInvocationBase:
|
||||
title="SDXL Prompt",
|
||||
tags=["sdxl", "compel", "prompt"],
|
||||
category="conditioning",
|
||||
version="1.1.1",
|
||||
version="1.2.0",
|
||||
)
|
||||
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
@@ -255,20 +252,19 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
target_height: int = InputField(default=1024, description="")
|
||||
clip: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
|
||||
clip2: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
|
||||
mask: Optional[TensorField] = InputField(
|
||||
default=None, description="A mask defining the region that this conditioning prompt applies to."
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
c1, c1_pooled, ec1 = self.run_clip_compel(
|
||||
context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True
|
||||
)
|
||||
c1, c1_pooled = self.run_clip_compel(context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True)
|
||||
if self.style.strip() == "":
|
||||
c2, c2_pooled, ec2 = self.run_clip_compel(
|
||||
c2, c2_pooled = self.run_clip_compel(
|
||||
context, self.clip2, self.prompt, True, "lora_te2_", zero_on_empty=True
|
||||
)
|
||||
else:
|
||||
c2, c2_pooled, ec2 = self.run_clip_compel(
|
||||
context, self.clip2, self.style, True, "lora_te2_", zero_on_empty=True
|
||||
)
|
||||
c2, c2_pooled = self.run_clip_compel(context, self.clip2, self.style, True, "lora_te2_", zero_on_empty=True)
|
||||
|
||||
original_size = (self.original_height, self.original_width)
|
||||
crop_coords = (self.crop_top, self.crop_left)
|
||||
@@ -307,17 +303,19 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[
|
||||
SDXLConditioningInfo(
|
||||
embeds=torch.cat([c1, c2], dim=-1),
|
||||
pooled_embeds=c2_pooled,
|
||||
add_time_ids=add_time_ids,
|
||||
extra_conditioning=ec1,
|
||||
embeds=torch.cat([c1, c2], dim=-1), pooled_embeds=c2_pooled, add_time_ids=add_time_ids
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
|
||||
return ConditioningOutput.build(conditioning_name)
|
||||
return ConditioningOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
mask=self.mask,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -345,7 +343,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
# TODO: if there will appear lora for refiner - write proper prefix
|
||||
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "<NONE>", zero_on_empty=False)
|
||||
c2, c2_pooled = self.run_clip_compel(context, self.clip2, self.style, True, "<NONE>", zero_on_empty=False)
|
||||
|
||||
original_size = (self.original_height, self.original_width)
|
||||
crop_coords = (self.crop_top, self.crop_left)
|
||||
@@ -354,14 +352,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
||||
|
||||
assert c2_pooled is not None
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[
|
||||
SDXLConditioningInfo(
|
||||
embeds=c2,
|
||||
pooled_embeds=c2_pooled,
|
||||
add_time_ids=add_time_ids,
|
||||
extra_conditioning=ec2, # or None
|
||||
)
|
||||
]
|
||||
conditionings=[SDXLConditioningInfo(embeds=c2, pooled_embeds=c2_pooled, add_time_ids=add_time_ids)]
|
||||
)
|
||||
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
|
||||
@@ -203,6 +203,12 @@ class DenoiseMaskField(BaseModel):
|
||||
gradient: bool = Field(default=False, description="Used for gradient inpainting")
|
||||
|
||||
|
||||
class TensorField(BaseModel):
|
||||
"""A tensor primitive field."""
|
||||
|
||||
tensor_name: str = Field(description="The name of a tensor.")
|
||||
|
||||
|
||||
class LatentsField(BaseModel):
|
||||
"""A latents tensor primitive field"""
|
||||
|
||||
@@ -226,7 +232,11 @@ class ConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||
# endregion
|
||||
mask: Optional[TensorField] = Field(
|
||||
default=None,
|
||||
description="The mask associated with this conditioning tensor. Excluded regions should be set to False, "
|
||||
"included regions should be set to True.",
|
||||
)
|
||||
|
||||
|
||||
class MetadataField(RootModel[dict[str, Any]]):
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from builtins import float
|
||||
from typing import List, Literal, Union
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, TensorField, UIType
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
@@ -23,13 +23,19 @@ class IPAdapterField(BaseModel):
|
||||
image: Union[ImageField, List[ImageField]] = Field(description="The IP-Adapter image prompt(s).")
|
||||
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model to use.")
|
||||
image_encoder_model: ModelIdentifierField = Field(description="The name of the CLIP image encoder model.")
|
||||
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the IP-Adapter.")
|
||||
target_blocks: List[str] = Field(default=[], description="The IP Adapter blocks to apply")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = Field(
|
||||
default=1, ge=0, le=1, description="When the IP-Adapter is last applied (% of total steps)"
|
||||
)
|
||||
mask: Optional[TensorField] = Field(
|
||||
default=None,
|
||||
description="The bool mask associated with this IP-Adapter. Excluded regions should be set to False, included "
|
||||
"regions should be set to True.",
|
||||
)
|
||||
|
||||
@field_validator("weight")
|
||||
@classmethod
|
||||
@@ -52,7 +58,7 @@ class IPAdapterOutput(BaseInvocationOutput):
|
||||
CLIP_VISION_MODEL_MAP = {"ViT-H": "ip_adapter_sd_image_encoder", "ViT-G": "ip_adapter_sdxl_image_encoder"}
|
||||
|
||||
|
||||
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.2.2")
|
||||
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.4.0")
|
||||
class IPAdapterInvocation(BaseInvocation):
|
||||
"""Collects IP-Adapter info to pass to other nodes."""
|
||||
|
||||
@@ -73,12 +79,18 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
weight: Union[float, List[float]] = InputField(
|
||||
default=1, description="The weight given to the IP-Adapter", title="Weight"
|
||||
)
|
||||
method: Literal["full", "style", "composition"] = InputField(
|
||||
default="full", description="The method to apply the IP-Adapter"
|
||||
)
|
||||
begin_step_percent: float = InputField(
|
||||
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = InputField(
|
||||
default=1, ge=0, le=1, description="When the IP-Adapter is last applied (% of total steps)"
|
||||
)
|
||||
mask: Optional[TensorField] = InputField(
|
||||
default=None, description="A mask defining the region that this IP-Adapter applies to."
|
||||
)
|
||||
|
||||
@field_validator("weight")
|
||||
@classmethod
|
||||
@@ -104,14 +116,35 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
|
||||
image_encoder_model = self._get_image_encoder(context, image_encoder_model_name)
|
||||
|
||||
if self.method == "style":
|
||||
if ip_adapter_info.base == "sd-1":
|
||||
target_blocks = ["up_blocks.1"]
|
||||
elif ip_adapter_info.base == "sdxl":
|
||||
target_blocks = ["up_blocks.0.attentions.1"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported IP-Adapter base type: '{ip_adapter_info.base}'.")
|
||||
elif self.method == "composition":
|
||||
if ip_adapter_info.base == "sd-1":
|
||||
target_blocks = ["down_blocks.2", "mid_block"]
|
||||
elif ip_adapter_info.base == "sdxl":
|
||||
target_blocks = ["down_blocks.2.attentions.1"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported IP-Adapter base type: '{ip_adapter_info.base}'.")
|
||||
elif self.method == "full":
|
||||
target_blocks = ["block"]
|
||||
else:
|
||||
raise ValueError(f"Unexpected IP-Adapter method: '{self.method}'.")
|
||||
|
||||
return IPAdapterOutput(
|
||||
ip_adapter=IPAdapterField(
|
||||
image=self.image,
|
||||
ip_adapter_model=self.ip_adapter_model,
|
||||
image_encoder_model=ModelIdentifierField.from_config(image_encoder_model),
|
||||
weight=self.weight,
|
||||
target_blocks=target_blocks,
|
||||
begin_step_percent=self.begin_step_percent,
|
||||
end_step_percent=self.end_step_percent,
|
||||
mask=self.mask,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import inspect
|
||||
import math
|
||||
from contextlib import ExitStack
|
||||
from functools import singledispatchmethod
|
||||
@@ -9,6 +9,7 @@ import einops
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.transforms as T
|
||||
from diffusers import AutoencoderKL, AutoencoderTiny
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
@@ -52,26 +53,31 @@ from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import BaseModelType, LoadedModel
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
BasicConditioningInfo,
|
||||
IPAdapterConditioningInfo,
|
||||
IPAdapterData,
|
||||
Range,
|
||||
SDXLConditioningInfo,
|
||||
TextConditioningData,
|
||||
TextConditioningRegions,
|
||||
)
|
||||
from invokeai.backend.util.mask import to_standard_float_mask
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||
ControlNetData,
|
||||
IPAdapterData,
|
||||
StableDiffusionGeneratorPipeline,
|
||||
T2IAdapterData,
|
||||
image_resized_to_grid_as_tensor,
|
||||
)
|
||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from ...backend.util.devices import choose_precision, choose_torch_device
|
||||
from ...backend.util.devices import TorchDevice
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from .controlnet_image_processors import ControlField
|
||||
from .model import ModelIdentifierField, UNetField, VAEField
|
||||
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
from torch import mps
|
||||
|
||||
DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
||||
DEFAULT_PRECISION = TorchDevice.choose_torch_dtype()
|
||||
|
||||
|
||||
@invocation_output("scheduler_output")
|
||||
@@ -275,10 +281,10 @@ def get_scheduler(
|
||||
class DenoiseLatentsInvocation(BaseInvocation):
|
||||
"""Denoises noisy latents to decodable images"""
|
||||
|
||||
positive_conditioning: ConditioningField = InputField(
|
||||
positive_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0
|
||||
)
|
||||
negative_conditioning: ConditioningField = InputField(
|
||||
negative_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField(
|
||||
description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=1
|
||||
)
|
||||
noise: Optional[LatentsField] = InputField(
|
||||
@@ -356,33 +362,168 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
return v
|
||||
|
||||
def _get_text_embeddings_and_masks(
|
||||
self,
|
||||
cond_list: list[ConditioningField],
|
||||
context: InvocationContext,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
) -> tuple[Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], list[Optional[torch.Tensor]]]:
|
||||
"""Get the text embeddings and masks from the input conditioning fields."""
|
||||
text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = []
|
||||
text_embeddings_masks: list[Optional[torch.Tensor]] = []
|
||||
for cond in cond_list:
|
||||
cond_data = context.conditioning.load(cond.conditioning_name)
|
||||
text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype))
|
||||
|
||||
mask = cond.mask
|
||||
if mask is not None:
|
||||
mask = context.tensors.load(mask.tensor_name)
|
||||
text_embeddings_masks.append(mask)
|
||||
|
||||
return text_embeddings, text_embeddings_masks
|
||||
|
||||
def _preprocess_regional_prompt_mask(
|
||||
self, mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype
|
||||
) -> torch.Tensor:
|
||||
"""Preprocess a regional prompt mask to match the target height and width.
|
||||
If mask is None, returns a mask of all ones with the target height and width.
|
||||
If mask is not None, resizes the mask to the target height and width using 'nearest' interpolation.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The processed mask. shape: (1, 1, target_height, target_width).
|
||||
"""
|
||||
|
||||
if mask is None:
|
||||
return torch.ones((1, 1, target_height, target_width), dtype=dtype)
|
||||
|
||||
mask = to_standard_float_mask(mask, out_dtype=dtype)
|
||||
|
||||
tf = torchvision.transforms.Resize(
|
||||
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
|
||||
)
|
||||
|
||||
# Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w).
|
||||
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
|
||||
resized_mask = tf(mask)
|
||||
return resized_mask
|
||||
|
||||
def _concat_regional_text_embeddings(
|
||||
self,
|
||||
text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]],
|
||||
masks: Optional[list[Optional[torch.Tensor]]],
|
||||
latent_height: int,
|
||||
latent_width: int,
|
||||
dtype: torch.dtype,
|
||||
) -> tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[TextConditioningRegions]]:
|
||||
"""Concatenate regional text embeddings into a single embedding and track the region masks accordingly."""
|
||||
if masks is None:
|
||||
masks = [None] * len(text_conditionings)
|
||||
assert len(text_conditionings) == len(masks)
|
||||
|
||||
is_sdxl = type(text_conditionings[0]) is SDXLConditioningInfo
|
||||
|
||||
all_masks_are_none = all(mask is None for mask in masks)
|
||||
|
||||
text_embedding = []
|
||||
pooled_embedding = None
|
||||
add_time_ids = None
|
||||
cur_text_embedding_len = 0
|
||||
processed_masks = []
|
||||
embedding_ranges = []
|
||||
|
||||
for prompt_idx, text_embedding_info in enumerate(text_conditionings):
|
||||
mask = masks[prompt_idx]
|
||||
|
||||
if is_sdxl:
|
||||
# We choose a random SDXLConditioningInfo's pooled_embeds and add_time_ids here, with a preference for
|
||||
# prompts without a mask. We prefer prompts without a mask, because they are more likely to contain
|
||||
# global prompt information. In an ideal case, there should be exactly one global prompt without a
|
||||
# mask, but we don't enforce this.
|
||||
|
||||
# HACK(ryand): The fact that we have to choose a single pooled_embedding and add_time_ids here is a
|
||||
# fundamental interface issue. The SDXL Compel nodes are not designed to be used in the way that we use
|
||||
# them for regional prompting. Ideally, the DenoiseLatents invocation should accept a single
|
||||
# pooled_embeds tensor and a list of standard text embeds with region masks. This change would be a
|
||||
# pretty major breaking change to a popular node, so for now we use this hack.
|
||||
if pooled_embedding is None or mask is None:
|
||||
pooled_embedding = text_embedding_info.pooled_embeds
|
||||
if add_time_ids is None or mask is None:
|
||||
add_time_ids = text_embedding_info.add_time_ids
|
||||
|
||||
text_embedding.append(text_embedding_info.embeds)
|
||||
if not all_masks_are_none:
|
||||
embedding_ranges.append(
|
||||
Range(
|
||||
start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1]
|
||||
)
|
||||
)
|
||||
processed_masks.append(
|
||||
self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype)
|
||||
)
|
||||
|
||||
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
|
||||
|
||||
text_embedding = torch.cat(text_embedding, dim=1)
|
||||
assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len
|
||||
|
||||
regions = None
|
||||
if not all_masks_are_none:
|
||||
regions = TextConditioningRegions(
|
||||
masks=torch.cat(processed_masks, dim=1),
|
||||
ranges=embedding_ranges,
|
||||
)
|
||||
|
||||
if is_sdxl:
|
||||
return SDXLConditioningInfo(
|
||||
embeds=text_embedding, pooled_embeds=pooled_embedding, add_time_ids=add_time_ids
|
||||
), regions
|
||||
return BasicConditioningInfo(embeds=text_embedding), regions
|
||||
|
||||
def get_conditioning_data(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
scheduler: Scheduler,
|
||||
unet: UNet2DConditionModel,
|
||||
seed: int,
|
||||
) -> ConditioningData:
|
||||
positive_cond_data = context.conditioning.load(self.positive_conditioning.conditioning_name)
|
||||
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||
latent_height: int,
|
||||
latent_width: int,
|
||||
) -> TextConditioningData:
|
||||
# Normalize self.positive_conditioning and self.negative_conditioning to lists.
|
||||
cond_list = self.positive_conditioning
|
||||
if not isinstance(cond_list, list):
|
||||
cond_list = [cond_list]
|
||||
uncond_list = self.negative_conditioning
|
||||
if not isinstance(uncond_list, list):
|
||||
uncond_list = [uncond_list]
|
||||
|
||||
negative_cond_data = context.conditioning.load(self.negative_conditioning.conditioning_name)
|
||||
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
conditioning_data = ConditioningData(
|
||||
unconditioned_embeddings=uc,
|
||||
text_embeddings=c,
|
||||
guidance_scale=self.cfg_scale,
|
||||
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||
cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks(
|
||||
cond_list, context, unet.device, unet.dtype
|
||||
)
|
||||
uncond_text_embeddings, uncond_text_embedding_masks = self._get_text_embeddings_and_masks(
|
||||
uncond_list, context, unet.device, unet.dtype
|
||||
)
|
||||
|
||||
conditioning_data = conditioning_data.add_scheduler_args_if_applicable( # FIXME
|
||||
scheduler,
|
||||
# for ddim scheduler
|
||||
eta=0.0, # ddim_eta
|
||||
# for ancestral and sde schedulers
|
||||
# flip all bits to have noise different from initial
|
||||
generator=torch.Generator(device=unet.device).manual_seed(seed ^ 0xFFFFFFFF),
|
||||
cond_text_embedding, cond_regions = self._concat_regional_text_embeddings(
|
||||
text_conditionings=cond_text_embeddings,
|
||||
masks=cond_text_embedding_masks,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
dtype=unet.dtype,
|
||||
)
|
||||
uncond_text_embedding, uncond_regions = self._concat_regional_text_embeddings(
|
||||
text_conditionings=uncond_text_embeddings,
|
||||
masks=uncond_text_embedding_masks,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
dtype=unet.dtype,
|
||||
)
|
||||
|
||||
conditioning_data = TextConditioningData(
|
||||
uncond_text=uncond_text_embedding,
|
||||
cond_text=cond_text_embedding,
|
||||
uncond_regions=uncond_regions,
|
||||
cond_regions=cond_regions,
|
||||
guidance_scale=self.cfg_scale,
|
||||
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||
)
|
||||
return conditioning_data
|
||||
|
||||
@@ -488,8 +629,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
self,
|
||||
context: InvocationContext,
|
||||
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]],
|
||||
conditioning_data: ConditioningData,
|
||||
exit_stack: ExitStack,
|
||||
latent_height: int,
|
||||
latent_width: int,
|
||||
dtype: torch.dtype,
|
||||
) -> Optional[list[IPAdapterData]]:
|
||||
"""If IP-Adapter is enabled, then this function loads the requisite models, and adds the image prompt embeddings
|
||||
to the `conditioning_data` (in-place).
|
||||
@@ -505,7 +648,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
return None
|
||||
|
||||
ip_adapter_data_list = []
|
||||
conditioning_data.ip_adapter_conditioning = []
|
||||
for single_ip_adapter in ip_adapter:
|
||||
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
|
||||
context.models.load(single_ip_adapter.ip_adapter_model)
|
||||
@@ -528,16 +670,20 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
single_ipa_images, image_encoder_model
|
||||
)
|
||||
|
||||
conditioning_data.ip_adapter_conditioning.append(
|
||||
IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds)
|
||||
)
|
||||
mask = single_ip_adapter.mask
|
||||
if mask is not None:
|
||||
mask = context.tensors.load(mask.tensor_name)
|
||||
mask = self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype)
|
||||
|
||||
ip_adapter_data_list.append(
|
||||
IPAdapterData(
|
||||
ip_adapter_model=ip_adapter_model,
|
||||
weight=single_ip_adapter.weight,
|
||||
target_blocks=single_ip_adapter.target_blocks,
|
||||
begin_step_percent=single_ip_adapter.begin_step_percent,
|
||||
end_step_percent=single_ip_adapter.end_step_percent,
|
||||
ip_adapter_conditioning=IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds),
|
||||
mask=mask,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -627,6 +773,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
steps: int,
|
||||
denoising_start: float,
|
||||
denoising_end: float,
|
||||
seed: int,
|
||||
) -> Tuple[int, List[int], int]:
|
||||
assert isinstance(scheduler, ConfigMixin)
|
||||
if scheduler.config.get("cpu_only", False):
|
||||
@@ -655,7 +802,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx]
|
||||
num_inference_steps = len(timesteps) // scheduler.order
|
||||
|
||||
return num_inference_steps, timesteps, init_timestep
|
||||
scheduler_step_kwargs = {}
|
||||
scheduler_step_signature = inspect.signature(scheduler.step)
|
||||
if "generator" in scheduler_step_signature.parameters:
|
||||
# At some point, someone decided that schedulers that accept a generator should use the original seed with
|
||||
# all bits flipped. I don't know the original rationale for this, but now we must keep it like this for
|
||||
# reproducibility.
|
||||
scheduler_step_kwargs = {"generator": torch.Generator(device=device).manual_seed(seed ^ 0xFFFFFFFF)}
|
||||
|
||||
return num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs
|
||||
|
||||
def prep_inpaint_mask(
|
||||
self, context: InvocationContext, latents: torch.Tensor
|
||||
@@ -749,7 +904,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
pipeline = self.create_pipeline(unet, scheduler)
|
||||
conditioning_data = self.get_conditioning_data(context, scheduler, unet, seed)
|
||||
|
||||
_, _, latent_height, latent_width = latents.shape
|
||||
conditioning_data = self.get_conditioning_data(
|
||||
context=context, unet=unet, latent_height=latent_height, latent_width=latent_width
|
||||
)
|
||||
|
||||
controlnet_data = self.prep_control_data(
|
||||
context=context,
|
||||
@@ -763,16 +922,19 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
ip_adapter_data = self.prep_ip_adapter_data(
|
||||
context=context,
|
||||
ip_adapter=self.ip_adapter,
|
||||
conditioning_data=conditioning_data,
|
||||
exit_stack=exit_stack,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
dtype=unet.dtype,
|
||||
)
|
||||
|
||||
num_inference_steps, timesteps, init_timestep = self.init_scheduler(
|
||||
num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
|
||||
scheduler,
|
||||
device=unet.device,
|
||||
steps=self.steps,
|
||||
denoising_start=self.denoising_start,
|
||||
denoising_end=self.denoising_end,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
result_latents = pipeline.latents_from_embeddings(
|
||||
@@ -785,6 +947,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
masked_latents=masked_latents,
|
||||
gradient_mask=gradient_mask,
|
||||
num_inference_steps=num_inference_steps,
|
||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
conditioning_data=conditioning_data,
|
||||
control_data=controlnet_data,
|
||||
ip_adapter_data=ip_adapter_data,
|
||||
@@ -794,12 +957,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
result_latents = result_latents.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
name = context.tensors.save(tensor=result_latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=seed)
|
||||
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -863,9 +1024,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
vae.disable_tiling()
|
||||
|
||||
# clear memory as vae decode can request a lot
|
||||
torch.cuda.empty_cache()
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
with torch.inference_mode():
|
||||
# copied from diffusers pipeline
|
||||
@@ -877,9 +1036,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
image = VaeImageProcessor.numpy_to_pil(np_image)[0]
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
image_dto = context.images.save(image=image)
|
||||
|
||||
@@ -918,9 +1075,7 @@ class ResizeLatentsInvocation(BaseInvocation):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
# TODO:
|
||||
device = choose_torch_device()
|
||||
device = TorchDevice.choose_torch_device()
|
||||
|
||||
resized_latents = torch.nn.functional.interpolate(
|
||||
latents.to(device),
|
||||
@@ -931,9 +1086,8 @@ class ResizeLatentsInvocation(BaseInvocation):
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
resized_latents = resized_latents.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
if device == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
name = context.tensors.save(tensor=resized_latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||
@@ -960,8 +1114,7 @@ class ScaleLatentsInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
# TODO:
|
||||
device = choose_torch_device()
|
||||
device = TorchDevice.choose_torch_device()
|
||||
|
||||
# resizing
|
||||
resized_latents = torch.nn.functional.interpolate(
|
||||
@@ -973,9 +1126,7 @@ class ScaleLatentsInvocation(BaseInvocation):
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
resized_latents = resized_latents.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
if device == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
name = context.tensors.save(tensor=resized_latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||
@@ -1107,8 +1258,7 @@ class BlendLatentsInvocation(BaseInvocation):
|
||||
if latents_a.shape != latents_b.shape:
|
||||
raise Exception("Latents to blend must be the same size.")
|
||||
|
||||
# TODO:
|
||||
device = choose_torch_device()
|
||||
device = TorchDevice.choose_torch_device()
|
||||
|
||||
def slerp(
|
||||
t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here?
|
||||
@@ -1161,9 +1311,8 @@ class BlendLatentsInvocation(BaseInvocation):
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
blended_latents = blended_latents.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
if device == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
name = context.tensors.save(tensor=blended_latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=blended_latents)
|
||||
|
||||
36
invokeai/app/invocations/mask.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import torch
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, InvocationContext, invocation
|
||||
from invokeai.app.invocations.fields import InputField, TensorField, WithMetadata
|
||||
from invokeai.app.invocations.primitives import MaskOutput
|
||||
|
||||
|
||||
@invocation(
|
||||
"rectangle_mask",
|
||||
title="Create Rectangle Mask",
|
||||
tags=["conditioning"],
|
||||
category="conditioning",
|
||||
version="1.0.1",
|
||||
)
|
||||
class RectangleMaskInvocation(BaseInvocation, WithMetadata):
|
||||
"""Create a rectangular mask."""
|
||||
|
||||
width: int = InputField(description="The width of the entire mask.")
|
||||
height: int = InputField(description="The height of the entire mask.")
|
||||
x_left: int = InputField(description="The left x-coordinate of the rectangular masked region (inclusive).")
|
||||
y_top: int = InputField(description="The top y-coordinate of the rectangular masked region (inclusive).")
|
||||
rectangle_width: int = InputField(description="The width of the rectangular masked region.")
|
||||
rectangle_height: int = InputField(description="The height of the rectangular masked region.")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||
mask = torch.zeros((1, self.height, self.width), dtype=torch.bool)
|
||||
mask[:, self.y_top : self.y_top + self.rectangle_height, self.x_left : self.x_left + self.rectangle_width] = (
|
||||
True
|
||||
)
|
||||
|
||||
mask_tensor_name = context.tensors.save(mask)
|
||||
return MaskOutput(
|
||||
mask=TensorField(tensor_name=mask_tensor_name),
|
||||
width=self.width,
|
||||
height=self.height,
|
||||
)
|
||||
@@ -36,6 +36,7 @@ class IPAdapterMetadataField(BaseModel):
|
||||
image: ImageField = Field(description="The IP-Adapter image prompt.")
|
||||
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model.")
|
||||
clip_vision_model: Literal["ViT-H", "ViT-G"] = Field(description="The CLIP Vision model")
|
||||
method: Literal["full", "style", "composition"] = Field(description="Method to apply IP Weights with")
|
||||
weight: Union[float, list[float]] = Field(description="The weight given to the IP-Adapter")
|
||||
begin_step_percent: float = Field(description="When the IP-Adapter is first applied (% of total steps)")
|
||||
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")
|
||||
|
||||
@@ -9,7 +9,7 @@ from invokeai.app.invocations.fields import FieldDescriptions, InputField, Laten
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.misc import SEED_MAX
|
||||
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
from ...backend.util.devices import TorchDevice
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
@@ -46,7 +46,7 @@ def get_noise(
|
||||
height // downsampling_factor,
|
||||
width // downsampling_factor,
|
||||
],
|
||||
dtype=torch_dtype(device),
|
||||
dtype=TorchDevice.choose_torch_dtype(device=device),
|
||||
device=noise_device_type,
|
||||
generator=generator,
|
||||
).to("cpu")
|
||||
@@ -111,14 +111,14 @@ class NoiseInvocation(BaseInvocation):
|
||||
|
||||
@field_validator("seed", mode="before")
|
||||
def modulo_seed(cls, v):
|
||||
"""Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
|
||||
"""Return the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
|
||||
return v % (SEED_MAX + 1)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> NoiseOutput:
|
||||
noise = get_noise(
|
||||
width=self.width,
|
||||
height=self.height,
|
||||
device=choose_torch_device(),
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
seed=self.seed,
|
||||
use_cpu=self.use_cpu,
|
||||
)
|
||||
|
||||
@@ -15,6 +15,7 @@ from invokeai.app.invocations.fields import (
|
||||
InputField,
|
||||
LatentsField,
|
||||
OutputField,
|
||||
TensorField,
|
||||
UIComponent,
|
||||
)
|
||||
from invokeai.app.services.images.images_common import ImageDTO
|
||||
@@ -405,9 +406,19 @@ class ColorInvocation(BaseInvocation):
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region Conditioning
|
||||
|
||||
|
||||
@invocation_output("mask_output")
|
||||
class MaskOutput(BaseInvocationOutput):
|
||||
"""A torch mask tensor."""
|
||||
|
||||
mask: TensorField = OutputField(description="The mask.")
|
||||
width: int = OutputField(description="The width of the mask in pixels.")
|
||||
height: int = OutputField(description="The height of the mask in pixels.")
|
||||
|
||||
|
||||
@invocation_output("conditioning_output")
|
||||
class ConditioningOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single conditioning tensor"""
|
||||
|
||||
@@ -4,7 +4,6 @@ from typing import Literal
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from pydantic import ConfigDict
|
||||
|
||||
@@ -14,7 +13,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
||||
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
from .baseinvocation import BaseInvocation, invocation
|
||||
from .fields import InputField, WithBoard, WithMetadata
|
||||
@@ -35,9 +34,6 @@ ESRGAN_MODEL_URLS: dict[str, str] = {
|
||||
"RealESRGAN_x2plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||
}
|
||||
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
from torch import mps
|
||||
|
||||
|
||||
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.2")
|
||||
class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
@@ -120,9 +116,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
upscaled_image = upscaler.upscale(cv2_image)
|
||||
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
image_dto = context.images.save(image=pil_image)
|
||||
|
||||
|
||||
@@ -27,12 +27,12 @@ DEFAULT_RAM_CACHE = 10.0
|
||||
DEFAULT_VRAM_CACHE = 0.25
|
||||
DEFAULT_CONVERT_CACHE = 20.0
|
||||
DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
|
||||
PRECISION = Literal["auto", "float16", "bfloat16", "float32", "autocast"]
|
||||
PRECISION = Literal["auto", "float16", "bfloat16", "float32"]
|
||||
ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
|
||||
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
|
||||
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
|
||||
LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
|
||||
CONFIG_SCHEMA_VERSION = "4.0.0"
|
||||
CONFIG_SCHEMA_VERSION = "4.0.1"
|
||||
|
||||
|
||||
def get_default_ram_cache_size() -> float:
|
||||
@@ -105,7 +105,7 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
lazy_offload: Keep models in VRAM until their space is needed.
|
||||
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
|
||||
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
|
||||
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`, `autocast`
|
||||
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
|
||||
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
|
||||
attention_type: Attention type.<br>Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp`
|
||||
attention_slice_size: Slice size, valid when attention_type=="sliced".<br>Valid values: `auto`, `balanced`, `max`, `1`, `2`, `3`, `4`, `5`, `6`, `7`, `8`
|
||||
@@ -370,6 +370,9 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
|
||||
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
|
||||
if k == "max_vram_cache_size" and "vram" not in category_dict:
|
||||
parsed_config_dict["vram"] = v
|
||||
# autocast was removed in v4.0.1
|
||||
if k == "precision" and v == "autocast":
|
||||
parsed_config_dict["precision"] = "auto"
|
||||
if k == "conf_path":
|
||||
parsed_config_dict["legacy_models_yaml_path"] = v
|
||||
if k == "legacy_conf_dir":
|
||||
@@ -392,6 +395,28 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
|
||||
return config
|
||||
|
||||
|
||||
def migrate_v4_0_0_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
|
||||
"""Migrate v4.0.0 config dictionary to a current config object.
|
||||
|
||||
Args:
|
||||
config_dict: A dictionary of settings from a v4.0.0 config file.
|
||||
|
||||
Returns:
|
||||
An instance of `InvokeAIAppConfig` with the migrated settings.
|
||||
"""
|
||||
parsed_config_dict: dict[str, Any] = {}
|
||||
for k, v in config_dict.items():
|
||||
# autocast was removed from precision in v4.0.1
|
||||
if k == "precision" and v == "autocast":
|
||||
parsed_config_dict["precision"] = "auto"
|
||||
else:
|
||||
parsed_config_dict[k] = v
|
||||
if k == "schema_version":
|
||||
parsed_config_dict[k] = CONFIG_SCHEMA_VERSION
|
||||
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
|
||||
return config
|
||||
|
||||
|
||||
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
||||
"""Load and migrate a config file to the latest version.
|
||||
|
||||
@@ -418,17 +443,21 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
||||
raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e
|
||||
migrated_config.write_file(config_path)
|
||||
return migrated_config
|
||||
else:
|
||||
# Attempt to load as a v4 config file
|
||||
try:
|
||||
# Meta is not included in the model fields, so we need to validate it separately
|
||||
config = InvokeAIAppConfig.model_validate(loaded_config_dict)
|
||||
assert (
|
||||
config.schema_version == CONFIG_SCHEMA_VERSION
|
||||
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}"
|
||||
return config
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
|
||||
|
||||
if loaded_config_dict["schema_version"] == "4.0.0":
|
||||
loaded_config_dict = migrate_v4_0_0_config_dict(loaded_config_dict)
|
||||
loaded_config_dict.write_file(config_path)
|
||||
|
||||
# Attempt to load as a v4 config file
|
||||
try:
|
||||
# Meta is not included in the model fields, so we need to validate it separately
|
||||
config = InvokeAIAppConfig.model_validate(loaded_config_dict)
|
||||
assert (
|
||||
config.schema_version == CONFIG_SCHEMA_VERSION
|
||||
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}"
|
||||
return config
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
|
||||
@@ -13,6 +13,7 @@ from shutil import copyfile, copytree, move, rmtree
|
||||
from tempfile import mkdtemp
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
from huggingface_hub import HfFolder
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
@@ -42,7 +43,7 @@ from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMet
|
||||
from invokeai.backend.model_manager.probe import ModelProbe
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
from invokeai.backend.util import InvokeAILogger
|
||||
from invokeai.backend.util.devices import choose_precision, choose_torch_device
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
from .model_install_base import (
|
||||
MODEL_SOURCE_TO_TYPE_MAP,
|
||||
@@ -634,11 +635,10 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._next_job_id += 1
|
||||
return id
|
||||
|
||||
@staticmethod
|
||||
def _guess_variant() -> Optional[ModelRepoVariant]:
|
||||
def _guess_variant(self) -> Optional[ModelRepoVariant]:
|
||||
"""Guess the best HuggingFace variant type to download."""
|
||||
precision = choose_precision(choose_torch_device())
|
||||
return ModelRepoVariant.FP16 if precision == "float16" else None
|
||||
precision = TorchDevice.choose_torch_dtype()
|
||||
return ModelRepoVariant.FP16 if precision == torch.float16 else None
|
||||
|
||||
def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||
return ModelInstallJob(
|
||||
@@ -754,6 +754,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._download_cache[download_job.source] = install_job # matches a download job to an install job
|
||||
install_job.download_parts.add(download_job)
|
||||
|
||||
# only start the jobs once install_job.download_parts is fully populated
|
||||
for download_job in install_job.download_parts:
|
||||
self._download_queue.submit_download_job(
|
||||
download_job,
|
||||
on_start=self._download_started_callback,
|
||||
@@ -762,6 +764,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
on_error=self._download_error_callback,
|
||||
on_cancelled=self._download_cancelled_callback,
|
||||
)
|
||||
|
||||
return install_job
|
||||
|
||||
def _stat_size(self, path: Path) -> int:
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
"""Implementation of ModelManagerServiceBase."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from ..config import InvokeAIAppConfig
|
||||
@@ -67,7 +69,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
model_record_service: ModelRecordServiceBase,
|
||||
download_queue: DownloadQueueServiceBase,
|
||||
events: EventServiceBase,
|
||||
execution_device: torch.device = choose_torch_device(),
|
||||
execution_device: Optional[torch.device] = None,
|
||||
) -> Self:
|
||||
"""
|
||||
Construct the model manager service instance.
|
||||
@@ -82,7 +84,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
max_vram_cache_size=app_config.vram,
|
||||
lazy_offloading=app_config.lazy_offload,
|
||||
logger=logger,
|
||||
execution_device=execution_device,
|
||||
execution_device=execution_device or TorchDevice.choose_torch_device(),
|
||||
)
|
||||
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
|
||||
loader = ModelLoadService(
|
||||
|
||||
@@ -86,6 +86,12 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
self._poll_now()
|
||||
elif event_name == "batch_enqueued":
|
||||
self._poll_now()
|
||||
elif event_name == "queue_item_status_changed" and event[1]["data"]["queue_item"]["status"] in [
|
||||
"completed",
|
||||
"failed",
|
||||
"canceled",
|
||||
]:
|
||||
self._poll_now()
|
||||
|
||||
def resume(self) -> SessionProcessorStatus:
|
||||
if not self._resume_event.is_set():
|
||||
|
||||
@@ -245,6 +245,18 @@ class ImagesInterface(InvocationContextInterface):
|
||||
"""
|
||||
return self._services.images.get_dto(image_name)
|
||||
|
||||
def get_path(self, image_name: str, thumbnail: bool = False) -> Path:
|
||||
"""Gets the internal path to an image or thumbnail.
|
||||
|
||||
Args:
|
||||
image_name: The name of the image to get the path of.
|
||||
thumbnail: Get the path of the thumbnail instead of the full image
|
||||
|
||||
Returns:
|
||||
The local path of the image or thumbnail.
|
||||
"""
|
||||
return self._services.images.get_path(image_name, thumbnail)
|
||||
|
||||
|
||||
class TensorsInterface(InvocationContextInterface):
|
||||
def save(self, tensor: Tensor) -> str:
|
||||
|
||||
@@ -13,7 +13,7 @@ from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
|
||||
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
config = get_config()
|
||||
@@ -56,7 +56,7 @@ class DepthAnythingDetector:
|
||||
def __init__(self) -> None:
|
||||
self.model = None
|
||||
self.model_size: Union[Literal["large", "base", "small"], None] = None
|
||||
self.device = choose_torch_device()
|
||||
self.device = TorchDevice.choose_torch_device()
|
||||
|
||||
def load_model(self, model_size: Literal["large", "base", "small"] = "small"):
|
||||
DEPTH_ANYTHING_MODEL_PATH = config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"]
|
||||
@@ -81,7 +81,7 @@ class DepthAnythingDetector:
|
||||
self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu"))
|
||||
self.model.eval()
|
||||
|
||||
self.model.to(choose_torch_device())
|
||||
self.model.to(self.device)
|
||||
return self.model
|
||||
|
||||
def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image:
|
||||
@@ -94,7 +94,7 @@ class DepthAnythingDetector:
|
||||
|
||||
image_height, image_width = np_image.shape[:2]
|
||||
np_image = transform({"image": np_image})["image"]
|
||||
tensor_image = torch.from_numpy(np_image).unsqueeze(0).to(choose_torch_device())
|
||||
tensor_image = torch.from_numpy(np_image).unsqueeze(0).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
depth = self.model(tensor_image)
|
||||
|
||||
@@ -7,7 +7,7 @@ import onnxruntime as ort
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
from .onnxdet import inference_detector
|
||||
from .onnxpose import inference_pose
|
||||
@@ -28,9 +28,9 @@ config = get_config()
|
||||
|
||||
class Wholebody:
|
||||
def __init__(self):
|
||||
device = choose_torch_device()
|
||||
device = TorchDevice.choose_torch_device()
|
||||
|
||||
providers = ["CUDAExecutionProvider"] if device == "cuda" else ["CPUExecutionProvider"]
|
||||
providers = ["CUDAExecutionProvider"] if device.type == "cuda" else ["CPUExecutionProvider"]
|
||||
|
||||
DET_MODEL_PATH = config.models_path / DWPOSE_MODELS["yolox_l.onnx"]["local"]
|
||||
download_with_progress_bar("yolox_l.onnx", DWPOSE_MODELS["yolox_l.onnx"]["url"], DET_MODEL_PATH)
|
||||
|
||||
@@ -8,7 +8,7 @@ from PIL import Image
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
def norm_img(np_img):
|
||||
@@ -29,7 +29,7 @@ def load_jit_model(url_or_path, device):
|
||||
|
||||
class LaMA:
|
||||
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
|
||||
device = choose_torch_device()
|
||||
device = TorchDevice.choose_torch_device()
|
||||
model_location = get_config().models_path / "core/misc/lama/lama.pt"
|
||||
|
||||
if not model_location.exists():
|
||||
|
||||
@@ -11,7 +11,7 @@ from cv2.typing import MatLike
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
"""
|
||||
Adapted from https://github.com/xinntao/Real-ESRGAN/blob/master/realesrgan/utils.py
|
||||
@@ -65,7 +65,7 @@ class RealESRGAN:
|
||||
self.pre_pad = pre_pad
|
||||
self.mod_scale: Optional[int] = None
|
||||
self.half = half
|
||||
self.device = choose_torch_device()
|
||||
self.device = TorchDevice.choose_torch_device()
|
||||
|
||||
loadnet = torch.load(model_path, map_location=torch.device("cpu"))
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from transformers import AutoFeatureExtractor
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
CHECKER_PATH = "core/convert/stable-diffusion-safety-checker"
|
||||
@@ -51,7 +51,7 @@ class SafetyChecker:
|
||||
cls._load_safety_checker()
|
||||
if cls.safety_checker is None or cls.feature_extractor is None:
|
||||
return False
|
||||
device = choose_torch_device()
|
||||
device = TorchDevice.choose_torch_device()
|
||||
features = cls.feature_extractor([image], return_tensors="pt")
|
||||
features.to(device)
|
||||
cls.safety_checker.to(device)
|
||||
|
||||
@@ -1,182 +0,0 @@
|
||||
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
|
||||
# and modified as needed
|
||||
|
||||
# tencent-ailab comment:
|
||||
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from diffusers.models.attention_processor import AttnProcessor2_0 as DiffusersAttnProcessor2_0
|
||||
|
||||
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
|
||||
|
||||
|
||||
# Create a version of AttnProcessor2_0 that is a sub-class of nn.Module. This is required for IP-Adapter state_dict
|
||||
# loading.
|
||||
class AttnProcessor2_0(DiffusersAttnProcessor2_0, nn.Module):
|
||||
def __init__(self):
|
||||
DiffusersAttnProcessor2_0.__init__(self)
|
||||
nn.Module.__init__(self)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
temb=None,
|
||||
ip_adapter_image_prompt_embeds=None,
|
||||
):
|
||||
"""Re-definition of DiffusersAttnProcessor2_0.__call__(...) that accepts and ignores the
|
||||
ip_adapter_image_prompt_embeds parameter.
|
||||
"""
|
||||
return DiffusersAttnProcessor2_0.__call__(
|
||||
self, attn, hidden_states, encoder_hidden_states, attention_mask, temb
|
||||
)
|
||||
|
||||
|
||||
class IPAttnProcessor2_0(torch.nn.Module):
|
||||
r"""
|
||||
Attention processor for IP-Adapater for PyTorch 2.0.
|
||||
Args:
|
||||
hidden_size (`int`):
|
||||
The hidden size of the attention layer.
|
||||
cross_attention_dim (`int`):
|
||||
The number of channels in the `encoder_hidden_states`.
|
||||
scale (`float`, defaults to 1.0):
|
||||
the weight scale of image prompt.
|
||||
"""
|
||||
|
||||
def __init__(self, weights: list[IPAttentionProcessorWeights], scales: list[float]):
|
||||
super().__init__()
|
||||
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
assert len(weights) == len(scales)
|
||||
|
||||
self._weights = weights
|
||||
self._scales = scales
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
temb=None,
|
||||
ip_adapter_image_prompt_embeds=None,
|
||||
):
|
||||
"""Apply IP-Adapter attention.
|
||||
|
||||
Args:
|
||||
ip_adapter_image_prompt_embeds (torch.Tensor): The image prompt embeddings.
|
||||
Shape: (batch_size, num_ip_images, seq_len, ip_embedding_len).
|
||||
"""
|
||||
residual = hidden_states
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
# If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case,
|
||||
# we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here.
|
||||
assert ip_adapter_image_prompt_embeds is not None
|
||||
assert len(ip_adapter_image_prompt_embeds) == len(self._weights)
|
||||
|
||||
for ipa_embed, ipa_weights, scale in zip(
|
||||
ip_adapter_image_prompt_embeds, self._weights, self._scales, strict=True
|
||||
):
|
||||
# The batch dimensions should match.
|
||||
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
|
||||
# The token_len dimensions should match.
|
||||
assert ipa_embed.shape[-1] == encoder_hidden_states.shape[-1]
|
||||
|
||||
ip_hidden_states = ipa_embed
|
||||
|
||||
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
|
||||
|
||||
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
|
||||
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
|
||||
|
||||
# Expected ip_key and ip_value shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
|
||||
|
||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# Expected ip_key and ip_value shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
|
||||
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
ip_hidden_states = F.scaled_dot_product_attention(
|
||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
|
||||
|
||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||
|
||||
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
|
||||
|
||||
hidden_states = hidden_states + scale * ip_hidden_states
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
@@ -1,53 +0,0 @@
|
||||
from contextlib import contextmanager
|
||||
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
|
||||
from invokeai.backend.ip_adapter.attention_processor import AttnProcessor2_0, IPAttnProcessor2_0
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
|
||||
|
||||
class UNetPatcher:
|
||||
"""A class that contains multiple IP-Adapters and can apply them to a UNet."""
|
||||
|
||||
def __init__(self, ip_adapters: list[IPAdapter]):
|
||||
self._ip_adapters = ip_adapters
|
||||
self._scales = [1.0] * len(self._ip_adapters)
|
||||
|
||||
def set_scale(self, idx: int, value: float):
|
||||
self._scales[idx] = value
|
||||
|
||||
def _prepare_attention_processors(self, unet: UNet2DConditionModel):
|
||||
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
|
||||
weights into them.
|
||||
|
||||
Note that the `unet` param is only used to determine attention block dimensions and naming.
|
||||
"""
|
||||
# Construct a dict of attention processors based on the UNet's architecture.
|
||||
attn_procs = {}
|
||||
for idx, name in enumerate(unet.attn_processors.keys()):
|
||||
if name.endswith("attn1.processor"):
|
||||
attn_procs[name] = AttnProcessor2_0()
|
||||
else:
|
||||
# Collect the weights from each IP Adapter for the idx'th attention processor.
|
||||
attn_procs[name] = IPAttnProcessor2_0(
|
||||
[ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in self._ip_adapters],
|
||||
self._scales,
|
||||
)
|
||||
return attn_procs
|
||||
|
||||
@contextmanager
|
||||
def apply_ip_adapter_attention(self, unet: UNet2DConditionModel):
|
||||
"""A context manager that patches `unet` with IP-Adapter attention processors."""
|
||||
|
||||
attn_procs = self._prepare_attention_processors(unet)
|
||||
|
||||
orig_attn_processors = unet.attn_processors
|
||||
|
||||
try:
|
||||
# Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from the
|
||||
# passed dict. So, if you wanted to keep the dict for future use, you'd have to make a moderately-shallow copy
|
||||
# of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`.
|
||||
unet.set_attn_processor(attn_procs)
|
||||
yield None
|
||||
finally:
|
||||
unet.set_attn_processor(orig_attn_processors)
|
||||
@@ -18,7 +18,7 @@ from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoad
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data, calc_model_size_by_fs
|
||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
# TO DO: The loader is not thread safe!
|
||||
@@ -37,7 +37,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
self._logger = logger
|
||||
self._ram_cache = ram_cache
|
||||
self._convert_cache = convert_cache
|
||||
self._torch_dtype = torch_dtype(choose_torch_device(), app_config)
|
||||
self._torch_dtype = TorchDevice.choose_torch_dtype()
|
||||
|
||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
"""
|
||||
|
||||
@@ -30,15 +30,12 @@ import torch
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel, SubModelType
|
||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase, ModelLockerBase
|
||||
from .model_locker import ModelLocker
|
||||
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
from torch import mps
|
||||
|
||||
# Maximum size of the cache, in gigs
|
||||
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
|
||||
DEFAULT_MAX_CACHE_SIZE = 6.0
|
||||
@@ -244,9 +241,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
|
||||
)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
|
||||
"""Move model into the indicated device.
|
||||
@@ -271,7 +266,12 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
|
||||
start_model_to_time = time.time()
|
||||
snapshot_before = self._capture_memory_snapshot()
|
||||
cache_entry.model.to(target_device)
|
||||
try:
|
||||
cache_entry.model.to(target_device)
|
||||
except Exception as e: # blow away cache entry
|
||||
self._delete_cache_entry(cache_entry)
|
||||
raise e
|
||||
|
||||
snapshot_after = self._capture_memory_snapshot()
|
||||
end_model_to_time = time.time()
|
||||
self.logger.debug(
|
||||
@@ -389,8 +389,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
)
|
||||
current_size -= cache_entry.size
|
||||
models_cleared += 1
|
||||
del self._cache_stack[pos]
|
||||
del self._cached_models[model_key]
|
||||
self._delete_cache_entry(cache_entry)
|
||||
del cache_entry
|
||||
|
||||
else:
|
||||
@@ -412,8 +411,9 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
self.stats.cleared = models_cleared
|
||||
gc.collect()
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
|
||||
|
||||
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:
|
||||
self._cache_stack.remove(cache_entry.key)
|
||||
del self._cached_models[cache_entry.key]
|
||||
|
||||
@@ -17,7 +17,7 @@ from diffusers.utils import logging as dlogging
|
||||
|
||||
from invokeai.app.services.model_install import ModelInstallServiceBase
|
||||
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
|
||||
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
from . import (
|
||||
AnyModelConfig,
|
||||
@@ -43,6 +43,7 @@ class ModelMerger(object):
|
||||
Initialize a ModelMerger object with the model installer.
|
||||
"""
|
||||
self._installer = installer
|
||||
self._dtype = TorchDevice.choose_torch_dtype()
|
||||
|
||||
def merge_diffusion_models(
|
||||
self,
|
||||
@@ -68,7 +69,7 @@ class ModelMerger(object):
|
||||
warnings.simplefilter("ignore")
|
||||
verbosity = dlogging.get_verbosity()
|
||||
dlogging.set_verbosity_error()
|
||||
dtype = torch.float16 if variant == "fp16" else torch_dtype(choose_torch_device())
|
||||
dtype = torch.float16 if variant == "fp16" else self._dtype
|
||||
|
||||
# Note that checkpoint_merger will not work with downloaded HuggingFace fp16 models
|
||||
# until upstream https://github.com/huggingface/diffusers/pull/6670 is merged and released.
|
||||
@@ -151,7 +152,7 @@ class ModelMerger(object):
|
||||
dump_path.mkdir(parents=True, exist_ok=True)
|
||||
dump_path = dump_path / merged_model_name
|
||||
|
||||
dtype = torch.float16 if variant == "fp16" else torch_dtype(choose_torch_device())
|
||||
dtype = torch.float16 if variant == "fp16" else self._dtype
|
||||
merged_pipe.save_pretrained(dump_path.as_posix(), safe_serialization=True, torch_dtype=dtype, variant=variant)
|
||||
|
||||
# register model and get its unique key
|
||||
|
||||
@@ -21,12 +21,11 @@ from pydantic import Field
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
|
||||
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
|
||||
from invokeai.backend.util.attention import auto_detect_slice_size
|
||||
from invokeai.backend.util.devices import normalize_device
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -149,16 +148,6 @@ class ControlNetData:
|
||||
resize_mode: str = Field(default="just_resize")
|
||||
|
||||
|
||||
@dataclass
|
||||
class IPAdapterData:
|
||||
ip_adapter_model: IPAdapter = Field(default=None)
|
||||
# TODO: change to polymorphic so can do different weights per step (once implemented...)
|
||||
weight: Union[float, List[float]] = Field(default=1.0)
|
||||
# weight: float = Field(default=1.0)
|
||||
begin_step_percent: float = Field(default=0.0)
|
||||
end_step_percent: float = Field(default=1.0)
|
||||
|
||||
|
||||
@dataclass
|
||||
class T2IAdapterData:
|
||||
"""A structure containing the information required to apply conditioning from a single T2I-Adapter model."""
|
||||
@@ -266,7 +255,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if self.unet.device.type == "cpu" or self.unet.device.type == "mps":
|
||||
mem_free = psutil.virtual_memory().free
|
||||
elif self.unet.device.type == "cuda":
|
||||
mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.unet.device))
|
||||
mem_free, _ = torch.cuda.mem_get_info(TorchDevice.normalize(self.unet.device))
|
||||
else:
|
||||
raise ValueError(f"unrecognized device {self.unet.device}")
|
||||
# input tensor of [1, 4, h/8, w/8]
|
||||
@@ -295,7 +284,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
num_inference_steps: int,
|
||||
conditioning_data: ConditioningData,
|
||||
scheduler_step_kwargs: dict[str, Any],
|
||||
conditioning_data: TextConditioningData,
|
||||
*,
|
||||
noise: Optional[torch.Tensor],
|
||||
timesteps: torch.Tensor,
|
||||
@@ -308,7 +298,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
masked_latents: Optional[torch.Tensor] = None,
|
||||
gradient_mask: Optional[bool] = False,
|
||||
seed: Optional[int] = None,
|
||||
seed: int,
|
||||
) -> torch.Tensor:
|
||||
if init_timestep.shape[0] == 0:
|
||||
return latents
|
||||
@@ -326,20 +316,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
||||
|
||||
if mask is not None:
|
||||
# if no noise provided, noisify unmasked area based on seed(or 0 as fallback)
|
||||
if noise is None:
|
||||
noise = torch.randn(
|
||||
orig_latents.shape,
|
||||
dtype=torch.float32,
|
||||
device="cpu",
|
||||
generator=torch.Generator(device="cpu").manual_seed(seed or 0),
|
||||
).to(device=orig_latents.device, dtype=orig_latents.dtype)
|
||||
|
||||
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
||||
latents = torch.lerp(
|
||||
orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype)
|
||||
)
|
||||
|
||||
if is_inpainting_model(self.unet):
|
||||
if masked_latents is None:
|
||||
raise Exception("Source image required for inpaint mask when inpaint model used!")
|
||||
@@ -348,6 +324,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
self._unet_forward, mask, masked_latents
|
||||
)
|
||||
else:
|
||||
# if no noise provided, noisify unmasked area based on seed
|
||||
if noise is None:
|
||||
noise = torch.randn(
|
||||
orig_latents.shape,
|
||||
dtype=torch.float32,
|
||||
device="cpu",
|
||||
generator=torch.Generator(device="cpu").manual_seed(seed),
|
||||
).to(device=orig_latents.device, dtype=orig_latents.dtype)
|
||||
|
||||
additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, gradient_mask))
|
||||
|
||||
try:
|
||||
@@ -355,6 +340,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
latents,
|
||||
timesteps,
|
||||
conditioning_data,
|
||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
additional_guidance=additional_guidance,
|
||||
control_data=control_data,
|
||||
ip_adapter_data=ip_adapter_data,
|
||||
@@ -380,7 +366,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
timesteps,
|
||||
conditioning_data: ConditioningData,
|
||||
conditioning_data: TextConditioningData,
|
||||
scheduler_step_kwargs: dict[str, Any],
|
||||
*,
|
||||
additional_guidance: List[Callable] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
@@ -397,22 +384,22 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if timesteps.shape[0] == 0:
|
||||
return latents
|
||||
|
||||
ip_adapter_unet_patcher = None
|
||||
extra_conditioning_info = conditioning_data.text_embeddings.extra_conditioning
|
||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||
attn_ctx = self.invokeai_diffuser.custom_attention_context(
|
||||
self.invokeai_diffuser.model,
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
use_ip_adapter = ip_adapter_data is not None
|
||||
use_regional_prompting = (
|
||||
conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None
|
||||
)
|
||||
unet_attention_patcher = None
|
||||
self.use_ip_adapter = use_ip_adapter
|
||||
attn_ctx = nullcontext()
|
||||
|
||||
if use_ip_adapter or use_regional_prompting:
|
||||
ip_adapters: Optional[List[UNetIPAdapterData]] = (
|
||||
[{"ip_adapter": ipa.ip_adapter_model, "target_blocks": ipa.target_blocks} for ipa in ip_adapter_data]
|
||||
if use_ip_adapter
|
||||
else None
|
||||
)
|
||||
self.use_ip_adapter = False
|
||||
elif ip_adapter_data is not None:
|
||||
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
|
||||
# As it is now, the IP-Adapter will silently be skipped.
|
||||
ip_adapter_unet_patcher = UNetPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data])
|
||||
attn_ctx = ip_adapter_unet_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
||||
self.use_ip_adapter = True
|
||||
else:
|
||||
attn_ctx = nullcontext()
|
||||
unet_attention_patcher = UNetAttentionPatcher(ip_adapters)
|
||||
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
||||
|
||||
with attn_ctx:
|
||||
if callback is not None:
|
||||
@@ -435,11 +422,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
conditioning_data,
|
||||
step_index=i,
|
||||
total_step_count=len(timesteps),
|
||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
additional_guidance=additional_guidance,
|
||||
control_data=control_data,
|
||||
ip_adapter_data=ip_adapter_data,
|
||||
t2i_adapter_data=t2i_adapter_data,
|
||||
ip_adapter_unet_patcher=ip_adapter_unet_patcher,
|
||||
)
|
||||
latents = step_output.prev_sample
|
||||
predicted_original = getattr(step_output, "pred_original_sample", None)
|
||||
@@ -463,14 +450,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
self,
|
||||
t: torch.Tensor,
|
||||
latents: torch.Tensor,
|
||||
conditioning_data: ConditioningData,
|
||||
conditioning_data: TextConditioningData,
|
||||
step_index: int,
|
||||
total_step_count: int,
|
||||
scheduler_step_kwargs: dict[str, Any],
|
||||
additional_guidance: List[Callable] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
||||
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
||||
ip_adapter_unet_patcher: Optional[UNetPatcher] = None,
|
||||
):
|
||||
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
||||
timestep = t[0]
|
||||
@@ -485,23 +472,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
# i.e. before or after passing it to InvokeAIDiffuserComponent
|
||||
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
|
||||
|
||||
# handle IP-Adapter
|
||||
if self.use_ip_adapter and ip_adapter_data is not None: # somewhat redundant but logic is clearer
|
||||
for i, single_ip_adapter_data in enumerate(ip_adapter_data):
|
||||
first_adapter_step = math.floor(single_ip_adapter_data.begin_step_percent * total_step_count)
|
||||
last_adapter_step = math.ceil(single_ip_adapter_data.end_step_percent * total_step_count)
|
||||
weight = (
|
||||
single_ip_adapter_data.weight[step_index]
|
||||
if isinstance(single_ip_adapter_data.weight, List)
|
||||
else single_ip_adapter_data.weight
|
||||
)
|
||||
if step_index >= first_adapter_step and step_index <= last_adapter_step:
|
||||
# Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range.
|
||||
ip_adapter_unet_patcher.set_scale(i, weight)
|
||||
else:
|
||||
# Otherwise, set the IP-Adapter's scale to 0, so it has no effect.
|
||||
ip_adapter_unet_patcher.set_scale(i, 0.0)
|
||||
|
||||
# Handle ControlNet(s)
|
||||
down_block_additional_residuals = None
|
||||
mid_block_additional_residual = None
|
||||
@@ -550,6 +520,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
step_index=step_index,
|
||||
total_step_count=total_step_count,
|
||||
conditioning_data=conditioning_data,
|
||||
ip_adapter_data=ip_adapter_data,
|
||||
down_block_additional_residuals=down_block_additional_residuals, # for ControlNet
|
||||
mid_block_additional_residual=mid_block_additional_residual, # for ControlNet
|
||||
down_intrablock_additional_residuals=down_intrablock_additional_residuals, # for T2I-Adapter
|
||||
@@ -569,7 +540,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
step_output = self.scheduler.step(noise_pred, timestep, latents, **conditioning_data.scheduler_args)
|
||||
step_output = self.scheduler.step(noise_pred, timestep, latents, **scheduler_step_kwargs)
|
||||
|
||||
# TODO: discuss injection point options. For now this is a patch to get progress images working with inpainting again.
|
||||
for guidance in additional_guidance:
|
||||
|
||||
@@ -1,27 +1,17 @@
|
||||
import dataclasses
|
||||
import inspect
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, List, Optional, Union
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from .cross_attention_control import Arguments
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtraConditioningInfo:
|
||||
tokens_count_including_eos_bos: int
|
||||
cross_attention_control_args: Optional[Arguments] = None
|
||||
|
||||
@property
|
||||
def wants_cross_attention_control(self):
|
||||
return self.cross_attention_control_args is not None
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
|
||||
|
||||
@dataclass
|
||||
class BasicConditioningInfo:
|
||||
"""SD 1/2 text conditioning information produced by Compel."""
|
||||
|
||||
embeds: torch.Tensor
|
||||
extra_conditioning: Optional[ExtraConditioningInfo]
|
||||
|
||||
def to(self, device, dtype=None):
|
||||
self.embeds = self.embeds.to(device=device, dtype=dtype)
|
||||
@@ -35,6 +25,8 @@ class ConditioningFieldData:
|
||||
|
||||
@dataclass
|
||||
class SDXLConditioningInfo(BasicConditioningInfo):
|
||||
"""SDXL text conditioning information produced by Compel."""
|
||||
|
||||
pooled_embeds: torch.Tensor
|
||||
add_time_ids: torch.Tensor
|
||||
|
||||
@@ -57,37 +49,75 @@ class IPAdapterConditioningInfo:
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConditioningData:
|
||||
unconditioned_embeddings: BasicConditioningInfo
|
||||
text_embeddings: BasicConditioningInfo
|
||||
"""
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
||||
Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
||||
images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
||||
"""
|
||||
guidance_scale: Union[float, List[float]]
|
||||
""" for models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7 .
|
||||
ref [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf)
|
||||
"""
|
||||
guidance_rescale_multiplier: float = 0
|
||||
scheduler_args: dict[str, Any] = field(default_factory=dict)
|
||||
class IPAdapterData:
|
||||
ip_adapter_model: IPAdapter
|
||||
ip_adapter_conditioning: IPAdapterConditioningInfo
|
||||
mask: torch.Tensor
|
||||
target_blocks: List[str]
|
||||
|
||||
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]] = None
|
||||
# Either a single weight applied to all steps, or a list of weights for each step.
|
||||
weight: Union[float, List[float]] = 1.0
|
||||
begin_step_percent: float = 0.0
|
||||
end_step_percent: float = 1.0
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.text_embeddings.dtype
|
||||
def scale_for_step(self, step_index: int, total_steps: int) -> float:
|
||||
first_adapter_step = math.floor(self.begin_step_percent * total_steps)
|
||||
last_adapter_step = math.ceil(self.end_step_percent * total_steps)
|
||||
weight = self.weight[step_index] if isinstance(self.weight, List) else self.weight
|
||||
if step_index >= first_adapter_step and step_index <= last_adapter_step:
|
||||
# Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range.
|
||||
return weight
|
||||
# Otherwise, set the IP-Adapter's scale to 0, so it has no effect.
|
||||
return 0.0
|
||||
|
||||
def add_scheduler_args_if_applicable(self, scheduler, **kwargs):
|
||||
scheduler_args = dict(self.scheduler_args)
|
||||
step_method = inspect.signature(scheduler.step)
|
||||
for name, value in kwargs.items():
|
||||
try:
|
||||
step_method.bind_partial(**{name: value})
|
||||
except TypeError:
|
||||
# FIXME: don't silently discard arguments
|
||||
pass # debug("%s does not accept argument named %r", scheduler, name)
|
||||
else:
|
||||
scheduler_args[name] = value
|
||||
return dataclasses.replace(self, scheduler_args=scheduler_args)
|
||||
|
||||
@dataclass
|
||||
class Range:
|
||||
start: int
|
||||
end: int
|
||||
|
||||
|
||||
class TextConditioningRegions:
|
||||
def __init__(
|
||||
self,
|
||||
masks: torch.Tensor,
|
||||
ranges: list[Range],
|
||||
):
|
||||
# A binary mask indicating the regions of the image that the prompt should be applied to.
|
||||
# Shape: (1, num_prompts, height, width)
|
||||
# Dtype: torch.bool
|
||||
self.masks = masks
|
||||
|
||||
# A list of ranges indicating the start and end indices of the embeddings that corresponding mask applies to.
|
||||
# ranges[i] contains the embedding range for the i'th prompt / mask.
|
||||
self.ranges = ranges
|
||||
|
||||
assert self.masks.shape[1] == len(self.ranges)
|
||||
|
||||
|
||||
class TextConditioningData:
|
||||
def __init__(
|
||||
self,
|
||||
uncond_text: Union[BasicConditioningInfo, SDXLConditioningInfo],
|
||||
cond_text: Union[BasicConditioningInfo, SDXLConditioningInfo],
|
||||
uncond_regions: Optional[TextConditioningRegions],
|
||||
cond_regions: Optional[TextConditioningRegions],
|
||||
guidance_scale: Union[float, List[float]],
|
||||
guidance_rescale_multiplier: float = 0,
|
||||
):
|
||||
self.uncond_text = uncond_text
|
||||
self.cond_text = cond_text
|
||||
self.uncond_regions = uncond_regions
|
||||
self.cond_regions = cond_regions
|
||||
# Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
# `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
||||
# Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
||||
# images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
||||
self.guidance_scale = guidance_scale
|
||||
# For models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7.
|
||||
# See [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||
self.guidance_rescale_multiplier = guidance_rescale_multiplier
|
||||
|
||||
def is_sdxl(self):
|
||||
assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo)
|
||||
return isinstance(self.cond_text, SDXLConditioningInfo)
|
||||
|
||||
@@ -1,218 +0,0 @@
|
||||
# adapted from bloc97's CrossAttentionControl colab
|
||||
# https://github.com/bloc97/CrossAttentionControl
|
||||
|
||||
|
||||
import enum
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from compel.cross_attention_control import Arguments
|
||||
from diffusers.models.attention_processor import Attention, SlicedAttnProcessor
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
|
||||
from invokeai.backend.util.devices import torch_dtype
|
||||
|
||||
|
||||
class CrossAttentionType(enum.Enum):
|
||||
SELF = 1
|
||||
TOKENS = 2
|
||||
|
||||
|
||||
class CrossAttnControlContext:
|
||||
def __init__(self, arguments: Arguments):
|
||||
"""
|
||||
:param arguments: Arguments for the cross-attention control process
|
||||
"""
|
||||
self.cross_attention_mask: Optional[torch.Tensor] = None
|
||||
self.cross_attention_index_map: Optional[torch.Tensor] = None
|
||||
self.arguments = arguments
|
||||
|
||||
def get_active_cross_attention_control_types_for_step(
|
||||
self, percent_through: float = None
|
||||
) -> list[CrossAttentionType]:
|
||||
"""
|
||||
Should cross-attention control be applied on the given step?
|
||||
:param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0.
|
||||
:return: A list of attention types that cross-attention control should be performed for on the given step. May be [].
|
||||
"""
|
||||
if percent_through is None:
|
||||
return [CrossAttentionType.SELF, CrossAttentionType.TOKENS]
|
||||
|
||||
opts = self.arguments.edit_options
|
||||
to_control = []
|
||||
if opts["s_start"] <= percent_through < opts["s_end"]:
|
||||
to_control.append(CrossAttentionType.SELF)
|
||||
if opts["t_start"] <= percent_through < opts["t_end"]:
|
||||
to_control.append(CrossAttentionType.TOKENS)
|
||||
return to_control
|
||||
|
||||
|
||||
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: CrossAttnControlContext):
|
||||
"""
|
||||
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
||||
|
||||
:param model: The unet model to inject into.
|
||||
:return: None
|
||||
"""
|
||||
|
||||
# adapted from init_attention_edit
|
||||
device = context.arguments.edited_conditioning.device
|
||||
|
||||
# urgh. should this be hardcoded?
|
||||
max_length = 77
|
||||
# mask=1 means use base prompt attention, mask=0 means use edited prompt attention
|
||||
mask = torch.zeros(max_length, dtype=torch_dtype(device))
|
||||
indices_target = torch.arange(max_length, dtype=torch.long)
|
||||
indices = torch.arange(max_length, dtype=torch.long)
|
||||
for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
|
||||
if b0 < max_length:
|
||||
if name == "equal": # or (name == "replace" and a1 - a0 == b1 - b0):
|
||||
# these tokens have not been edited
|
||||
indices[b0:b1] = indices_target[a0:a1]
|
||||
mask[b0:b1] = 1
|
||||
|
||||
context.cross_attention_mask = mask.to(device)
|
||||
context.cross_attention_index_map = indices.to(device)
|
||||
old_attn_processors = unet.attn_processors
|
||||
if torch.backends.mps.is_available():
|
||||
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
|
||||
unet.set_attn_processor(SwapCrossAttnProcessor())
|
||||
else:
|
||||
# try to re-use an existing slice size
|
||||
default_slice_size = 4
|
||||
slice_size = next(
|
||||
(p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size
|
||||
)
|
||||
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
|
||||
|
||||
|
||||
@dataclass
|
||||
class SwapCrossAttnContext:
|
||||
modified_text_embeddings: torch.Tensor
|
||||
index_map: torch.Tensor # maps from original prompt token indices to the equivalent tokens in the modified prompt
|
||||
mask: torch.Tensor # in the target space of the index_map
|
||||
cross_attention_types_to_do: list[CrossAttentionType] = field(default_factory=list)
|
||||
|
||||
def wants_cross_attention_control(self, attn_type: CrossAttentionType) -> bool:
|
||||
return attn_type in self.cross_attention_types_to_do
|
||||
|
||||
@classmethod
|
||||
def make_mask_and_index_map(
|
||||
cls, edit_opcodes: list[tuple[str, int, int, int, int]], max_length: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# mask=1 means use original prompt attention, mask=0 means use modified prompt attention
|
||||
mask = torch.zeros(max_length)
|
||||
indices_target = torch.arange(max_length, dtype=torch.long)
|
||||
indices = torch.arange(max_length, dtype=torch.long)
|
||||
for name, a0, a1, b0, b1 in edit_opcodes:
|
||||
if b0 < max_length:
|
||||
if name == "equal":
|
||||
# these tokens remain the same as in the original prompt
|
||||
indices[b0:b1] = indices_target[a0:a1]
|
||||
mask[b0:b1] = 1
|
||||
|
||||
return mask, indices
|
||||
|
||||
|
||||
class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
|
||||
# TODO: dynamically pick slice size based on memory conditions
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
# kwargs
|
||||
swap_cross_attn_context: SwapCrossAttnContext = None,
|
||||
**kwargs,
|
||||
):
|
||||
attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS
|
||||
|
||||
# if cross-attention control is not in play, just call through to the base implementation.
|
||||
if (
|
||||
attention_type is CrossAttentionType.SELF
|
||||
or swap_cross_attn_context is None
|
||||
or not swap_cross_attn_context.wants_cross_attention_control(attention_type)
|
||||
):
|
||||
# print(f"SwapCrossAttnContext for {attention_type} not active - passing request to superclass")
|
||||
return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask)
|
||||
# else:
|
||||
# print(f"SwapCrossAttnContext for {attention_type} active")
|
||||
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
attention_mask = attn.prepare_attention_mask(
|
||||
attention_mask=attention_mask,
|
||||
target_length=sequence_length,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
dim = query.shape[-1]
|
||||
query = attn.head_to_batch_dim(query)
|
||||
|
||||
original_text_embeddings = encoder_hidden_states
|
||||
modified_text_embeddings = swap_cross_attn_context.modified_text_embeddings
|
||||
original_text_key = attn.to_k(original_text_embeddings)
|
||||
modified_text_key = attn.to_k(modified_text_embeddings)
|
||||
original_value = attn.to_v(original_text_embeddings)
|
||||
modified_value = attn.to_v(modified_text_embeddings)
|
||||
|
||||
original_text_key = attn.head_to_batch_dim(original_text_key)
|
||||
modified_text_key = attn.head_to_batch_dim(modified_text_key)
|
||||
original_value = attn.head_to_batch_dim(original_value)
|
||||
modified_value = attn.head_to_batch_dim(modified_value)
|
||||
|
||||
# compute slices and prepare output tensor
|
||||
batch_size_attention = query.shape[0]
|
||||
hidden_states = torch.zeros(
|
||||
(batch_size_attention, sequence_length, dim // attn.heads),
|
||||
device=query.device,
|
||||
dtype=query.dtype,
|
||||
)
|
||||
|
||||
# do slices
|
||||
for i in range(max(1, hidden_states.shape[0] // self.slice_size)):
|
||||
start_idx = i * self.slice_size
|
||||
end_idx = (i + 1) * self.slice_size
|
||||
|
||||
query_slice = query[start_idx:end_idx]
|
||||
original_key_slice = original_text_key[start_idx:end_idx]
|
||||
modified_key_slice = modified_text_key[start_idx:end_idx]
|
||||
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
||||
|
||||
original_attn_slice = attn.get_attention_scores(query_slice, original_key_slice, attn_mask_slice)
|
||||
modified_attn_slice = attn.get_attention_scores(query_slice, modified_key_slice, attn_mask_slice)
|
||||
|
||||
# because the prompt modifications may result in token sequences shifted forwards or backwards,
|
||||
# the original attention probabilities must be remapped to account for token index changes in the
|
||||
# modified prompt
|
||||
remapped_original_attn_slice = torch.index_select(
|
||||
original_attn_slice, -1, swap_cross_attn_context.index_map
|
||||
)
|
||||
|
||||
# only some tokens taken from the original attention probabilities. this is controlled by the mask.
|
||||
mask = swap_cross_attn_context.mask
|
||||
inverse_mask = 1 - mask
|
||||
attn_slice = remapped_original_attn_slice * mask + modified_attn_slice * inverse_mask
|
||||
|
||||
del remapped_original_attn_slice, modified_attn_slice
|
||||
|
||||
attn_slice = torch.bmm(attn_slice, modified_value[start_idx:end_idx])
|
||||
hidden_states[start_idx:end_idx] = attn_slice
|
||||
|
||||
# done
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SwapCrossAttnProcessor(SlicedSwapCrossAttnProcesser):
|
||||
def __init__(self):
|
||||
super(SwapCrossAttnProcessor, self).__init__(slice_size=int(1e9)) # massive slice size = don't slice
|
||||
214
invokeai/backend/stable_diffusion/diffusion/custom_atttention.py
Normal file
@@ -0,0 +1,214 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, cast
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from diffusers.models.attention_processor import Attention, AttnProcessor2_0
|
||||
|
||||
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
|
||||
from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData
|
||||
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
|
||||
|
||||
|
||||
@dataclass
|
||||
class IPAdapterAttentionWeights:
|
||||
ip_adapter_weights: IPAttentionProcessorWeights
|
||||
skip: bool
|
||||
|
||||
|
||||
class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
"""A custom implementation of AttnProcessor2_0 that supports additional Invoke features.
|
||||
This implementation is based on
|
||||
https://github.com/huggingface/diffusers/blame/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1204
|
||||
Supported custom features:
|
||||
- IP-Adapter
|
||||
- Regional prompt attention
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ip_adapter_attention_weights: Optional[List[IPAdapterAttentionWeights]] = None,
|
||||
):
|
||||
"""Initialize a CustomAttnProcessor2_0.
|
||||
Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are
|
||||
layer-specific are passed to __init__().
|
||||
Args:
|
||||
ip_adapter_weights: The IP-Adapter attention weights. ip_adapter_weights[i] contains the attention weights
|
||||
for the i'th IP-Adapter.
|
||||
"""
|
||||
super().__init__()
|
||||
self._ip_adapter_attention_weights = ip_adapter_attention_weights
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
# For Regional Prompting:
|
||||
regional_prompt_data: Optional[RegionalPromptData] = None,
|
||||
percent_through: Optional[torch.Tensor] = None,
|
||||
# For IP-Adapter:
|
||||
regional_ip_data: Optional[RegionalIPData] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
"""Apply attention.
|
||||
Args:
|
||||
regional_prompt_data: The regional prompt data for the current batch. If not None, this will be used to
|
||||
apply regional prompt masking.
|
||||
regional_ip_data: The IP-Adapter data for the current batch.
|
||||
"""
|
||||
# If true, we are doing cross-attention, if false we are doing self-attention.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
|
||||
# Start unmodified block from AttnProcessor2_0.
|
||||
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
|
||||
residual = hidden_states
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
# End unmodified block from AttnProcessor2_0.
|
||||
|
||||
_, query_seq_len, _ = hidden_states.shape
|
||||
# Handle regional prompt attention masks.
|
||||
if regional_prompt_data is not None and is_cross_attention:
|
||||
assert percent_through is not None
|
||||
prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask(
|
||||
query_seq_len=query_seq_len, key_seq_len=sequence_length
|
||||
)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = prompt_region_attention_mask
|
||||
else:
|
||||
attention_mask = prompt_region_attention_mask + attention_mask
|
||||
|
||||
# Start unmodified block from AttnProcessor2_0.
|
||||
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
# End unmodified block from AttnProcessor2_0.
|
||||
|
||||
# Apply IP-Adapter conditioning.
|
||||
if is_cross_attention:
|
||||
if self._ip_adapter_attention_weights:
|
||||
assert regional_ip_data is not None
|
||||
ip_masks = regional_ip_data.get_masks(query_seq_len=query_seq_len)
|
||||
|
||||
assert (
|
||||
len(regional_ip_data.image_prompt_embeds)
|
||||
== len(self._ip_adapter_attention_weights)
|
||||
== len(regional_ip_data.scales)
|
||||
== ip_masks.shape[1]
|
||||
)
|
||||
|
||||
for ipa_index, ipa_embed in enumerate(regional_ip_data.image_prompt_embeds):
|
||||
ipa_weights = self._ip_adapter_attention_weights[ipa_index].ip_adapter_weights
|
||||
ipa_scale = regional_ip_data.scales[ipa_index]
|
||||
ip_mask = ip_masks[0, ipa_index, ...]
|
||||
|
||||
# The batch dimensions should match.
|
||||
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
|
||||
# The token_len dimensions should match.
|
||||
assert ipa_embed.shape[-1] == encoder_hidden_states.shape[-1]
|
||||
|
||||
ip_hidden_states = ipa_embed
|
||||
|
||||
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
|
||||
|
||||
if not self._ip_adapter_attention_weights[ipa_index].skip:
|
||||
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
|
||||
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
|
||||
|
||||
# Expected ip_key and ip_value shape:
|
||||
# (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
|
||||
|
||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# Expected ip_key and ip_value shape:
|
||||
# (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
|
||||
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
ip_hidden_states = F.scaled_dot_product_attention(
|
||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
|
||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(
|
||||
batch_size, -1, attn.heads * head_dim
|
||||
)
|
||||
|
||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||
|
||||
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
|
||||
hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask
|
||||
else:
|
||||
# If IP-Adapter is not enabled, then regional_ip_data should not be passed in.
|
||||
assert regional_ip_data is None
|
||||
|
||||
# Start unmodified block from AttnProcessor2_0.
|
||||
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
# End of unmodified block from AttnProcessor2_0
|
||||
|
||||
# casting torch.Tensor to torch.FloatTensor to avoid type issues
|
||||
return cast(torch.FloatTensor, hidden_states)
|
||||
@@ -0,0 +1,72 @@
|
||||
import torch
|
||||
|
||||
|
||||
class RegionalIPData:
|
||||
"""A class to manage the data for regional IP-Adapter conditioning."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_prompt_embeds: list[torch.Tensor],
|
||||
scales: list[float],
|
||||
masks: list[torch.Tensor],
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
max_downscale_factor: int = 8,
|
||||
):
|
||||
"""Initialize a `IPAdapterConditioningData` object."""
|
||||
assert len(image_prompt_embeds) == len(scales) == len(masks)
|
||||
|
||||
# The image prompt embeddings.
|
||||
# regional_ip_data[i] contains the image prompt embeddings for the i'th IP-Adapter. Each tensor
|
||||
# has shape (batch_size, num_ip_images, seq_len, ip_embedding_len).
|
||||
self.image_prompt_embeds = image_prompt_embeds
|
||||
|
||||
# The scales for the IP-Adapter attention.
|
||||
# scales[i] contains the attention scale for the i'th IP-Adapter.
|
||||
self.scales = scales
|
||||
|
||||
# The IP-Adapter masks.
|
||||
# self._masks_by_seq_len[s] contains the spatial masks for the downsampling level with query sequence length of
|
||||
# s. It has shape (batch_size, num_ip_images, query_seq_len, 1). The masks have values of 1.0 for included
|
||||
# regions and 0.0 for excluded regions.
|
||||
self._masks_by_seq_len = self._prepare_masks(masks, max_downscale_factor, device, dtype)
|
||||
|
||||
def _prepare_masks(
|
||||
self, masks: list[torch.Tensor], max_downscale_factor: int, device: torch.device, dtype: torch.dtype
|
||||
) -> dict[int, torch.Tensor]:
|
||||
"""Prepare the masks for the IP-Adapter attention."""
|
||||
# Concatenate the masks so that they can be processed more efficiently.
|
||||
mask_tensor = torch.cat(masks, dim=1)
|
||||
|
||||
mask_tensor = mask_tensor.to(device=device, dtype=dtype)
|
||||
|
||||
masks_by_seq_len: dict[int, torch.Tensor] = {}
|
||||
|
||||
# Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached.
|
||||
downscale_factor = 1
|
||||
while downscale_factor <= max_downscale_factor:
|
||||
b, num_ip_adapters, h, w = mask_tensor.shape
|
||||
# Assert that the batch size is 1, because I haven't thought through batch handling for this feature yet.
|
||||
assert b == 1
|
||||
|
||||
# The IP-Adapters are applied in the cross-attention layers, where the query sequence length is the h * w of
|
||||
# the spatial features.
|
||||
query_seq_len = h * w
|
||||
|
||||
masks_by_seq_len[query_seq_len] = mask_tensor.view((b, num_ip_adapters, -1, 1))
|
||||
|
||||
downscale_factor *= 2
|
||||
if downscale_factor <= max_downscale_factor:
|
||||
# We use max pooling because we downscale to a pretty low resolution, so we don't want small mask
|
||||
# regions to be lost entirely.
|
||||
#
|
||||
# ceil_mode=True is set to mirror the downsampling behavior of SD and SDXL.
|
||||
#
|
||||
# TODO(ryand): In the future, we may want to experiment with other downsampling methods.
|
||||
mask_tensor = torch.nn.functional.max_pool2d(mask_tensor, kernel_size=2, stride=2, ceil_mode=True)
|
||||
|
||||
return masks_by_seq_len
|
||||
|
||||
def get_masks(self, query_seq_len: int) -> torch.Tensor:
|
||||
"""Get the mask for the given query sequence length."""
|
||||
return self._masks_by_seq_len[query_seq_len]
|
||||
@@ -0,0 +1,105 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
TextConditioningRegions,
|
||||
)
|
||||
|
||||
|
||||
class RegionalPromptData:
|
||||
"""A class to manage the prompt data for regional conditioning."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
regions: list[TextConditioningRegions],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
max_downscale_factor: int = 8,
|
||||
):
|
||||
"""Initialize a `RegionalPromptData` object.
|
||||
Args:
|
||||
regions (list[TextConditioningRegions]): regions[i] contains the prompt regions for the i'th sample in the
|
||||
batch.
|
||||
device (torch.device): The device to use for the attention masks.
|
||||
dtype (torch.dtype): The data type to use for the attention masks.
|
||||
max_downscale_factor: Spatial masks will be prepared for downscale factors from 1 to max_downscale_factor
|
||||
in steps of 2x.
|
||||
"""
|
||||
self._regions = regions
|
||||
self._device = device
|
||||
self._dtype = dtype
|
||||
# self._spatial_masks_by_seq_len[b][s] contains the spatial masks for the b'th batch sample with a query
|
||||
# sequence length of s.
|
||||
self._spatial_masks_by_seq_len: list[dict[int, torch.Tensor]] = self._prepare_spatial_masks(
|
||||
regions, max_downscale_factor
|
||||
)
|
||||
self._negative_cross_attn_mask_score = -10000.0
|
||||
|
||||
def _prepare_spatial_masks(
|
||||
self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8
|
||||
) -> list[dict[int, torch.Tensor]]:
|
||||
"""Prepare the spatial masks for all downscaling factors."""
|
||||
# batch_masks_by_seq_len[b][s] contains the spatial masks for the b'th batch sample with a query sequence length
|
||||
# of s.
|
||||
batch_sample_masks_by_seq_len: list[dict[int, torch.Tensor]] = []
|
||||
|
||||
for batch_sample_regions in regions:
|
||||
batch_sample_masks_by_seq_len.append({})
|
||||
|
||||
batch_sample_masks = batch_sample_regions.masks.to(device=self._device, dtype=self._dtype)
|
||||
|
||||
# Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached.
|
||||
downscale_factor = 1
|
||||
while downscale_factor <= max_downscale_factor:
|
||||
b, _num_prompts, h, w = batch_sample_masks.shape
|
||||
assert b == 1
|
||||
query_seq_len = h * w
|
||||
|
||||
batch_sample_masks_by_seq_len[-1][query_seq_len] = batch_sample_masks
|
||||
|
||||
downscale_factor *= 2
|
||||
if downscale_factor <= max_downscale_factor:
|
||||
# We use max pooling because we downscale to a pretty low resolution, so we don't want small prompt
|
||||
# regions to be lost entirely.
|
||||
#
|
||||
# ceil_mode=True is set to mirror the downsampling behavior of SD and SDXL.
|
||||
#
|
||||
# TODO(ryand): In the future, we may want to experiment with other downsampling methods (e.g.
|
||||
# nearest interpolation), and could potentially use a weighted mask rather than a binary mask.
|
||||
batch_sample_masks = F.max_pool2d(batch_sample_masks, kernel_size=2, stride=2, ceil_mode=True)
|
||||
|
||||
return batch_sample_masks_by_seq_len
|
||||
|
||||
def get_cross_attn_mask(self, query_seq_len: int, key_seq_len: int) -> torch.Tensor:
|
||||
"""Get the cross-attention mask for the given query sequence length.
|
||||
Args:
|
||||
query_seq_len: The length of the flattened spatial features at the current downscaling level.
|
||||
key_seq_len (int): The sequence length of the prompt embeddings (which act as the key in the cross-attention
|
||||
layers). This is most likely equal to the max embedding range end, but we pass it explicitly to be sure.
|
||||
Returns:
|
||||
torch.Tensor: The cross-attention score mask.
|
||||
shape: (batch_size, query_seq_len, key_seq_len).
|
||||
dtype: float
|
||||
"""
|
||||
batch_size = len(self._spatial_masks_by_seq_len)
|
||||
batch_spatial_masks = [self._spatial_masks_by_seq_len[b][query_seq_len] for b in range(batch_size)]
|
||||
|
||||
# Create an empty attention mask with the correct shape.
|
||||
attn_mask = torch.zeros((batch_size, query_seq_len, key_seq_len), dtype=self._dtype, device=self._device)
|
||||
|
||||
for batch_idx in range(batch_size):
|
||||
batch_sample_spatial_masks = batch_spatial_masks[batch_idx]
|
||||
batch_sample_regions = self._regions[batch_idx]
|
||||
|
||||
# Flatten the spatial dimensions of the mask by reshaping to (1, num_prompts, query_seq_len, 1).
|
||||
_, num_prompts, _, _ = batch_sample_spatial_masks.shape
|
||||
batch_sample_query_masks = batch_sample_spatial_masks.view((1, num_prompts, query_seq_len, 1))
|
||||
|
||||
for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges):
|
||||
batch_sample_query_scores = batch_sample_query_masks[0, prompt_idx, :, :].clone()
|
||||
batch_sample_query_mask = batch_sample_query_scores > 0.5
|
||||
batch_sample_query_scores[batch_sample_query_mask] = 0.0
|
||||
batch_sample_query_scores[~batch_sample_query_mask] = self._negative_cross_attn_mask_score
|
||||
attn_mask[batch_idx, :, embedding_range.start : embedding_range.end] = batch_sample_query_scores
|
||||
|
||||
return attn_mask
|
||||
@@ -1,26 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
ConditioningData,
|
||||
ExtraConditioningInfo,
|
||||
SDXLConditioningInfo,
|
||||
)
|
||||
|
||||
from .cross_attention_control import (
|
||||
CrossAttentionType,
|
||||
CrossAttnControlContext,
|
||||
SwapCrossAttnContext,
|
||||
setup_cross_attention_control_attention_processors,
|
||||
IPAdapterData,
|
||||
Range,
|
||||
TextConditioningData,
|
||||
TextConditioningRegions,
|
||||
)
|
||||
from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData
|
||||
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
|
||||
|
||||
ModelForwardCallback: TypeAlias = Union[
|
||||
# x, t, conditioning, Optional[cross-attention kwargs]
|
||||
@@ -58,31 +52,8 @@ class InvokeAIDiffuserComponent:
|
||||
self.conditioning = None
|
||||
self.model = model
|
||||
self.model_forward_callback = model_forward_callback
|
||||
self.cross_attention_control_context = None
|
||||
self.sequential_guidance = config.sequential_guidance
|
||||
|
||||
@contextmanager
|
||||
def custom_attention_context(
|
||||
self,
|
||||
unet: UNet2DConditionModel,
|
||||
extra_conditioning_info: Optional[ExtraConditioningInfo],
|
||||
):
|
||||
old_attn_processors = unet.attn_processors
|
||||
|
||||
try:
|
||||
self.cross_attention_control_context = CrossAttnControlContext(
|
||||
arguments=extra_conditioning_info.cross_attention_control_args,
|
||||
)
|
||||
setup_cross_attention_control_attention_processors(
|
||||
unet,
|
||||
self.cross_attention_control_context,
|
||||
)
|
||||
|
||||
yield None
|
||||
finally:
|
||||
self.cross_attention_control_context = None
|
||||
unet.set_attn_processor(old_attn_processors)
|
||||
|
||||
def do_controlnet_step(
|
||||
self,
|
||||
control_data,
|
||||
@@ -90,7 +61,7 @@ class InvokeAIDiffuserComponent:
|
||||
timestep: torch.Tensor,
|
||||
step_index: int,
|
||||
total_step_count: int,
|
||||
conditioning_data,
|
||||
conditioning_data: TextConditioningData,
|
||||
):
|
||||
down_block_res_samples, mid_block_res_sample = None, None
|
||||
|
||||
@@ -123,28 +94,28 @@ class InvokeAIDiffuserComponent:
|
||||
added_cond_kwargs = None
|
||||
|
||||
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
|
||||
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
|
||||
if conditioning_data.is_sdxl():
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
|
||||
"time_ids": conditioning_data.text_embeddings.add_time_ids,
|
||||
"text_embeds": conditioning_data.cond_text.pooled_embeds,
|
||||
"time_ids": conditioning_data.cond_text.add_time_ids,
|
||||
}
|
||||
encoder_hidden_states = conditioning_data.text_embeddings.embeds
|
||||
encoder_hidden_states = conditioning_data.cond_text.embeds
|
||||
encoder_attention_mask = None
|
||||
else:
|
||||
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
|
||||
if conditioning_data.is_sdxl():
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": torch.cat(
|
||||
[
|
||||
# TODO: how to pad? just by zeros? or even truncate?
|
||||
conditioning_data.unconditioned_embeddings.pooled_embeds,
|
||||
conditioning_data.text_embeddings.pooled_embeds,
|
||||
conditioning_data.uncond_text.pooled_embeds,
|
||||
conditioning_data.cond_text.pooled_embeds,
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
"time_ids": torch.cat(
|
||||
[
|
||||
conditioning_data.unconditioned_embeddings.add_time_ids,
|
||||
conditioning_data.text_embeddings.add_time_ids,
|
||||
conditioning_data.uncond_text.add_time_ids,
|
||||
conditioning_data.cond_text.add_time_ids,
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
@@ -153,8 +124,8 @@ class InvokeAIDiffuserComponent:
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
) = self._concat_conditionings_for_batch(
|
||||
conditioning_data.unconditioned_embeddings.embeds,
|
||||
conditioning_data.text_embeddings.embeds,
|
||||
conditioning_data.uncond_text.embeds,
|
||||
conditioning_data.cond_text.embeds,
|
||||
)
|
||||
if isinstance(control_datum.weight, list):
|
||||
# if controlnet has multiple weights, use the weight for the current step
|
||||
@@ -198,24 +169,15 @@ class InvokeAIDiffuserComponent:
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
conditioning_data: ConditioningData,
|
||||
conditioning_data: TextConditioningData,
|
||||
ip_adapter_data: Optional[list[IPAdapterData]],
|
||||
step_index: int,
|
||||
total_step_count: int,
|
||||
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
||||
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
||||
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
||||
):
|
||||
cross_attention_control_types_to_do = []
|
||||
if self.cross_attention_control_context is not None:
|
||||
percent_through = step_index / total_step_count
|
||||
cross_attention_control_types_to_do = (
|
||||
self.cross_attention_control_context.get_active_cross_attention_control_types_for_step(percent_through)
|
||||
)
|
||||
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
|
||||
|
||||
if wants_cross_attention_control or self.sequential_guidance:
|
||||
# If wants_cross_attention_control is True, we force the sequential mode to be used, because cross-attention
|
||||
# control is currently only supported in sequential mode.
|
||||
if self.sequential_guidance:
|
||||
(
|
||||
unconditioned_next_x,
|
||||
conditioned_next_x,
|
||||
@@ -223,7 +185,9 @@ class InvokeAIDiffuserComponent:
|
||||
x=sample,
|
||||
sigma=timestep,
|
||||
conditioning_data=conditioning_data,
|
||||
cross_attention_control_types_to_do=cross_attention_control_types_to_do,
|
||||
ip_adapter_data=ip_adapter_data,
|
||||
step_index=step_index,
|
||||
total_step_count=total_step_count,
|
||||
down_block_additional_residuals=down_block_additional_residuals,
|
||||
mid_block_additional_residual=mid_block_additional_residual,
|
||||
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
||||
@@ -236,6 +200,9 @@ class InvokeAIDiffuserComponent:
|
||||
x=sample,
|
||||
sigma=timestep,
|
||||
conditioning_data=conditioning_data,
|
||||
ip_adapter_data=ip_adapter_data,
|
||||
step_index=step_index,
|
||||
total_step_count=total_step_count,
|
||||
down_block_additional_residuals=down_block_additional_residuals,
|
||||
mid_block_additional_residual=mid_block_additional_residual,
|
||||
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
||||
@@ -294,53 +261,84 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
def _apply_standard_conditioning(
|
||||
self,
|
||||
x,
|
||||
sigma,
|
||||
conditioning_data: ConditioningData,
|
||||
x: torch.Tensor,
|
||||
sigma: torch.Tensor,
|
||||
conditioning_data: TextConditioningData,
|
||||
ip_adapter_data: Optional[list[IPAdapterData]],
|
||||
step_index: int,
|
||||
total_step_count: int,
|
||||
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
||||
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
||||
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
||||
):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at
|
||||
the cost of higher memory usage.
|
||||
"""
|
||||
x_twice = torch.cat([x] * 2)
|
||||
sigma_twice = torch.cat([sigma] * 2)
|
||||
|
||||
cross_attention_kwargs = None
|
||||
if conditioning_data.ip_adapter_conditioning is not None:
|
||||
cross_attention_kwargs = {}
|
||||
if ip_adapter_data is not None:
|
||||
ip_adapter_conditioning = [ipa.ip_adapter_conditioning for ipa in ip_adapter_data]
|
||||
# Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len).
|
||||
cross_attention_kwargs = {
|
||||
"ip_adapter_image_prompt_embeds": [
|
||||
torch.stack(
|
||||
[ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds]
|
||||
)
|
||||
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
|
||||
]
|
||||
}
|
||||
image_prompt_embeds = [
|
||||
torch.stack([ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds])
|
||||
for ipa_conditioning in ip_adapter_conditioning
|
||||
]
|
||||
scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data]
|
||||
ip_masks = [ipa.mask for ipa in ip_adapter_data]
|
||||
regional_ip_data = RegionalIPData(
|
||||
image_prompt_embeds=image_prompt_embeds, scales=scales, masks=ip_masks, dtype=x.dtype, device=x.device
|
||||
)
|
||||
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
|
||||
|
||||
added_cond_kwargs = None
|
||||
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
|
||||
if conditioning_data.is_sdxl():
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": torch.cat(
|
||||
[
|
||||
# TODO: how to pad? just by zeros? or even truncate?
|
||||
conditioning_data.unconditioned_embeddings.pooled_embeds,
|
||||
conditioning_data.text_embeddings.pooled_embeds,
|
||||
conditioning_data.uncond_text.pooled_embeds,
|
||||
conditioning_data.cond_text.pooled_embeds,
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
"time_ids": torch.cat(
|
||||
[
|
||||
conditioning_data.unconditioned_embeddings.add_time_ids,
|
||||
conditioning_data.text_embeddings.add_time_ids,
|
||||
conditioning_data.uncond_text.add_time_ids,
|
||||
conditioning_data.cond_text.add_time_ids,
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
}
|
||||
|
||||
if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None:
|
||||
# TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings
|
||||
# and masks are not changing from step-to-step, so this really only needs to be done once. While this seems
|
||||
# painfully inefficient, the time spent is typically negligible compared to the forward inference pass of
|
||||
# the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly
|
||||
# awkward to handle both standard conditioning and sequential conditioning further up the stack.
|
||||
regions = []
|
||||
for c, r in [
|
||||
(conditioning_data.uncond_text, conditioning_data.uncond_regions),
|
||||
(conditioning_data.cond_text, conditioning_data.cond_regions),
|
||||
]:
|
||||
if r is None:
|
||||
# Create a dummy mask and range for text conditioning that doesn't have region masks.
|
||||
_, _, h, w = x.shape
|
||||
r = TextConditioningRegions(
|
||||
masks=torch.ones((1, 1, h, w), dtype=x.dtype),
|
||||
ranges=[Range(start=0, end=c.embeds.shape[1])],
|
||||
)
|
||||
regions.append(r)
|
||||
|
||||
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
|
||||
regions=regions, device=x.device, dtype=x.dtype
|
||||
)
|
||||
cross_attention_kwargs["percent_through"] = step_index / total_step_count
|
||||
|
||||
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
|
||||
conditioning_data.unconditioned_embeddings.embeds, conditioning_data.text_embeddings.embeds
|
||||
conditioning_data.uncond_text.embeds, conditioning_data.cond_text.embeds
|
||||
)
|
||||
both_results = self.model_forward_callback(
|
||||
x_twice,
|
||||
@@ -360,8 +358,10 @@ class InvokeAIDiffuserComponent:
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
sigma,
|
||||
conditioning_data: ConditioningData,
|
||||
cross_attention_control_types_to_do: list[CrossAttentionType],
|
||||
conditioning_data: TextConditioningData,
|
||||
ip_adapter_data: Optional[list[IPAdapterData]],
|
||||
step_index: int,
|
||||
total_step_count: int,
|
||||
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
||||
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
||||
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
||||
@@ -391,53 +391,48 @@ class InvokeAIDiffuserComponent:
|
||||
if mid_block_additional_residual is not None:
|
||||
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
|
||||
|
||||
# If cross-attention control is enabled, prepare the SwapCrossAttnContext.
|
||||
cross_attn_processor_context = None
|
||||
if self.cross_attention_control_context is not None:
|
||||
# Note that the SwapCrossAttnContext is initialized with an empty list of cross_attention_types_to_do.
|
||||
# This list is empty because cross-attention control is not applied in the unconditioned pass. This field
|
||||
# will be populated before the conditioned pass.
|
||||
cross_attn_processor_context = SwapCrossAttnContext(
|
||||
modified_text_embeddings=self.cross_attention_control_context.arguments.edited_conditioning,
|
||||
index_map=self.cross_attention_control_context.cross_attention_index_map,
|
||||
mask=self.cross_attention_control_context.cross_attention_mask,
|
||||
cross_attention_types_to_do=[],
|
||||
)
|
||||
|
||||
#####################
|
||||
# Unconditioned pass
|
||||
#####################
|
||||
|
||||
cross_attention_kwargs = None
|
||||
cross_attention_kwargs = {}
|
||||
|
||||
# Prepare IP-Adapter cross-attention kwargs for the unconditioned pass.
|
||||
if conditioning_data.ip_adapter_conditioning is not None:
|
||||
if ip_adapter_data is not None:
|
||||
ip_adapter_conditioning = [ipa.ip_adapter_conditioning for ipa in ip_adapter_data]
|
||||
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
|
||||
cross_attention_kwargs = {
|
||||
"ip_adapter_image_prompt_embeds": [
|
||||
torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0)
|
||||
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
|
||||
]
|
||||
}
|
||||
image_prompt_embeds = [
|
||||
torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0)
|
||||
for ipa_conditioning in ip_adapter_conditioning
|
||||
]
|
||||
|
||||
# Prepare cross-attention control kwargs for the unconditioned pass.
|
||||
if cross_attn_processor_context is not None:
|
||||
cross_attention_kwargs = {"swap_cross_attn_context": cross_attn_processor_context}
|
||||
scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data]
|
||||
ip_masks = [ipa.mask for ipa in ip_adapter_data]
|
||||
regional_ip_data = RegionalIPData(
|
||||
image_prompt_embeds=image_prompt_embeds, scales=scales, masks=ip_masks, dtype=x.dtype, device=x.device
|
||||
)
|
||||
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
|
||||
|
||||
# Prepare SDXL conditioning kwargs for the unconditioned pass.
|
||||
added_cond_kwargs = None
|
||||
is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo
|
||||
if is_sdxl:
|
||||
if conditioning_data.is_sdxl():
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds,
|
||||
"time_ids": conditioning_data.unconditioned_embeddings.add_time_ids,
|
||||
"text_embeds": conditioning_data.uncond_text.pooled_embeds,
|
||||
"time_ids": conditioning_data.uncond_text.add_time_ids,
|
||||
}
|
||||
|
||||
# Prepare prompt regions for the unconditioned pass.
|
||||
if conditioning_data.uncond_regions is not None:
|
||||
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
|
||||
regions=[conditioning_data.uncond_regions], device=x.device, dtype=x.dtype
|
||||
)
|
||||
cross_attention_kwargs["percent_through"] = step_index / total_step_count
|
||||
|
||||
# Run unconditioned UNet denoising (i.e. negative prompt).
|
||||
unconditioned_next_x = self.model_forward_callback(
|
||||
x,
|
||||
sigma,
|
||||
conditioning_data.unconditioned_embeddings.embeds,
|
||||
conditioning_data.uncond_text.embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
down_block_additional_residuals=uncond_down_block,
|
||||
mid_block_additional_residual=uncond_mid_block,
|
||||
@@ -449,36 +444,43 @@ class InvokeAIDiffuserComponent:
|
||||
# Conditioned pass
|
||||
###################
|
||||
|
||||
cross_attention_kwargs = None
|
||||
cross_attention_kwargs = {}
|
||||
|
||||
# Prepare IP-Adapter cross-attention kwargs for the conditioned pass.
|
||||
if conditioning_data.ip_adapter_conditioning is not None:
|
||||
if ip_adapter_data is not None:
|
||||
ip_adapter_conditioning = [ipa.ip_adapter_conditioning for ipa in ip_adapter_data]
|
||||
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
|
||||
cross_attention_kwargs = {
|
||||
"ip_adapter_image_prompt_embeds": [
|
||||
torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0)
|
||||
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
|
||||
]
|
||||
}
|
||||
image_prompt_embeds = [
|
||||
torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0)
|
||||
for ipa_conditioning in ip_adapter_conditioning
|
||||
]
|
||||
|
||||
# Prepare cross-attention control kwargs for the conditioned pass.
|
||||
if cross_attn_processor_context is not None:
|
||||
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
|
||||
cross_attention_kwargs = {"swap_cross_attn_context": cross_attn_processor_context}
|
||||
scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data]
|
||||
ip_masks = [ipa.mask for ipa in ip_adapter_data]
|
||||
regional_ip_data = RegionalIPData(
|
||||
image_prompt_embeds=image_prompt_embeds, scales=scales, masks=ip_masks, dtype=x.dtype, device=x.device
|
||||
)
|
||||
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
|
||||
|
||||
# Prepare SDXL conditioning kwargs for the conditioned pass.
|
||||
added_cond_kwargs = None
|
||||
if is_sdxl:
|
||||
if conditioning_data.is_sdxl():
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
|
||||
"time_ids": conditioning_data.text_embeddings.add_time_ids,
|
||||
"text_embeds": conditioning_data.cond_text.pooled_embeds,
|
||||
"time_ids": conditioning_data.cond_text.add_time_ids,
|
||||
}
|
||||
|
||||
# Prepare prompt regions for the conditioned pass.
|
||||
if conditioning_data.cond_regions is not None:
|
||||
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
|
||||
regions=[conditioning_data.cond_regions], device=x.device, dtype=x.dtype
|
||||
)
|
||||
cross_attention_kwargs["percent_through"] = step_index / total_step_count
|
||||
|
||||
# Run conditioned UNet denoising (i.e. positive prompt).
|
||||
conditioned_next_x = self.model_forward_callback(
|
||||
x,
|
||||
sigma,
|
||||
conditioning_data.text_embeddings.embeds,
|
||||
conditioning_data.cond_text.embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
down_block_additional_residuals=cond_down_block,
|
||||
mid_block_additional_residual=cond_mid_block,
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Optional, TypedDict
|
||||
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import (
|
||||
CustomAttnProcessor2_0,
|
||||
IPAdapterAttentionWeights,
|
||||
)
|
||||
|
||||
|
||||
class UNetIPAdapterData(TypedDict):
|
||||
ip_adapter: IPAdapter
|
||||
target_blocks: List[str]
|
||||
|
||||
|
||||
class UNetAttentionPatcher:
|
||||
"""A class for patching a UNet with CustomAttnProcessor2_0 attention layers."""
|
||||
|
||||
def __init__(self, ip_adapter_data: Optional[List[UNetIPAdapterData]]):
|
||||
self._ip_adapters = ip_adapter_data
|
||||
|
||||
def _prepare_attention_processors(self, unet: UNet2DConditionModel):
|
||||
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
|
||||
weights into them (if IP-Adapters are being applied).
|
||||
Note that the `unet` param is only used to determine attention block dimensions and naming.
|
||||
"""
|
||||
# Construct a dict of attention processors based on the UNet's architecture.
|
||||
attn_procs = {}
|
||||
for idx, name in enumerate(unet.attn_processors.keys()):
|
||||
if name.endswith("attn1.processor") or self._ip_adapters is None:
|
||||
# "attn1" processors do not use IP-Adapters.
|
||||
attn_procs[name] = CustomAttnProcessor2_0()
|
||||
else:
|
||||
# Collect the weights from each IP Adapter for the idx'th attention processor.
|
||||
ip_adapter_attention_weights_collection: list[IPAdapterAttentionWeights] = []
|
||||
|
||||
for ip_adapter in self._ip_adapters:
|
||||
ip_adapter_weights = ip_adapter["ip_adapter"].attn_weights.get_attention_processor_weights(idx)
|
||||
skip = True
|
||||
for block in ip_adapter["target_blocks"]:
|
||||
if block in name:
|
||||
skip = False
|
||||
break
|
||||
ip_adapter_attention_weights: IPAdapterAttentionWeights = IPAdapterAttentionWeights(
|
||||
ip_adapter_weights=ip_adapter_weights, skip=skip
|
||||
)
|
||||
ip_adapter_attention_weights_collection.append(ip_adapter_attention_weights)
|
||||
|
||||
attn_procs[name] = CustomAttnProcessor2_0(ip_adapter_attention_weights_collection)
|
||||
|
||||
return attn_procs
|
||||
|
||||
@contextmanager
|
||||
def apply_ip_adapter_attention(self, unet: UNet2DConditionModel):
|
||||
"""A context manager that patches `unet` with CustomAttnProcessor2_0 attention layers."""
|
||||
attn_procs = self._prepare_attention_processors(unet)
|
||||
orig_attn_processors = unet.attn_processors
|
||||
|
||||
try:
|
||||
# Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from
|
||||
# the passed dict. So, if you wanted to keep the dict for future use, you'd have to make a
|
||||
# moderately-shallow copy of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`.
|
||||
unet.set_attn_processor(attn_procs)
|
||||
yield None
|
||||
finally:
|
||||
unet.set_attn_processor(orig_attn_processors)
|
||||
@@ -2,7 +2,6 @@
|
||||
Initialization file for invokeai.backend.util
|
||||
"""
|
||||
|
||||
from .devices import choose_precision, choose_torch_device
|
||||
from .logging import InvokeAILogger
|
||||
from .util import GIG, Chdir, directory_size
|
||||
|
||||
@@ -11,6 +10,4 @@ __all__ = [
|
||||
"directory_size",
|
||||
"Chdir",
|
||||
"InvokeAILogger",
|
||||
"choose_precision",
|
||||
"choose_torch_device",
|
||||
]
|
||||
|
||||
@@ -1,91 +1,110 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import nullcontext
|
||||
from typing import Literal, Optional, Union
|
||||
from typing import Dict, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import autocast
|
||||
from deprecated import deprecated
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
|
||||
# legacy APIs
|
||||
TorchPrecisionNames = Literal["float32", "float16", "bfloat16"]
|
||||
CPU_DEVICE = torch.device("cpu")
|
||||
CUDA_DEVICE = torch.device("cuda")
|
||||
MPS_DEVICE = torch.device("mps")
|
||||
|
||||
|
||||
@deprecated("Use TorchDevice.choose_torch_dtype() instead.") # type: ignore
|
||||
def choose_precision(device: torch.device) -> TorchPrecisionNames:
|
||||
"""Return the string representation of the recommended torch device."""
|
||||
torch_dtype = TorchDevice.choose_torch_dtype(device)
|
||||
return PRECISION_TO_NAME[torch_dtype]
|
||||
|
||||
|
||||
@deprecated("Use TorchDevice.choose_torch_device() instead.") # type: ignore
|
||||
def choose_torch_device() -> torch.device:
|
||||
"""Convenience routine for guessing which GPU device to run model on"""
|
||||
config = get_config()
|
||||
if config.device == "auto":
|
||||
if torch.cuda.is_available():
|
||||
return torch.device("cuda")
|
||||
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
return torch.device("mps")
|
||||
"""Return the torch.device to use for accelerated inference."""
|
||||
return TorchDevice.choose_torch_device()
|
||||
|
||||
|
||||
@deprecated("Use TorchDevice.choose_torch_dtype() instead.") # type: ignore
|
||||
def torch_dtype(device: torch.device) -> torch.dtype:
|
||||
"""Return the torch precision for the recommended torch device."""
|
||||
return TorchDevice.choose_torch_dtype(device)
|
||||
|
||||
|
||||
NAME_TO_PRECISION: Dict[TorchPrecisionNames, torch.dtype] = {
|
||||
"float32": torch.float32,
|
||||
"float16": torch.float16,
|
||||
"bfloat16": torch.bfloat16,
|
||||
}
|
||||
PRECISION_TO_NAME: Dict[torch.dtype, TorchPrecisionNames] = {v: k for k, v in NAME_TO_PRECISION.items()}
|
||||
|
||||
|
||||
class TorchDevice:
|
||||
"""Abstraction layer for torch devices."""
|
||||
|
||||
@classmethod
|
||||
def choose_torch_device(cls) -> torch.device:
|
||||
"""Return the torch.device to use for accelerated inference."""
|
||||
app_config = get_config()
|
||||
if app_config.device != "auto":
|
||||
device = torch.device(app_config.device)
|
||||
elif torch.cuda.is_available():
|
||||
device = CUDA_DEVICE
|
||||
elif torch.backends.mps.is_available():
|
||||
device = MPS_DEVICE
|
||||
else:
|
||||
return CPU_DEVICE
|
||||
else:
|
||||
return torch.device(config.device)
|
||||
device = CPU_DEVICE
|
||||
return cls.normalize(device)
|
||||
|
||||
|
||||
def get_torch_device_name() -> str:
|
||||
device = choose_torch_device()
|
||||
return torch.cuda.get_device_name(device) if device.type == "cuda" else device.type.upper()
|
||||
|
||||
|
||||
# We are in transition here from using a single global AppConfig to allowing multiple
|
||||
# configurations. It is strongly recommended to pass the app_config to this function.
|
||||
def choose_precision(
|
||||
device: torch.device, app_config: Optional[InvokeAIAppConfig] = None
|
||||
) -> Literal["float32", "float16", "bfloat16"]:
|
||||
"""Return an appropriate precision for the given torch device."""
|
||||
app_config = app_config or get_config()
|
||||
if device.type == "cuda":
|
||||
device_name = torch.cuda.get_device_name(device)
|
||||
if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name):
|
||||
if app_config.precision == "float32":
|
||||
return "float32"
|
||||
elif app_config.precision == "bfloat16":
|
||||
return "bfloat16"
|
||||
@classmethod
|
||||
def choose_torch_dtype(cls, device: Optional[torch.device] = None) -> torch.dtype:
|
||||
"""Return the precision to use for accelerated inference."""
|
||||
device = device or cls.choose_torch_device()
|
||||
config = get_config()
|
||||
if device.type == "cuda" and torch.cuda.is_available():
|
||||
device_name = torch.cuda.get_device_name(device)
|
||||
if "GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name:
|
||||
# These GPUs have limited support for float16
|
||||
return cls._to_dtype("float32")
|
||||
elif config.precision == "auto":
|
||||
# Default to float16 for CUDA devices
|
||||
return cls._to_dtype("float16")
|
||||
else:
|
||||
return "float16"
|
||||
elif device.type == "mps":
|
||||
return "float16"
|
||||
return "float32"
|
||||
# Use the user-defined precision
|
||||
return cls._to_dtype(config.precision)
|
||||
|
||||
elif device.type == "mps" and torch.backends.mps.is_available():
|
||||
if config.precision == "auto":
|
||||
# Default to float16 for MPS devices
|
||||
return cls._to_dtype("float16")
|
||||
else:
|
||||
# Use the user-defined precision
|
||||
return cls._to_dtype(config.precision)
|
||||
# CPU / safe fallback
|
||||
return cls._to_dtype("float32")
|
||||
|
||||
# We are in transition here from using a single global AppConfig to allowing multiple
|
||||
# configurations. It is strongly recommended to pass the app_config to this function.
|
||||
def torch_dtype(
|
||||
device: Optional[torch.device] = None,
|
||||
app_config: Optional[InvokeAIAppConfig] = None,
|
||||
) -> torch.dtype:
|
||||
device = device or choose_torch_device()
|
||||
precision = choose_precision(device, app_config)
|
||||
if precision == "float16":
|
||||
return torch.float16
|
||||
if precision == "bfloat16":
|
||||
return torch.bfloat16
|
||||
else:
|
||||
# "auto", "autocast", "float32"
|
||||
return torch.float32
|
||||
@classmethod
|
||||
def get_torch_device_name(cls) -> str:
|
||||
"""Return the device name for the current torch device."""
|
||||
device = cls.choose_torch_device()
|
||||
return torch.cuda.get_device_name(device) if device.type == "cuda" else device.type.upper()
|
||||
|
||||
|
||||
def choose_autocast(precision):
|
||||
"""Returns an autocast context or nullcontext for the given precision string"""
|
||||
# float16 currently requires autocast to avoid errors like:
|
||||
# 'expected scalar type Half but found Float'
|
||||
if precision == "autocast" or precision == "float16":
|
||||
return autocast
|
||||
return nullcontext
|
||||
|
||||
|
||||
def normalize_device(device: Union[str, torch.device]) -> torch.device:
|
||||
"""Ensure device has a device index defined, if appropriate."""
|
||||
device = torch.device(device)
|
||||
if device.index is None:
|
||||
# cuda might be the only torch backend that currently uses the device index?
|
||||
# I don't see anything like `current_device` for cpu or mps.
|
||||
if device.type == "cuda":
|
||||
@classmethod
|
||||
def normalize(cls, device: Union[str, torch.device]) -> torch.device:
|
||||
"""Add the device index to CUDA devices."""
|
||||
device = torch.device(device)
|
||||
if device.index is None and device.type == "cuda" and torch.cuda.is_available():
|
||||
device = torch.device(device.type, torch.cuda.current_device())
|
||||
return device
|
||||
return device
|
||||
|
||||
@classmethod
|
||||
def empty_cache(cls) -> None:
|
||||
"""Clear the GPU device cache."""
|
||||
if torch.backends.mps.is_available():
|
||||
torch.mps.empty_cache()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@classmethod
|
||||
def _to_dtype(cls, precision_name: TorchPrecisionNames) -> torch.dtype:
|
||||
return NAME_TO_PRECISION[precision_name]
|
||||
|
||||
53
invokeai/backend/util/mask.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import torch
|
||||
|
||||
|
||||
def to_standard_mask_dim(mask: torch.Tensor) -> torch.Tensor:
|
||||
"""Standardize the dimensions of a mask tensor.
|
||||
|
||||
Args:
|
||||
mask (torch.Tensor): A mask tensor. The shape can be (1, h, w) or (h, w).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The output mask tensor. The shape is (1, h, w).
|
||||
"""
|
||||
# Get the mask height and width.
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
elif mask.ndim == 3 and mask.shape[0] == 1:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Unsupported mask shape: {mask.shape}. Expected (1, h, w) or (h, w).")
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
def to_standard_float_mask(mask: torch.Tensor, out_dtype: torch.dtype) -> torch.Tensor:
|
||||
"""Standardize the format of a mask tensor.
|
||||
|
||||
Args:
|
||||
mask (torch.Tensor): A mask tensor. The dtype can be any bool, float, or int type. The shape must be (1, h, w)
|
||||
or (h, w).
|
||||
|
||||
out_dtype (torch.dtype): The dtype of the output mask tensor. Must be a float type.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The output mask tensor. The dtype is out_dtype. The shape is (1, h, w). All values are either 0.0
|
||||
or 1.0.
|
||||
"""
|
||||
|
||||
if not out_dtype.is_floating_point:
|
||||
raise ValueError(f"out_dtype must be a float type, but got {out_dtype}")
|
||||
|
||||
mask = to_standard_mask_dim(mask)
|
||||
mask = mask.to(out_dtype)
|
||||
|
||||
# Set masked regions to 1.0.
|
||||
if mask.dtype == torch.bool:
|
||||
mask = mask.to(out_dtype)
|
||||
else:
|
||||
mask = mask.to(out_dtype)
|
||||
mask_region = mask > 0.5
|
||||
mask[mask_region] = 1.0
|
||||
mask[~mask_region] = 0.0
|
||||
|
||||
return mask
|
||||
@@ -8,7 +8,7 @@
|
||||
<meta http-equiv="Pragma" content="no-cache">
|
||||
<meta http-equiv="Expires" content="0">
|
||||
<title>Invoke - Community Edition</title>
|
||||
<link rel="icon" type="icon" href="assets/images/invoke-favicon.svg" />
|
||||
<link id="invoke-favicon" rel="icon" type="icon" href="assets/images/invoke-favicon.svg" />
|
||||
<style>
|
||||
html,
|
||||
body {
|
||||
@@ -23,4 +23,4 @@
|
||||
<script type="module" src="/src/main.tsx"></script>
|
||||
</body>
|
||||
|
||||
</html>
|
||||
</html>
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import type { KnipConfig } from 'knip';
|
||||
|
||||
const config: KnipConfig = {
|
||||
project: ['src/**/*.{ts,tsx}!'],
|
||||
ignore: [
|
||||
// This file is only used during debugging
|
||||
'src/app/store/middleware/debugLoggerMiddleware.ts',
|
||||
@@ -10,6 +11,9 @@ const config: KnipConfig = {
|
||||
'src/features/nodes/types/v2/**',
|
||||
],
|
||||
ignoreBinaries: ['only-allow'],
|
||||
paths: {
|
||||
'public/*': ['public/*'],
|
||||
},
|
||||
};
|
||||
|
||||
export default config;
|
||||
|
||||
@@ -24,7 +24,7 @@
|
||||
"build": "pnpm run lint && vite build",
|
||||
"typegen": "node scripts/typegen.js",
|
||||
"preview": "vite preview",
|
||||
"lint:knip": "knip --tags=-@knipignore",
|
||||
"lint:knip": "knip",
|
||||
"lint:dpdm": "dpdm --no-warning --no-tree --transform --exit-code circular:1 src/main.tsx",
|
||||
"lint:eslint": "eslint --max-warnings=0 .",
|
||||
"lint:prettier": "prettier --check .",
|
||||
@@ -52,6 +52,7 @@
|
||||
},
|
||||
"dependencies": {
|
||||
"@chakra-ui/react-use-size": "^2.1.0",
|
||||
"@dagrejs/dagre": "^1.1.1",
|
||||
"@dagrejs/graphlib": "^2.2.1",
|
||||
"@dnd-kit/core": "^6.1.0",
|
||||
"@dnd-kit/sortable": "^8.0.0",
|
||||
|
||||
9
invokeai/frontend/web/pnpm-lock.yaml
generated
@@ -11,6 +11,9 @@ dependencies:
|
||||
'@chakra-ui/react-use-size':
|
||||
specifier: ^2.1.0
|
||||
version: 2.1.0(react@18.2.0)
|
||||
'@dagrejs/dagre':
|
||||
specifier: ^1.1.1
|
||||
version: 1.1.1
|
||||
'@dagrejs/graphlib':
|
||||
specifier: ^2.2.1
|
||||
version: 2.2.1
|
||||
@@ -3092,6 +3095,12 @@ packages:
|
||||
dev: true
|
||||
optional: true
|
||||
|
||||
/@dagrejs/dagre@1.1.1:
|
||||
resolution: {integrity: sha512-AQfT6pffEuPE32weFzhS/u3UpX+bRXUARIXL7UqLaxz497cN8pjuBlX6axO4IIECE2gBV8eLFQkGCtKX5sDaUA==}
|
||||
dependencies:
|
||||
'@dagrejs/graphlib': 2.2.1
|
||||
dev: false
|
||||
|
||||
/@dagrejs/graphlib@2.2.1:
|
||||
resolution: {integrity: sha512-xJsN1v6OAxXk6jmNdM+OS/bBE8nDCwM0yDNprXR18ZNatL6to9ggod9+l2XtiLhXfLm0NkE7+Er/cpdlM+SkUA==}
|
||||
engines: {node: '>17.0.0'}
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<rect width="16" height="16" rx="2" fill="#E6FD13"/>
|
||||
<path d="M9.61889 5.45H12.5V3.5H3.5V5.45H6.38111L9.61889 10.55H12.5V12.5H3.5V10.55H6.38111" stroke="black"/>
|
||||
<circle cx="12" cy="4" r="3" fill="#f5480c" stroke="#0d1117" stroke-width="1"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 345 B |
@@ -85,7 +85,8 @@
|
||||
"loadMore": "Mehr laden",
|
||||
"noImagesInGallery": "Keine Bilder in der Galerie",
|
||||
"loading": "Lade",
|
||||
"deleteImage": "Lösche Bild",
|
||||
"deleteImage_one": "Lösche Bild",
|
||||
"deleteImage_other": "",
|
||||
"copy": "Kopieren",
|
||||
"download": "Runterladen",
|
||||
"setCurrentImage": "Setze aktuelle Bild",
|
||||
|
||||
@@ -69,6 +69,7 @@
|
||||
"auto": "Auto",
|
||||
"back": "Back",
|
||||
"batch": "Batch Manager",
|
||||
"beta": "Beta",
|
||||
"cancel": "Cancel",
|
||||
"copy": "Copy",
|
||||
"copyError": "$t(gallery.copy) Error",
|
||||
@@ -213,6 +214,10 @@
|
||||
"resize": "Resize",
|
||||
"resizeSimple": "Resize (Simple)",
|
||||
"resizeMode": "Resize Mode",
|
||||
"ipAdapterMethod": "Method",
|
||||
"full": "Full",
|
||||
"style": "Style Only",
|
||||
"composition": "Composition Only",
|
||||
"safe": "Safe",
|
||||
"saveControlImage": "Save Control Image",
|
||||
"scribble": "scribble",
|
||||
@@ -326,7 +331,8 @@
|
||||
"drop": "Drop",
|
||||
"dropOrUpload": "$t(gallery.drop) or Upload",
|
||||
"dropToUpload": "$t(gallery.drop) to Upload",
|
||||
"deleteImage": "Delete Image",
|
||||
"deleteImage_one": "Delete Image",
|
||||
"deleteImage_other": "Delete {{count}} Images",
|
||||
"deleteImageBin": "Deleted images will be sent to your operating system's Bin.",
|
||||
"deleteImagePermanent": "Deleted images cannot be restored.",
|
||||
"download": "Download",
|
||||
@@ -769,6 +775,8 @@
|
||||
"float": "Float",
|
||||
"fullyContainNodes": "Fully Contain Nodes to Select",
|
||||
"fullyContainNodesHelp": "Nodes must be fully inside the selection box to be selected",
|
||||
"showEdgeLabels": "Show Edge Labels",
|
||||
"showEdgeLabelsHelp": "Show labels on edges, indicating the connected nodes",
|
||||
"hideLegendNodes": "Hide Field Type Legend",
|
||||
"hideMinimapnodes": "Hide MiniMap",
|
||||
"inputMayOnlyHaveOneConnection": "Input may only have one connection",
|
||||
@@ -849,6 +857,7 @@
|
||||
"version": "Version",
|
||||
"versionUnknown": " Version Unknown",
|
||||
"workflow": "Workflow",
|
||||
"graph": "Graph",
|
||||
"workflowAuthor": "Author",
|
||||
"workflowContact": "Contact",
|
||||
"workflowDescription": "Short Description",
|
||||
@@ -1423,6 +1432,7 @@
|
||||
"eraseBoundingBox": "Erase Bounding Box",
|
||||
"eraser": "Eraser",
|
||||
"fillBoundingBox": "Fill Bounding Box",
|
||||
"hideBoundingBox": "Hide Bounding Box",
|
||||
"initialFitImageSize": "Fit Image Size on Drop",
|
||||
"invertBrushSizeScrollDirection": "Invert Scroll for Brush Size",
|
||||
"layer": "Layer",
|
||||
@@ -1440,6 +1450,7 @@
|
||||
"saveMask": "Save $t(unifiedCanvas.mask)",
|
||||
"saveToGallery": "Save To Gallery",
|
||||
"scaledBoundingBox": "Scaled Bounding Box",
|
||||
"showBoundingBox": "Show Bounding Box",
|
||||
"showCanvasDebugInfo": "Show Additional Canvas Info",
|
||||
"showGrid": "Show Grid",
|
||||
"showResultsOn": "Show Results (On)",
|
||||
@@ -1482,7 +1493,11 @@
|
||||
"workflowName": "Workflow Name",
|
||||
"newWorkflowCreated": "New Workflow Created",
|
||||
"workflowCleared": "Workflow Cleared",
|
||||
"workflowEditorMenu": "Workflow Editor Menu"
|
||||
"workflowEditorMenu": "Workflow Editor Menu",
|
||||
"loadFromGraph": "Load Workflow from Graph",
|
||||
"convertGraph": "Convert Graph",
|
||||
"loadWorkflow": "$t(common.load) Workflow",
|
||||
"autoLayout": "Auto Layout"
|
||||
},
|
||||
"app": {
|
||||
"storeNotInitialized": "Store is not initialized"
|
||||
|
||||
@@ -33,7 +33,9 @@
|
||||
"autoSwitchNewImages": "Auto seleccionar Imágenes nuevas",
|
||||
"loadMore": "Cargar más",
|
||||
"noImagesInGallery": "No hay imágenes para mostrar",
|
||||
"deleteImage": "Eliminar Imagen",
|
||||
"deleteImage_one": "Eliminar Imagen",
|
||||
"deleteImage_many": "",
|
||||
"deleteImage_other": "",
|
||||
"deleteImageBin": "Las imágenes eliminadas se enviarán a la papelera de tu sistema operativo.",
|
||||
"deleteImagePermanent": "Las imágenes eliminadas no se pueden restaurar.",
|
||||
"assets": "Activos",
|
||||
|
||||
@@ -82,7 +82,9 @@
|
||||
"autoSwitchNewImages": "Passaggio automatico a nuove immagini",
|
||||
"loadMore": "Carica altro",
|
||||
"noImagesInGallery": "Nessuna immagine da visualizzare",
|
||||
"deleteImage": "Elimina l'immagine",
|
||||
"deleteImage_one": "Elimina l'immagine",
|
||||
"deleteImage_many": "Elimina {{count}} immagini",
|
||||
"deleteImage_other": "Elimina {{count}} immagini",
|
||||
"deleteImagePermanent": "Le immagini eliminate non possono essere ripristinate.",
|
||||
"deleteImageBin": "Le immagini eliminate verranno spostate nel cestino del tuo sistema operativo.",
|
||||
"assets": "Risorse",
|
||||
@@ -444,7 +446,8 @@
|
||||
"hfTokenInvalidErrorMessage2": "Aggiornalo in ",
|
||||
"main": "Principali",
|
||||
"noModelsInstalledDesc1": "Installa i modelli con",
|
||||
"ipAdapters": "Adattatori IP"
|
||||
"ipAdapters": "Adattatori IP",
|
||||
"noMatchingModels": "Nessun modello corrispondente"
|
||||
},
|
||||
"parameters": {
|
||||
"images": "Immagini",
|
||||
@@ -526,7 +529,12 @@
|
||||
"aspect": "Aspetto",
|
||||
"setToOptimalSizeTooLarge": "$t(parameters.setToOptimalSize) (potrebbe essere troppo grande)",
|
||||
"remixImage": "Remixa l'immagine",
|
||||
"coherenceEdgeSize": "Dim. bordo"
|
||||
"coherenceEdgeSize": "Dim. bordo",
|
||||
"infillMosaicTileWidth": "Larghezza piastrella",
|
||||
"infillMosaicMinColor": "Colore minimo",
|
||||
"infillMosaicMaxColor": "Colore massimo",
|
||||
"infillMosaicTileHeight": "Altezza piastrella",
|
||||
"infillColorValue": "Colore di riempimento"
|
||||
},
|
||||
"settings": {
|
||||
"models": "Modelli",
|
||||
@@ -620,7 +628,8 @@
|
||||
"uploadInitialImage": "Carica l'immagine iniziale",
|
||||
"problemDownloadingImage": "Impossibile scaricare l'immagine",
|
||||
"prunedQueue": "Coda ripulita",
|
||||
"modelImportCanceled": "Importazione del modello annullata"
|
||||
"modelImportCanceled": "Importazione del modello annullata",
|
||||
"parameters": "Parametri"
|
||||
},
|
||||
"tooltip": {
|
||||
"feature": {
|
||||
@@ -689,7 +698,10 @@
|
||||
"coherenceModeBoxBlur": "Sfocatura Box",
|
||||
"coherenceModeStaged": "Maschera espansa",
|
||||
"invertBrushSizeScrollDirection": "Inverti scorrimento per dimensione pennello",
|
||||
"discardCurrent": "Scarta l'attuale"
|
||||
"discardCurrent": "Scarta l'attuale",
|
||||
"initialFitImageSize": "Adatta dimensione immagine al rilascio",
|
||||
"hideBoundingBox": "Nascondi il rettangolo di selezione",
|
||||
"showBoundingBox": "Mostra il rettangolo di selezione"
|
||||
},
|
||||
"accessibility": {
|
||||
"invokeProgressBar": "Barra di avanzamento generazione",
|
||||
@@ -832,7 +844,8 @@
|
||||
"editMode": "Modifica nell'editor del flusso di lavoro",
|
||||
"resetToDefaultValue": "Ripristina il valore predefinito",
|
||||
"noFieldsViewMode": "Questo flusso di lavoro non ha campi selezionati da visualizzare. Visualizza il flusso di lavoro completo per configurare i valori.",
|
||||
"edit": "Modifica"
|
||||
"edit": "Modifica",
|
||||
"graph": "Grafico"
|
||||
},
|
||||
"boards": {
|
||||
"autoAddBoard": "Aggiungi automaticamente bacheca",
|
||||
@@ -1346,13 +1359,13 @@
|
||||
]
|
||||
},
|
||||
"seamlessTilingXAxis": {
|
||||
"heading": "Asse X di piastrellatura senza cuciture",
|
||||
"heading": "Piastrella senza giunte sull'asse X",
|
||||
"paragraphs": [
|
||||
"Affianca senza soluzione di continuità un'immagine lungo l'asse orizzontale."
|
||||
]
|
||||
},
|
||||
"seamlessTilingYAxis": {
|
||||
"heading": "Asse Y di piastrellatura senza cuciture",
|
||||
"heading": "Piastrella senza giunte sull'asse Y",
|
||||
"paragraphs": [
|
||||
"Affianca senza soluzione di continuità un'immagine lungo l'asse verticale."
|
||||
]
|
||||
@@ -1476,7 +1489,11 @@
|
||||
"name": "Nome",
|
||||
"updated": "Aggiornato",
|
||||
"projectWorkflows": "Flussi di lavoro del progetto",
|
||||
"opened": "Aperto"
|
||||
"opened": "Aperto",
|
||||
"convertGraph": "Converti grafico",
|
||||
"loadWorkflow": "$t(common.load) Flusso di lavoro",
|
||||
"autoLayout": "Disposizione automatica",
|
||||
"loadFromGraph": "Carica il flusso di lavoro dal grafico"
|
||||
},
|
||||
"app": {
|
||||
"storeNotInitialized": "Il negozio non è inizializzato"
|
||||
|
||||
@@ -90,7 +90,7 @@
|
||||
"problemDeletingImages": "画像の削除中に問題が発生",
|
||||
"drop": "ドロップ",
|
||||
"dropOrUpload": "$t(gallery.drop) またはアップロード",
|
||||
"deleteImage": "画像を削除",
|
||||
"deleteImage_other": "画像を削除",
|
||||
"deleteImageBin": "削除された画像はOSのゴミ箱に送られます。",
|
||||
"deleteImagePermanent": "削除された画像は復元できません。",
|
||||
"download": "ダウンロード",
|
||||
|
||||
@@ -82,7 +82,7 @@
|
||||
"drop": "드랍",
|
||||
"problemDeletingImages": "이미지 삭제 중 발생한 문제",
|
||||
"downloadSelection": "선택 항목 다운로드",
|
||||
"deleteImage": "이미지 삭제",
|
||||
"deleteImage_other": "이미지 삭제",
|
||||
"currentlyInUse": "이 이미지는 현재 다음 기능에서 사용되고 있습니다:",
|
||||
"dropOrUpload": "$t(gallery.drop) 또는 업로드",
|
||||
"copy": "복사",
|
||||
|
||||
@@ -42,7 +42,8 @@
|
||||
"autoSwitchNewImages": "Wissel autom. naar nieuwe afbeeldingen",
|
||||
"loadMore": "Laad meer",
|
||||
"noImagesInGallery": "Geen afbeeldingen om te tonen",
|
||||
"deleteImage": "Verwijder afbeelding",
|
||||
"deleteImage_one": "Verwijder afbeelding",
|
||||
"deleteImage_other": "",
|
||||
"deleteImageBin": "Verwijderde afbeeldingen worden naar de prullenbak van je besturingssysteem gestuurd.",
|
||||
"deleteImagePermanent": "Verwijderde afbeeldingen kunnen niet worden hersteld.",
|
||||
"assets": "Eigen onderdelen",
|
||||
|
||||
@@ -86,7 +86,9 @@
|
||||
"noImagesInGallery": "Изображений нет",
|
||||
"deleteImagePermanent": "Удаленные изображения невозможно восстановить.",
|
||||
"deleteImageBin": "Удаленные изображения будут отправлены в корзину вашей операционной системы.",
|
||||
"deleteImage": "Удалить изображение",
|
||||
"deleteImage_one": "Удалить изображение",
|
||||
"deleteImage_few": "",
|
||||
"deleteImage_many": "",
|
||||
"assets": "Ресурсы",
|
||||
"autoAssignBoardOnClick": "Авто-назначение доски по клику",
|
||||
"deleteSelection": "Удалить выделенное",
|
||||
@@ -448,7 +450,9 @@
|
||||
"loraModels": "LoRAs",
|
||||
"main": "Основные",
|
||||
"noModelsInstalled": "Нет установленных моделей",
|
||||
"noModelsInstalledDesc1": "Установите модели с помощью"
|
||||
"noModelsInstalledDesc1": "Установите модели с помощью",
|
||||
"noMatchingModels": "Нет подходящих моделей",
|
||||
"ipAdapters": "IP адаптеры"
|
||||
},
|
||||
"parameters": {
|
||||
"images": "Изображения",
|
||||
@@ -532,7 +536,12 @@
|
||||
"lockAspectRatio": "Заблокировать соотношение",
|
||||
"remixImage": "Ремикс изображения",
|
||||
"coherenceMinDenoise": "Мин. шумоподавление",
|
||||
"coherenceEdgeSize": "Размер края"
|
||||
"coherenceEdgeSize": "Размер края",
|
||||
"infillMosaicTileWidth": "Ширина плиток",
|
||||
"infillMosaicTileHeight": "Высота плиток",
|
||||
"infillMosaicMinColor": "Мин цвет",
|
||||
"infillMosaicMaxColor": "Макс цвет",
|
||||
"infillColorValue": "Цвет заливки"
|
||||
},
|
||||
"settings": {
|
||||
"models": "Модели",
|
||||
@@ -626,7 +635,8 @@
|
||||
"uploadInitialImage": "Загрузить начальное изображение",
|
||||
"resetInitialImage": "Сбросить начальное изображение",
|
||||
"prunedQueue": "Урезанная очередь",
|
||||
"modelImportCanceled": "Импорт модели отменен"
|
||||
"modelImportCanceled": "Импорт модели отменен",
|
||||
"parameters": "Параметры"
|
||||
},
|
||||
"tooltip": {
|
||||
"feature": {
|
||||
@@ -695,7 +705,8 @@
|
||||
"coherenceModeGaussianBlur": "Размытие по Гауссу",
|
||||
"coherenceModeBoxBlur": "коробчатое размытие",
|
||||
"discardCurrent": "Отбросить текущее",
|
||||
"invertBrushSizeScrollDirection": "Инвертировать прокрутку для размера кисти"
|
||||
"invertBrushSizeScrollDirection": "Инвертировать прокрутку для размера кисти",
|
||||
"initialFitImageSize": "Подогнать размер изображения при перебросе"
|
||||
},
|
||||
"accessibility": {
|
||||
"uploadImage": "Загрузить изображение",
|
||||
@@ -921,7 +932,8 @@
|
||||
"modelSize": "Размер модели",
|
||||
"small": "Маленький",
|
||||
"body": "Тело",
|
||||
"hands": "Руки"
|
||||
"hands": "Руки",
|
||||
"selectCLIPVisionModel": "Выбрать модель CLIP Vision"
|
||||
},
|
||||
"boards": {
|
||||
"autoAddBoard": "Авто добавление Доски",
|
||||
|
||||
@@ -298,7 +298,8 @@
|
||||
"noImagesInGallery": "Gösterilecek Görsel Yok",
|
||||
"autoSwitchNewImages": "Yeni Görseli Biter Bitmez Gör",
|
||||
"currentlyInUse": "Bu görsel şurada kullanımda:",
|
||||
"deleteImage": "Görseli Sil",
|
||||
"deleteImage_one": "Görseli Sil",
|
||||
"deleteImage_other": "",
|
||||
"loadMore": "Daha Getir",
|
||||
"setCurrentImage": "Çalışma Görseli Yap",
|
||||
"unableToLoad": "Galeri Yüklenemedi",
|
||||
|
||||
@@ -65,7 +65,12 @@
|
||||
"nextPage": "下一页",
|
||||
"saveAs": "保存为",
|
||||
"ai": "ai",
|
||||
"or": "或"
|
||||
"or": "或",
|
||||
"aboutDesc": "使用 Invoke 工作?查看:",
|
||||
"add": "添加",
|
||||
"loglevel": "日志级别",
|
||||
"copy": "复制",
|
||||
"localSystem": "本地系统"
|
||||
},
|
||||
"gallery": {
|
||||
"galleryImageSize": "预览大小",
|
||||
@@ -73,7 +78,7 @@
|
||||
"autoSwitchNewImages": "自动切换到新图像",
|
||||
"loadMore": "加载更多",
|
||||
"noImagesInGallery": "无图像可用于显示",
|
||||
"deleteImage": "删除图片",
|
||||
"deleteImage_other": "删除图片",
|
||||
"deleteImageBin": "被删除的图片会发送到你操作系统的回收站。",
|
||||
"deleteImagePermanent": "删除的图片无法被恢复。",
|
||||
"assets": "素材",
|
||||
@@ -599,7 +604,8 @@
|
||||
"loadMore": "加载更多",
|
||||
"mode": "模式",
|
||||
"resetUI": "$t(accessibility.reset) UI",
|
||||
"createIssue": "创建问题"
|
||||
"createIssue": "创建问题",
|
||||
"about": "关于"
|
||||
},
|
||||
"tooltip": {
|
||||
"feature": {
|
||||
@@ -1201,7 +1207,16 @@
|
||||
"workflows": "工作流",
|
||||
"noDescription": "无描述",
|
||||
"uploadWorkflow": "从文件中加载",
|
||||
"newWorkflowCreated": "已创建新的工作流"
|
||||
"newWorkflowCreated": "已创建新的工作流",
|
||||
"name": "名称",
|
||||
"defaultWorkflows": "默认工作流",
|
||||
"created": "已创建",
|
||||
"ascending": "升序",
|
||||
"descending": "降序",
|
||||
"updated": "已更新",
|
||||
"userWorkflows": "我的工作流",
|
||||
"projectWorkflows": "项目工作流",
|
||||
"opened": "已打开"
|
||||
},
|
||||
"app": {
|
||||
"storeNotInitialized": "商店尚未初始化"
|
||||
@@ -1219,7 +1234,8 @@
|
||||
"title": "生成"
|
||||
},
|
||||
"advanced": {
|
||||
"title": "高级"
|
||||
"title": "高级",
|
||||
"options": "$t(accordions.advanced.title) 选项"
|
||||
},
|
||||
"image": {
|
||||
"title": "图像"
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { Box, useGlobalModifiersInit } from '@invoke-ai/ui-library';
|
||||
import { useSocketIO } from 'app/hooks/useSocketIO';
|
||||
import { useSyncQueueStatus } from 'app/hooks/useSyncQueueStatus';
|
||||
import { useLogger } from 'app/logging/useLogger';
|
||||
import { appStarted } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
@@ -70,6 +71,7 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage }: Props) => {
|
||||
}, [dispatch]);
|
||||
|
||||
useStarterModelsToast();
|
||||
useSyncQueueStatus();
|
||||
|
||||
return (
|
||||
<ErrorBoundary onReset={handleReset} FallbackComponent={AppErrorBoundaryFallback}>
|
||||
|
||||
25
invokeai/frontend/web/src/app/hooks/useSyncQueueStatus.ts
Normal file
@@ -0,0 +1,25 @@
|
||||
import { useEffect } from 'react';
|
||||
import { useGetQueueStatusQuery } from 'services/api/endpoints/queue';
|
||||
|
||||
const baseTitle = document.title;
|
||||
const invokeLogoSVG = 'assets/images/invoke-favicon.svg';
|
||||
const invokeAlertLogoSVG = 'assets/images/invoke-alert-favicon.svg';
|
||||
|
||||
/**
|
||||
* This hook synchronizes the queue status with the page's title and favicon.
|
||||
* It should be considered a singleton and only used once in the component tree.
|
||||
*/
|
||||
export const useSyncQueueStatus = () => {
|
||||
const { queueSize } = useGetQueueStatusQuery(undefined, {
|
||||
selectFromResult: (res) => ({
|
||||
queueSize: res.data ? res.data.queue.pending + res.data.queue.in_progress : 0,
|
||||
}),
|
||||
});
|
||||
useEffect(() => {
|
||||
document.title = queueSize > 0 ? `(${queueSize}) ${baseTitle}` : baseTitle;
|
||||
const faviconEl = document.getElementById('invoke-favicon');
|
||||
if (faviconEl instanceof HTMLLinkElement) {
|
||||
faviconEl.href = queueSize > 0 ? invokeAlertLogoSVG : invokeLogoSVG;
|
||||
}
|
||||
}, [queueSize]);
|
||||
};
|
||||
@@ -1,12 +1,18 @@
|
||||
import { isAnyOf } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { canvasBatchIdsReset, commitStagingAreaImage, discardStagedImages } from 'features/canvas/store/canvasSlice';
|
||||
import {
|
||||
canvasBatchIdsReset,
|
||||
commitStagingAreaImage,
|
||||
discardStagedImages,
|
||||
resetCanvas,
|
||||
setInitialCanvasImage,
|
||||
} from 'features/canvas/store/canvasSlice';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { t } from 'i18next';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
|
||||
const matcher = isAnyOf(commitStagingAreaImage, discardStagedImages);
|
||||
const matcher = isAnyOf(commitStagingAreaImage, discardStagedImages, resetCanvas, setInitialCanvasImage);
|
||||
|
||||
export const addCommitStagingAreaImageListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import { Flex, Image, Spinner } from '@invoke-ai/ui-library';
|
||||
/** @knipignore */
|
||||
import InvokeLogoWhite from 'public/assets/images/invoke-symbol-wht-lrg.svg';
|
||||
import { memo } from 'react';
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ import { useHotkeys } from 'react-hotkeys-hook';
|
||||
|
||||
export const useGlobalHotkeys = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const isModelManagerEnabled = useFeatureStatus('modelManager').isFeatureEnabled;
|
||||
const isModelManagerEnabled = useFeatureStatus('modelManager');
|
||||
const { queueBack, isDisabled: isDisabledQueueBack, isLoading: isLoadingQueueBack } = useQueueBack();
|
||||
|
||||
useHotkeys(
|
||||
|
||||
@@ -49,14 +49,20 @@ const selector = createMemoizedSelector(selectCanvasSlice, (canvas) => {
|
||||
const ClearStagingIntermediatesIconButton = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const totalStagedImages = useAppSelector((s) => s.canvas.layerState.stagingArea.images.length);
|
||||
|
||||
const handleDiscardStagingArea = useCallback(() => {
|
||||
dispatch(discardStagedImages());
|
||||
}, [dispatch]);
|
||||
|
||||
const handleDiscardStagingImage = useCallback(() => {
|
||||
dispatch(discardStagedImage());
|
||||
}, [dispatch]);
|
||||
// Discarding all staged images triggers cancelation of all canvas batches. It's too easy to accidentally
|
||||
// click the discard button, so to prevent accidental cancelation of all batches, we only discard the current
|
||||
// image if there are more than one staged images.
|
||||
if (totalStagedImages > 1) {
|
||||
dispatch(discardStagedImage());
|
||||
}
|
||||
}, [dispatch, totalStagedImages]);
|
||||
|
||||
return (
|
||||
<>
|
||||
@@ -67,6 +73,7 @@ const ClearStagingIntermediatesIconButton = () => {
|
||||
onClick={handleDiscardStagingImage}
|
||||
colorScheme="invokeBlue"
|
||||
fontSize={16}
|
||||
isDisabled={totalStagedImages <= 1}
|
||||
/>
|
||||
<IconButton
|
||||
tooltip={`${t('unifiedCanvas.discardAll')} (Esc)`}
|
||||
|
||||
@@ -13,7 +13,13 @@ import {
|
||||
} from 'features/canvas/store/actions';
|
||||
import { $canvasBaseLayer, $tool } from 'features/canvas/store/canvasNanostore';
|
||||
import { isStagingSelector } from 'features/canvas/store/canvasSelectors';
|
||||
import { resetCanvas, resetCanvasView, setIsMaskEnabled, setLayer } from 'features/canvas/store/canvasSlice';
|
||||
import {
|
||||
resetCanvas,
|
||||
resetCanvasView,
|
||||
setIsMaskEnabled,
|
||||
setLayer,
|
||||
setShouldShowBoundingBox,
|
||||
} from 'features/canvas/store/canvasSlice';
|
||||
import type { CanvasLayer } from 'features/canvas/store/canvasTypes';
|
||||
import { LAYER_NAMES_DICT } from 'features/canvas/store/canvasTypes';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
@@ -23,6 +29,8 @@ import {
|
||||
PiCopyBold,
|
||||
PiCrosshairSimpleBold,
|
||||
PiDownloadSimpleBold,
|
||||
PiEyeBold,
|
||||
PiEyeSlashBold,
|
||||
PiFloppyDiskBold,
|
||||
PiHandGrabbingBold,
|
||||
PiStackBold,
|
||||
@@ -44,6 +52,7 @@ const IAICanvasToolbar = () => {
|
||||
const isStaging = useAppSelector(isStagingSelector);
|
||||
const { t } = useTranslation();
|
||||
const { isClipboardAPIAvailable } = useCopyImageToClipboard();
|
||||
const shouldShowBoundingBox = useAppSelector((s) => s.canvas.shouldShowBoundingBox);
|
||||
|
||||
const { getUploadButtonProps, getUploadInputProps } = useImageUploadButton({
|
||||
postUploadAction: { type: 'SET_CANVAS_INITIAL_IMAGE' },
|
||||
@@ -61,6 +70,18 @@ const IAICanvasToolbar = () => {
|
||||
[]
|
||||
);
|
||||
|
||||
useHotkeys(
|
||||
'shift+h',
|
||||
() => {
|
||||
dispatch(setShouldShowBoundingBox(!shouldShowBoundingBox));
|
||||
},
|
||||
{
|
||||
enabled: () => !isStaging,
|
||||
preventDefault: true,
|
||||
},
|
||||
[shouldShowBoundingBox]
|
||||
);
|
||||
|
||||
useHotkeys(
|
||||
['r'],
|
||||
() => {
|
||||
@@ -125,6 +146,10 @@ const IAICanvasToolbar = () => {
|
||||
$tool.set('move');
|
||||
}, []);
|
||||
|
||||
const handleSetShouldShowBoundingBox = useCallback(() => {
|
||||
dispatch(setShouldShowBoundingBox(!shouldShowBoundingBox));
|
||||
}, [dispatch, shouldShowBoundingBox]);
|
||||
|
||||
const handleResetCanvasView = useCallback(
|
||||
(shouldScaleTo1 = false) => {
|
||||
const canvasBaseLayer = $canvasBaseLayer.get();
|
||||
@@ -212,6 +237,13 @@ const IAICanvasToolbar = () => {
|
||||
isChecked={tool === 'move' || isStaging}
|
||||
onClick={handleSelectMoveTool}
|
||||
/>
|
||||
<IconButton
|
||||
aria-label={`${shouldShowBoundingBox ? t('unifiedCanvas.hideBoundingBox') : t('unifiedCanvas.showBoundingBox')} (Shift + H)`}
|
||||
tooltip={`${shouldShowBoundingBox ? t('unifiedCanvas.hideBoundingBox') : t('unifiedCanvas.showBoundingBox')} (Shift + H)`}
|
||||
icon={shouldShowBoundingBox ? <PiEyeBold /> : <PiEyeSlashBold />}
|
||||
onClick={handleSetShouldShowBoundingBox}
|
||||
isDisabled={isStaging}
|
||||
/>
|
||||
<IconButton
|
||||
aria-label={`${t('unifiedCanvas.resetView')} (R)`}
|
||||
tooltip={`${t('unifiedCanvas.resetView')} (R)`}
|
||||
|
||||
@@ -7,12 +7,7 @@ import {
|
||||
resetToolInteractionState,
|
||||
} from 'features/canvas/store/canvasNanostore';
|
||||
import { isStagingSelector } from 'features/canvas/store/canvasSelectors';
|
||||
import {
|
||||
clearMask,
|
||||
setIsMaskEnabled,
|
||||
setShouldShowBoundingBox,
|
||||
setShouldSnapToGrid,
|
||||
} from 'features/canvas/store/canvasSlice';
|
||||
import { clearMask, setIsMaskEnabled, setShouldSnapToGrid } from 'features/canvas/store/canvasSlice';
|
||||
import { isInteractiveTarget } from 'features/canvas/util/isInteractiveTarget';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { useCallback, useEffect } from 'react';
|
||||
@@ -21,7 +16,6 @@ import { useHotkeys } from 'react-hotkeys-hook';
|
||||
const useInpaintingCanvasHotkeys = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const activeTabName = useAppSelector(activeTabNameSelector);
|
||||
const shouldShowBoundingBox = useAppSelector((s) => s.canvas.shouldShowBoundingBox);
|
||||
const isStaging = useAppSelector(isStagingSelector);
|
||||
const isMaskEnabled = useAppSelector((s) => s.canvas.isMaskEnabled);
|
||||
const shouldSnapToGrid = useAppSelector((s) => s.canvas.shouldSnapToGrid);
|
||||
@@ -79,18 +73,6 @@ const useInpaintingCanvasHotkeys = () => {
|
||||
}
|
||||
);
|
||||
|
||||
useHotkeys(
|
||||
'shift+h',
|
||||
() => {
|
||||
dispatch(setShouldShowBoundingBox(!shouldShowBoundingBox));
|
||||
},
|
||||
{
|
||||
enabled: () => !isStaging,
|
||||
preventDefault: true,
|
||||
},
|
||||
[activeTabName, shouldShowBoundingBox]
|
||||
);
|
||||
|
||||
const onKeyDown = useCallback(
|
||||
(e: KeyboardEvent) => {
|
||||
if (e.repeat || e.key !== ' ' || isInteractiveTarget(e.target) || activeTabName !== 'unifiedCanvas') {
|
||||
|
||||
@@ -190,7 +190,6 @@ export const canvasSlice = createSlice({
|
||||
],
|
||||
};
|
||||
state.futureLayerStates = [];
|
||||
state.batchIds = [];
|
||||
|
||||
const newScale = calculateScale(
|
||||
stageDimensions.width,
|
||||
@@ -286,40 +285,14 @@ export const canvasSlice = createSlice({
|
||||
},
|
||||
discardStagedImages: (state) => {
|
||||
pushToPrevLayerStates(state);
|
||||
|
||||
state.layerState.stagingArea = deepClone(initialLayerState.stagingArea);
|
||||
|
||||
resetStagingArea(state);
|
||||
state.futureLayerStates = [];
|
||||
state.shouldShowStagingOutline = true;
|
||||
state.shouldShowStagingImage = true;
|
||||
state.batchIds = [];
|
||||
},
|
||||
discardStagedImage: (state) => {
|
||||
const { images, selectedImageIndex } = state.layerState.stagingArea;
|
||||
pushToPrevLayerStates(state);
|
||||
|
||||
images.splice(selectedImageIndex, 1);
|
||||
|
||||
if (images.length === 0) {
|
||||
pushToPrevLayerStates(state);
|
||||
|
||||
state.layerState.stagingArea = deepClone(initialLayerState.stagingArea);
|
||||
|
||||
state.futureLayerStates = [];
|
||||
state.shouldShowStagingOutline = true;
|
||||
state.shouldShowStagingImage = true;
|
||||
state.batchIds = [];
|
||||
}
|
||||
|
||||
if (selectedImageIndex >= images.length) {
|
||||
state.layerState.stagingArea.selectedImageIndex = images.length - 1;
|
||||
}
|
||||
|
||||
if (!images.length) {
|
||||
state.shouldShowStagingImage = false;
|
||||
state.shouldShowStagingOutline = false;
|
||||
}
|
||||
|
||||
state.layerState.stagingArea.selectedImageIndex = Math.max(0, images.length - 1);
|
||||
state.futureLayerStates = [];
|
||||
},
|
||||
addFillRect: (state) => {
|
||||
@@ -433,7 +406,6 @@ export const canvasSlice = createSlice({
|
||||
pushToPrevLayerStates(state);
|
||||
state.layerState = deepClone(initialLayerState);
|
||||
state.futureLayerStates = [];
|
||||
state.batchIds = [];
|
||||
state.boundingBoxCoordinates = {
|
||||
...initialCanvasState.boundingBoxCoordinates,
|
||||
};
|
||||
@@ -534,12 +506,9 @@ export const canvasSlice = createSlice({
|
||||
...imageToCommit,
|
||||
});
|
||||
}
|
||||
state.layerState.stagingArea = deepClone(initialLayerState.stagingArea);
|
||||
|
||||
resetStagingArea(state);
|
||||
state.futureLayerStates = [];
|
||||
state.shouldShowStagingOutline = true;
|
||||
state.shouldShowStagingImage = true;
|
||||
state.batchIds = [];
|
||||
},
|
||||
setBoundingBoxScaleMethod: {
|
||||
reducer: (state, action: PayloadActionWithOptimalDimension<BoundingBoxScaleMethod>) => {
|
||||
@@ -647,12 +616,19 @@ export const canvasSlice = createSlice({
|
||||
if (batch_status.in_progress === 0 && batch_status.pending === 0) {
|
||||
state.batchIds = state.batchIds.filter((id) => id !== batch_status.batch_id);
|
||||
}
|
||||
|
||||
const queueItemStatus = action.payload.data.queue_item.status;
|
||||
if (queueItemStatus === 'canceled' || queueItemStatus === 'failed') {
|
||||
resetStagingAreaIfEmpty(state);
|
||||
}
|
||||
});
|
||||
builder.addMatcher(queueApi.endpoints.clearQueue.matchFulfilled, (state) => {
|
||||
state.batchIds = [];
|
||||
resetStagingAreaIfEmpty(state);
|
||||
});
|
||||
builder.addMatcher(queueApi.endpoints.cancelByBatchIds.matchFulfilled, (state, action) => {
|
||||
state.batchIds = state.batchIds.filter((id) => !action.meta.arg.originalArgs.batch_ids.includes(id));
|
||||
resetStagingAreaIfEmpty(state);
|
||||
});
|
||||
},
|
||||
});
|
||||
@@ -726,7 +702,7 @@ export const canvasPersistConfig: PersistConfig<CanvasState> = {
|
||||
name: canvasSlice.name,
|
||||
initialState: initialCanvasState,
|
||||
migrate: migrateCanvasState,
|
||||
persistDenylist: [],
|
||||
persistDenylist: ['shouldShowStagingImage', 'shouldShowStagingOutline'],
|
||||
};
|
||||
|
||||
const pushToPrevLayerStates = (state: CanvasState) => {
|
||||
@@ -742,3 +718,15 @@ const pushToFutureLayerStates = (state: CanvasState) => {
|
||||
state.futureLayerStates = state.futureLayerStates.slice(0, MAX_HISTORY);
|
||||
}
|
||||
};
|
||||
|
||||
const resetStagingAreaIfEmpty = (state: CanvasState) => {
|
||||
if (state.batchIds.length === 0 && state.layerState.stagingArea.images.length === 0) {
|
||||
resetStagingArea(state);
|
||||
}
|
||||
};
|
||||
|
||||
const resetStagingArea = (state: CanvasState) => {
|
||||
state.layerState.stagingArea = { ...initialCanvasState.layerState.stagingArea };
|
||||
state.shouldShowStagingImage = initialCanvasState.shouldShowStagingImage;
|
||||
state.shouldShowStagingOutline = initialCanvasState.shouldShowStagingOutline;
|
||||
};
|
||||
|
||||
@@ -21,6 +21,7 @@ import ControlAdapterShouldAutoConfig from './ControlAdapterShouldAutoConfig';
|
||||
import ControlNetCanvasImageImports from './imports/ControlNetCanvasImageImports';
|
||||
import { ParamControlAdapterBeginEnd } from './parameters/ParamControlAdapterBeginEnd';
|
||||
import ParamControlAdapterControlMode from './parameters/ParamControlAdapterControlMode';
|
||||
import ParamControlAdapterIPMethod from './parameters/ParamControlAdapterIPMethod';
|
||||
import ParamControlAdapterProcessorSelect from './parameters/ParamControlAdapterProcessorSelect';
|
||||
import ParamControlAdapterResizeMode from './parameters/ParamControlAdapterResizeMode';
|
||||
import ParamControlAdapterWeight from './parameters/ParamControlAdapterWeight';
|
||||
@@ -111,7 +112,8 @@ const ControlAdapterConfig = (props: { id: string; number: number }) => {
|
||||
|
||||
<Flex w="full" flexDir="column" gap={4}>
|
||||
<Flex gap={8} w="full" alignItems="center">
|
||||
<Flex flexDir="column" gap={2} h={32} w="full">
|
||||
<Flex flexDir="column" gap={4} h={controlAdapterType === 'ip_adapter' ? 40 : 32} w="full">
|
||||
<ParamControlAdapterIPMethod id={id} />
|
||||
<ParamControlAdapterWeight id={id} />
|
||||
<ParamControlAdapterBeginEnd id={id} />
|
||||
</Flex>
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { useControlAdapterIPMethod } from 'features/controlAdapters/hooks/useControlAdapterIPMethod';
|
||||
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
|
||||
import { controlAdapterIPMethodChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import type { IPMethod } from 'features/controlAdapters/store/types';
|
||||
import { isIPMethod } from 'features/controlAdapters/store/types';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
type Props = {
|
||||
id: string;
|
||||
};
|
||||
|
||||
const ParamControlAdapterIPMethod = ({ id }: Props) => {
|
||||
const isEnabled = useControlAdapterIsEnabled(id);
|
||||
const method = useControlAdapterIPMethod(id);
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const options: { label: string; value: IPMethod }[] = useMemo(
|
||||
() => [
|
||||
{ label: t('controlnet.full'), value: 'full' },
|
||||
{ label: `${t('controlnet.style')} (${t('common.beta')})`, value: 'style' },
|
||||
{ label: `${t('controlnet.composition')} (${t('common.beta')})`, value: 'composition' },
|
||||
],
|
||||
[t]
|
||||
);
|
||||
|
||||
const handleIPMethodChanged = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
if (!isIPMethod(v?.value)) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
controlAdapterIPMethodChanged({
|
||||
id,
|
||||
method: v.value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[id, dispatch]
|
||||
);
|
||||
|
||||
const value = useMemo(() => options.find((o) => o.value === method), [options, method]);
|
||||
|
||||
if (!method) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
<InformationalPopover feature="controlNetResizeMode">
|
||||
<FormLabel>{t('controlnet.ipAdapterMethod')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<Combobox value={value} options={options} isDisabled={!isEnabled} onChange={handleIPMethodChanged} />
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamControlAdapterIPMethod);
|
||||
@@ -103,7 +103,7 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
|
||||
|
||||
return (
|
||||
<Flex sx={{ gap: 2 }}>
|
||||
<Tooltip label={value?.description}>
|
||||
<Tooltip label={selectedModel?.description}>
|
||||
<FormControl
|
||||
isDisabled={!isEnabled}
|
||||
isInvalid={!value || mainModel?.base !== modelConfig?.base}
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
selectControlAdapterById,
|
||||
selectControlAdaptersSlice,
|
||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useControlAdapterIPMethod = (id: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectControlAdaptersSlice, (controlAdapters) => {
|
||||
const cn = selectControlAdapterById(controlAdapters, id);
|
||||
if (cn && cn?.type === 'ip_adapter') {
|
||||
return cn.method;
|
||||
}
|
||||
}),
|
||||
[id]
|
||||
);
|
||||
|
||||
const method = useAppSelector(selector);
|
||||
|
||||
return method;
|
||||
};
|
||||
@@ -21,6 +21,7 @@ import type {
|
||||
ControlAdapterType,
|
||||
ControlMode,
|
||||
ControlNetConfig,
|
||||
IPMethod,
|
||||
RequiredControlAdapterProcessorNode,
|
||||
ResizeMode,
|
||||
T2IAdapterConfig,
|
||||
@@ -245,6 +246,10 @@ export const controlAdaptersSlice = createSlice({
|
||||
}
|
||||
caAdapter.updateOne(state, { id, changes: { controlMode } });
|
||||
},
|
||||
controlAdapterIPMethodChanged: (state, action: PayloadAction<{ id: string; method: IPMethod }>) => {
|
||||
const { id, method } = action.payload;
|
||||
caAdapter.updateOne(state, { id, changes: { method } });
|
||||
},
|
||||
controlAdapterCLIPVisionModelChanged: (
|
||||
state,
|
||||
action: PayloadAction<{ id: string; clipVisionModel: CLIPVisionModel }>
|
||||
@@ -390,6 +395,7 @@ export const {
|
||||
controlAdapterIsEnabledChanged,
|
||||
controlAdapterModelChanged,
|
||||
controlAdapterCLIPVisionModelChanged,
|
||||
controlAdapterIPMethodChanged,
|
||||
controlAdapterWeightChanged,
|
||||
controlAdapterBeginStepPctChanged,
|
||||
controlAdapterEndStepPctChanged,
|
||||
|
||||
@@ -210,6 +210,10 @@ const zResizeMode = z.enum(['just_resize', 'crop_resize', 'fill_resize', 'just_r
|
||||
export type ResizeMode = z.infer<typeof zResizeMode>;
|
||||
export const isResizeMode = (v: unknown): v is ResizeMode => zResizeMode.safeParse(v).success;
|
||||
|
||||
const zIPMethod = z.enum(['full', 'style', 'composition']);
|
||||
export type IPMethod = z.infer<typeof zIPMethod>;
|
||||
export const isIPMethod = (v: unknown): v is IPMethod => zIPMethod.safeParse(v).success;
|
||||
|
||||
export type ControlNetConfig = {
|
||||
type: 'controlnet';
|
||||
id: string;
|
||||
@@ -253,6 +257,7 @@ export type IPAdapterConfig = {
|
||||
model: ParameterIPAdapterModel | null;
|
||||
clipVisionModel: CLIPVisionModel;
|
||||
weight: number;
|
||||
method: IPMethod;
|
||||
beginStepPct: number;
|
||||
endStepPct: number;
|
||||
};
|
||||
|
||||
@@ -46,6 +46,7 @@ export const initialIPAdapter: Omit<IPAdapterConfig, 'id'> = {
|
||||
isEnabled: true,
|
||||
controlImage: null,
|
||||
model: null,
|
||||
method: 'full',
|
||||
clipVisionModel: 'ViT-H',
|
||||
weight: 1,
|
||||
beginStepPct: 0,
|
||||
|
||||
@@ -13,13 +13,15 @@ export const DeleteImageButton = memo((props: DeleteImageButtonProps) => {
|
||||
const { onClick, isDisabled } = props;
|
||||
const { t } = useTranslation();
|
||||
const isConnected = useAppSelector((s) => s.system.isConnected);
|
||||
const imageSelectionLength: number = useAppSelector((s) => s.gallery.selection.length);
|
||||
const labelMessage: string = `${t('gallery.deleteImage', { count: imageSelectionLength })} (Del)`;
|
||||
|
||||
return (
|
||||
<IconButton
|
||||
onClick={onClick}
|
||||
icon={<PiTrashSimpleBold />}
|
||||
tooltip={`${t('gallery.deleteImage')} (Del)`}
|
||||
aria-label={`${t('gallery.deleteImage')} (Del)`}
|
||||
tooltip={labelMessage}
|
||||
aria-label={labelMessage}
|
||||
isDisabled={isDisabled || !isConnected}
|
||||
colorScheme="error"
|
||||
/>
|
||||
|
||||
@@ -80,7 +80,7 @@ const DeleteImageModal = () => {
|
||||
|
||||
return (
|
||||
<ConfirmationAlertDialog
|
||||
title={t('gallery.deleteImage')}
|
||||
title={t('gallery.deleteImage', { count: imagesToDelete.length })}
|
||||
isOpen={isModalOpen}
|
||||
onClose={handleClose}
|
||||
cancelButtonText={t('boards.cancel')}
|
||||
|
||||
@@ -32,7 +32,7 @@ const BoardContextMenu = ({ board, board_id, setBoardToDelete, children }: Props
|
||||
|
||||
const isSelectedForAutoAdd = useAppSelector(selectIsSelectedForAutoAdd);
|
||||
const boardName = useBoardName(board_id);
|
||||
const isBulkDownloadEnabled = useFeatureStatus('bulkDownload').isFeatureEnabled;
|
||||
const isBulkDownloadEnabled = useFeatureStatus('bulkDownload');
|
||||
|
||||
const [bulkDownload] = useBulkDownloadImagesMutation();
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ import type { RemoveFromBoardDropData } from 'features/dnd/types';
|
||||
import AutoAddIcon from 'features/gallery/components/Boards/AutoAddIcon';
|
||||
import BoardContextMenu from 'features/gallery/components/Boards/BoardContextMenu';
|
||||
import { autoAddBoardIdChanged, boardIdSelected } from 'features/gallery/store/gallerySlice';
|
||||
/** @knipignore */
|
||||
import InvokeLogoSVG from 'public/assets/images/invoke-symbol-wht-lrg.svg';
|
||||
import { memo, useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
@@ -51,9 +51,10 @@ const CurrentImageButtons = () => {
|
||||
const shouldShowImageDetails = useAppSelector((s) => s.ui.shouldShowImageDetails);
|
||||
const shouldShowProgressInViewer = useAppSelector((s) => s.ui.shouldShowProgressInViewer);
|
||||
const lastSelectedImage = useAppSelector(selectLastSelectedImage);
|
||||
const selection = useAppSelector((s) => s.gallery.selection);
|
||||
const shouldDisableToolbarButtons = useAppSelector(selectShouldDisableToolbarButtons);
|
||||
|
||||
const isUpscalingEnabled = useFeatureStatus('upscaling').isFeatureEnabled;
|
||||
const isUpscalingEnabled = useFeatureStatus('upscaling');
|
||||
const isQueueMutationInProgress = useIsQueueMutationInProgress();
|
||||
const toaster = useAppToaster();
|
||||
const { t } = useTranslation();
|
||||
@@ -102,8 +103,8 @@ const CurrentImageButtons = () => {
|
||||
if (!imageDTO) {
|
||||
return;
|
||||
}
|
||||
dispatch(imagesToDeleteSelected([imageDTO]));
|
||||
}, [dispatch, imageDTO]);
|
||||
dispatch(imagesToDeleteSelected(selection));
|
||||
}, [dispatch, imageDTO, selection]);
|
||||
|
||||
useHotkeys(
|
||||
'Shift+U',
|
||||
|
||||
@@ -20,7 +20,7 @@ const MultipleSelectionMenuItems = () => {
|
||||
const selection = useAppSelector((s) => s.gallery.selection);
|
||||
const customStarUi = useStore($customStarUI);
|
||||
|
||||
const isBulkDownloadEnabled = useFeatureStatus('bulkDownload').isFeatureEnabled;
|
||||
const isBulkDownloadEnabled = useFeatureStatus('bulkDownload');
|
||||
|
||||
const [starImages] = useStarImagesMutation();
|
||||
const [unstarImages] = useUnstarImagesMutation();
|
||||
|
||||
@@ -45,7 +45,7 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const toaster = useAppToaster();
|
||||
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
|
||||
const isCanvasEnabled = useFeatureStatus('unifiedCanvas');
|
||||
const customStarUi = useStore($customStarUI);
|
||||
const { downloadImage } = useDownloadImage();
|
||||
|
||||
@@ -188,7 +188,7 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
||||
)}
|
||||
<MenuDivider />
|
||||
<MenuItem color="error.300" icon={<PiTrashSimpleBold />} onClickCapture={handleDelete}>
|
||||
{t('gallery.deleteImage')}
|
||||
{t('gallery.deleteImage', { count: 1 })}
|
||||
</MenuItem>
|
||||
</>
|
||||
);
|
||||
|
||||
@@ -180,7 +180,7 @@ const GalleryImage = (props: HoverableImageProps) => {
|
||||
<IAIDndImageIcon
|
||||
onClick={handleDelete}
|
||||
icon={<PiTrashSimpleFill size="16px" />}
|
||||
tooltip={t('gallery.deleteImage')}
|
||||
tooltip={t('gallery.deleteImage', { count: 1 })}
|
||||
styleOverrides={imageIconStyleOverrides}
|
||||
/>
|
||||
)}
|
||||
|
||||
@@ -18,7 +18,7 @@ export const useMultiselect = (imageDTO?: ImageDTO) => {
|
||||
[imageDTO?.image_name]
|
||||
);
|
||||
const isSelected = useAppSelector(selectIsSelected);
|
||||
const isMultiSelectEnabled = useFeatureStatus('multiselect').isFeatureEnabled;
|
||||
const isMultiSelectEnabled = useFeatureStatus('multiselect');
|
||||
|
||||
const handleClick = useCallback(
|
||||
(e: MouseEvent<HTMLDivElement>) => {
|
||||
|
||||
@@ -8,7 +8,7 @@ import ParamHrfStrength from './ParamHrfStrength';
|
||||
import ParamHrfToggle from './ParamHrfToggle';
|
||||
|
||||
export const HrfSettings = memo(() => {
|
||||
const isHRFFeatureEnabled = useFeatureStatus('hrf').isFeatureEnabled;
|
||||
const isHRFFeatureEnabled = useFeatureStatus('hrf');
|
||||
const hrfEnabled = useAppSelector((s) => s.hrf.hrfEnabled);
|
||||
|
||||
if (!isHRFFeatureEnabled) {
|
||||
|
||||
@@ -156,8 +156,13 @@ const parseSteps: MetadataParseFunc<ParameterSteps> = (metadata) => getProperty(
|
||||
const parseStrength: MetadataParseFunc<ParameterStrength> = (metadata) =>
|
||||
getProperty(metadata, 'strength', isParameterStrength);
|
||||
|
||||
const parseHRFEnabled: MetadataParseFunc<ParameterHRFEnabled> = (metadata) =>
|
||||
getProperty(metadata, 'hrf_enabled', isParameterHRFEnabled);
|
||||
const parseHRFEnabled: MetadataParseFunc<ParameterHRFEnabled> = async (metadata) => {
|
||||
try {
|
||||
return await getProperty(metadata, 'hrf_enabled', isParameterHRFEnabled);
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
const parseHRFStrength: MetadataParseFunc<ParameterStrength> = (metadata) =>
|
||||
getProperty(metadata, 'hrf_strength', isParameterStrength);
|
||||
@@ -224,12 +229,16 @@ const parseLoRA: MetadataParseFunc<LoRA> = async (metadataItem) => {
|
||||
};
|
||||
|
||||
const parseAllLoRAs: MetadataParseFunc<LoRA[]> = async (metadata) => {
|
||||
const lorasRaw = await getProperty(metadata, 'loras', isArray);
|
||||
const parseResults = await Promise.allSettled(lorasRaw.map((lora) => parseLoRA(lora)));
|
||||
const loras = parseResults
|
||||
.filter((result): result is PromiseFulfilledResult<LoRA> => result.status === 'fulfilled')
|
||||
.map((result) => result.value);
|
||||
return loras;
|
||||
try {
|
||||
const lorasRaw = await getProperty(metadata, 'loras', isArray);
|
||||
const parseResults = await Promise.allSettled(lorasRaw.map((lora) => parseLoRA(lora)));
|
||||
const loras = parseResults
|
||||
.filter((result): result is PromiseFulfilledResult<LoRA> => result.status === 'fulfilled')
|
||||
.map((result) => result.value);
|
||||
return loras;
|
||||
} catch {
|
||||
return [];
|
||||
}
|
||||
};
|
||||
|
||||
const parseControlNet: MetadataParseFunc<ControlNetConfigMetadata> = async (metadataItem) => {
|
||||
@@ -288,12 +297,16 @@ const parseControlNet: MetadataParseFunc<ControlNetConfigMetadata> = async (meta
|
||||
};
|
||||
|
||||
const parseAllControlNets: MetadataParseFunc<ControlNetConfigMetadata[]> = async (metadata) => {
|
||||
const controlNetsRaw = await getProperty(metadata, 'controlnets', isArray);
|
||||
const parseResults = await Promise.allSettled(controlNetsRaw.map((cn) => parseControlNet(cn)));
|
||||
const controlNets = parseResults
|
||||
.filter((result): result is PromiseFulfilledResult<ControlNetConfigMetadata> => result.status === 'fulfilled')
|
||||
.map((result) => result.value);
|
||||
return controlNets;
|
||||
try {
|
||||
const controlNetsRaw = await getProperty(metadata, 'controlnets', isArray || undefined);
|
||||
const parseResults = await Promise.allSettled(controlNetsRaw.map((cn) => parseControlNet(cn)));
|
||||
const controlNets = parseResults
|
||||
.filter((result): result is PromiseFulfilledResult<ControlNetConfigMetadata> => result.status === 'fulfilled')
|
||||
.map((result) => result.value);
|
||||
return controlNets;
|
||||
} catch {
|
||||
return [];
|
||||
}
|
||||
};
|
||||
|
||||
const parseT2IAdapter: MetadataParseFunc<T2IAdapterConfigMetadata> = async (metadataItem) => {
|
||||
@@ -348,12 +361,16 @@ const parseT2IAdapter: MetadataParseFunc<T2IAdapterConfigMetadata> = async (meta
|
||||
};
|
||||
|
||||
const parseAllT2IAdapters: MetadataParseFunc<T2IAdapterConfigMetadata[]> = async (metadata) => {
|
||||
const t2iAdaptersRaw = await getProperty(metadata, 't2iAdapters', isArray);
|
||||
const parseResults = await Promise.allSettled(t2iAdaptersRaw.map((t2iAdapter) => parseT2IAdapter(t2iAdapter)));
|
||||
const t2iAdapters = parseResults
|
||||
.filter((result): result is PromiseFulfilledResult<T2IAdapterConfigMetadata> => result.status === 'fulfilled')
|
||||
.map((result) => result.value);
|
||||
return t2iAdapters;
|
||||
try {
|
||||
const t2iAdaptersRaw = await getProperty(metadata, 't2iAdapters', isArray);
|
||||
const parseResults = await Promise.allSettled(t2iAdaptersRaw.map((t2iAdapter) => parseT2IAdapter(t2iAdapter)));
|
||||
const t2iAdapters = parseResults
|
||||
.filter((result): result is PromiseFulfilledResult<T2IAdapterConfigMetadata> => result.status === 'fulfilled')
|
||||
.map((result) => result.value);
|
||||
return t2iAdapters;
|
||||
} catch {
|
||||
return [];
|
||||
}
|
||||
};
|
||||
|
||||
const parseIPAdapter: MetadataParseFunc<IPAdapterConfigMetadata> = async (metadataItem) => {
|
||||
@@ -369,6 +386,10 @@ const parseIPAdapter: MetadataParseFunc<IPAdapterConfigMetadata> = async (metada
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(await getProperty(metadataItem, 'weight'));
|
||||
const method = zIPAdapterField.shape.method
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(await getProperty(metadataItem, 'method'));
|
||||
const begin_step_percent = zIPAdapterField.shape.begin_step_percent
|
||||
.nullish()
|
||||
.catch(null)
|
||||
@@ -386,6 +407,7 @@ const parseIPAdapter: MetadataParseFunc<IPAdapterConfigMetadata> = async (metada
|
||||
clipVisionModel: 'ViT-H',
|
||||
controlImage: image?.image_name ?? null,
|
||||
weight: weight ?? initialIPAdapter.weight,
|
||||
method: method ?? initialIPAdapter.method,
|
||||
beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct,
|
||||
endStepPct: end_step_percent ?? initialIPAdapter.endStepPct,
|
||||
};
|
||||
@@ -394,12 +416,16 @@ const parseIPAdapter: MetadataParseFunc<IPAdapterConfigMetadata> = async (metada
|
||||
};
|
||||
|
||||
const parseAllIPAdapters: MetadataParseFunc<IPAdapterConfigMetadata[]> = async (metadata) => {
|
||||
const ipAdaptersRaw = await getProperty(metadata, 'ipAdapters', isArray);
|
||||
const parseResults = await Promise.allSettled(ipAdaptersRaw.map((ipAdapter) => parseIPAdapter(ipAdapter)));
|
||||
const ipAdapters = parseResults
|
||||
.filter((result): result is PromiseFulfilledResult<IPAdapterConfigMetadata> => result.status === 'fulfilled')
|
||||
.map((result) => result.value);
|
||||
return ipAdapters;
|
||||
try {
|
||||
const ipAdaptersRaw = await getProperty(metadata, 'ipAdapters', isArray);
|
||||
const parseResults = await Promise.allSettled(ipAdaptersRaw.map((ipAdapter) => parseIPAdapter(ipAdapter)));
|
||||
const ipAdapters = parseResults
|
||||
.filter((result): result is PromiseFulfilledResult<IPAdapterConfigMetadata> => result.status === 'fulfilled')
|
||||
.map((result) => result.value);
|
||||
return ipAdapters;
|
||||
} catch {
|
||||
return [];
|
||||
}
|
||||
};
|
||||
|
||||
export const parsers = {
|
||||
|
||||
@@ -177,11 +177,11 @@ const recallLoRA: MetadataRecallFunc<LoRA> = (lora) => {
|
||||
};
|
||||
|
||||
const recallAllLoRAs: MetadataRecallFunc<LoRA[]> = (loras) => {
|
||||
const { dispatch } = getStore();
|
||||
dispatch(lorasReset());
|
||||
if (!loras.length) {
|
||||
return;
|
||||
}
|
||||
const { dispatch } = getStore();
|
||||
dispatch(lorasReset());
|
||||
loras.forEach((lora) => {
|
||||
dispatch(loraRecalled(lora));
|
||||
});
|
||||
@@ -192,11 +192,11 @@ const recallControlNet: MetadataRecallFunc<ControlNetConfigMetadata> = (controlN
|
||||
};
|
||||
|
||||
const recallControlNets: MetadataRecallFunc<ControlNetConfigMetadata[]> = (controlNets) => {
|
||||
const { dispatch } = getStore();
|
||||
dispatch(controlNetsReset());
|
||||
if (!controlNets.length) {
|
||||
return;
|
||||
}
|
||||
const { dispatch } = getStore();
|
||||
dispatch(controlNetsReset());
|
||||
controlNets.forEach((controlNet) => {
|
||||
dispatch(controlAdapterRecalled(controlNet));
|
||||
});
|
||||
@@ -207,11 +207,11 @@ const recallT2IAdapter: MetadataRecallFunc<T2IAdapterConfigMetadata> = (t2iAdapt
|
||||
};
|
||||
|
||||
const recallT2IAdapters: MetadataRecallFunc<T2IAdapterConfigMetadata[]> = (t2iAdapters) => {
|
||||
const { dispatch } = getStore();
|
||||
dispatch(t2iAdaptersReset());
|
||||
if (!t2iAdapters.length) {
|
||||
return;
|
||||
}
|
||||
const { dispatch } = getStore();
|
||||
dispatch(t2iAdaptersReset());
|
||||
t2iAdapters.forEach((t2iAdapter) => {
|
||||
dispatch(controlAdapterRecalled(t2iAdapter));
|
||||
});
|
||||
@@ -222,11 +222,11 @@ const recallIPAdapter: MetadataRecallFunc<IPAdapterConfigMetadata> = (ipAdapter)
|
||||
};
|
||||
|
||||
const recallIPAdapters: MetadataRecallFunc<IPAdapterConfigMetadata[]> = (ipAdapters) => {
|
||||
const { dispatch } = getStore();
|
||||
dispatch(ipAdaptersReset());
|
||||
if (!ipAdapters.length) {
|
||||
return;
|
||||
}
|
||||
const { dispatch } = getStore();
|
||||
dispatch(ipAdaptersReset());
|
||||
ipAdapters.forEach((ipAdapter) => {
|
||||
dispatch(controlAdapterRecalled(ipAdapter));
|
||||
});
|
||||
|
||||
@@ -10,7 +10,7 @@ const TOAST_ID = 'starterModels';
|
||||
|
||||
export const useStarterModelsToast = () => {
|
||||
const { t } = useTranslation();
|
||||
const isEnabled = useFeatureStatus('starterModels').isFeatureEnabled;
|
||||
const isEnabled = useFeatureStatus('starterModels');
|
||||
const [didToast, setDidToast] = useState(false);
|
||||
const [mainModels, { data }] = useMainModels();
|
||||
const toast = useToast();
|
||||
|
||||
@@ -74,7 +74,6 @@ export const InstallModelForm = () => {
|
||||
onClick={handleSubmit(onSubmit)}
|
||||
isDisabled={!formState.dirtyFields.location}
|
||||
isLoading={isLoading}
|
||||
type="submit"
|
||||
size="sm"
|
||||
>
|
||||
{t('modelManager.install')}
|
||||
|
||||
@@ -86,7 +86,6 @@ export const ControlNetOrT2IAdapterDefaultSettings = () => {
|
||||
colorScheme="invokeYellow"
|
||||
isDisabled={!formState.isDirty}
|
||||
onClick={handleSubmit(onSubmit)}
|
||||
type="submit"
|
||||
isLoading={isLoadingUpdateModel}
|
||||
>
|
||||
{t('common.save')}
|
||||
|
||||
@@ -116,7 +116,6 @@ export const MainModelDefaultSettings = () => {
|
||||
colorScheme="invokeYellow"
|
||||
isDisabled={!formState.isDirty}
|
||||
onClick={handleSubmit(onSubmit)}
|
||||
type="submit"
|
||||
isLoading={isLoadingUpdateModel}
|
||||
>
|
||||
{t('common.save')}
|
||||
|
||||
@@ -88,7 +88,6 @@ export const TriggerPhrases = () => {
|
||||
<Button
|
||||
leftIcon={<PiPlusBold />}
|
||||
size="sm"
|
||||
type="submit"
|
||||
onClick={addTriggerPhrase}
|
||||
isDisabled={!phrase || Boolean(errors.length)}
|
||||
isLoading={isLoading}
|
||||
|
||||
@@ -3,6 +3,7 @@ import 'reactflow/dist/style.css';
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import TopPanel from 'features/nodes/components/flow/panels/TopPanel/TopPanel';
|
||||
import { LoadWorkflowFromGraphModal } from 'features/workflowLibrary/components/LoadWorkflowFromGraphModal/LoadWorkflowFromGraphModal';
|
||||
import { SaveWorkflowAsDialog } from 'features/workflowLibrary/components/SaveWorkflowAsDialog/SaveWorkflowAsDialog';
|
||||
import type { AnimationProps } from 'framer-motion';
|
||||
import { AnimatePresence, motion } from 'framer-motion';
|
||||
@@ -61,6 +62,7 @@ const NodeEditor = () => {
|
||||
<BottomLeftPanel />
|
||||
<MinimapPanel />
|
||||
<SaveWorkflowAsDialog />
|
||||
<LoadWorkflowFromGraphModal />
|
||||
</motion.div>
|
||||
)}
|
||||
</AnimatePresence>
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import { Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import type { CSSProperties } from 'react';
|
||||
import { memo, useMemo } from 'react';
|
||||
import type { EdgeProps } from 'reactflow';
|
||||
import { BaseEdge, getBezierPath } from 'reactflow';
|
||||
import { BaseEdge, EdgeLabelRenderer, getBezierPath } from 'reactflow';
|
||||
|
||||
import { makeEdgeSelector } from './util/makeEdgeSelector';
|
||||
|
||||
@@ -25,9 +26,10 @@ const InvocationDefaultEdge = ({
|
||||
[source, sourceHandleId, target, targetHandleId, selected]
|
||||
);
|
||||
|
||||
const { isSelected, shouldAnimate, stroke } = useAppSelector(selector);
|
||||
const { isSelected, shouldAnimate, stroke, label } = useAppSelector(selector);
|
||||
const shouldShowEdgeLabels = useAppSelector((s) => s.nodes.shouldShowEdgeLabels);
|
||||
|
||||
const [edgePath] = getBezierPath({
|
||||
const [edgePath, labelX, labelY] = getBezierPath({
|
||||
sourceX,
|
||||
sourceY,
|
||||
sourcePosition,
|
||||
@@ -47,7 +49,33 @@ const InvocationDefaultEdge = ({
|
||||
[isSelected, shouldAnimate, stroke]
|
||||
);
|
||||
|
||||
return <BaseEdge path={edgePath} markerEnd={markerEnd} style={edgeStyles} />;
|
||||
return (
|
||||
<>
|
||||
<BaseEdge path={edgePath} markerEnd={markerEnd} style={edgeStyles} />
|
||||
{label && shouldShowEdgeLabels && (
|
||||
<EdgeLabelRenderer>
|
||||
<Flex
|
||||
className="nodrag nopan"
|
||||
pointerEvents="all"
|
||||
position="absolute"
|
||||
transform={`translate(-50%, -50%) translate(${labelX}px,${labelY}px)`}
|
||||
bg="base.800"
|
||||
borderRadius="base"
|
||||
borderWidth={1}
|
||||
borderColor={isSelected ? 'undefined' : 'transparent'}
|
||||
opacity={isSelected ? 1 : 0.5}
|
||||
py={1}
|
||||
px={3}
|
||||
shadow="md"
|
||||
>
|
||||
<Text size="sm" fontWeight="semibold" color={isSelected ? 'base.100' : 'base.300'}>
|
||||
{label}
|
||||
</Text>
|
||||
</Flex>
|
||||
</EdgeLabelRenderer>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(InvocationDefaultEdge);
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectFieldOutputTemplate } from 'features/nodes/store/selectors';
|
||||
import { selectFieldOutputTemplate, selectNodeTemplate } from 'features/nodes/store/selectors';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
|
||||
import { getFieldColor } from './getEdgeColor';
|
||||
@@ -10,6 +10,7 @@ const defaultReturnValue = {
|
||||
isSelected: false,
|
||||
shouldAnimate: false,
|
||||
stroke: colorTokenToCssVar('base.500'),
|
||||
label: '',
|
||||
};
|
||||
|
||||
export const makeEdgeSelector = (
|
||||
@@ -19,25 +20,34 @@ export const makeEdgeSelector = (
|
||||
targetHandleId: string | null | undefined,
|
||||
selected?: boolean
|
||||
) =>
|
||||
createMemoizedSelector(selectNodesSlice, (nodes): { isSelected: boolean; shouldAnimate: boolean; stroke: string } => {
|
||||
const sourceNode = nodes.nodes.find((node) => node.id === source);
|
||||
const targetNode = nodes.nodes.find((node) => node.id === target);
|
||||
createMemoizedSelector(
|
||||
selectNodesSlice,
|
||||
(nodes): { isSelected: boolean; shouldAnimate: boolean; stroke: string; label: string } => {
|
||||
const sourceNode = nodes.nodes.find((node) => node.id === source);
|
||||
const targetNode = nodes.nodes.find((node) => node.id === target);
|
||||
|
||||
const isInvocationToInvocationEdge = isInvocationNode(sourceNode) && isInvocationNode(targetNode);
|
||||
const isInvocationToInvocationEdge = isInvocationNode(sourceNode) && isInvocationNode(targetNode);
|
||||
|
||||
const isSelected = Boolean(sourceNode?.selected || targetNode?.selected || selected);
|
||||
if (!sourceNode || !sourceHandleId) {
|
||||
return defaultReturnValue;
|
||||
const isSelected = Boolean(sourceNode?.selected || targetNode?.selected || selected);
|
||||
if (!sourceNode || !sourceHandleId || !targetNode || !targetHandleId) {
|
||||
return defaultReturnValue;
|
||||
}
|
||||
|
||||
const outputFieldTemplate = selectFieldOutputTemplate(nodes, sourceNode.id, sourceHandleId);
|
||||
const sourceType = isInvocationToInvocationEdge ? outputFieldTemplate?.type : undefined;
|
||||
|
||||
const stroke = sourceType && nodes.shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500');
|
||||
|
||||
const sourceNodeTemplate = selectNodeTemplate(nodes, sourceNode.id);
|
||||
const targetNodeTemplate = selectNodeTemplate(nodes, targetNode.id);
|
||||
|
||||
const label = `${sourceNodeTemplate?.title || sourceNode.data?.label} -> ${targetNodeTemplate?.title || targetNode.data?.label}`;
|
||||
|
||||
return {
|
||||
isSelected,
|
||||
shouldAnimate: nodes.shouldAnimateEdges && isSelected,
|
||||
stroke,
|
||||
label,
|
||||
};
|
||||
}
|
||||
|
||||
const outputFieldTemplate = selectFieldOutputTemplate(nodes, sourceNode.id, sourceHandleId);
|
||||
const sourceType = isInvocationToInvocationEdge ? outputFieldTemplate?.type : undefined;
|
||||
|
||||
const stroke = sourceType && nodes.shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500');
|
||||
|
||||
return {
|
||||
isSelected,
|
||||
shouldAnimate: nodes.shouldAnimateEdges && isSelected,
|
||||
stroke,
|
||||
};
|
||||
});
|
||||
);
|
||||
|
||||
@@ -16,7 +16,7 @@ const props: ChakraProps = { w: 'unset' };
|
||||
|
||||
const InvocationNodeFooter = ({ nodeId }: Props) => {
|
||||
const hasImageOutput = useHasImageOutput(nodeId);
|
||||
const isCacheEnabled = useFeatureStatus('invocationCache').isFeatureEnabled;
|
||||
const isCacheEnabled = useFeatureStatus('invocationCache');
|
||||
return (
|
||||
<Flex
|
||||
className={DRAG_HANDLE_CLASSNAME}
|
||||
|
||||