Compare commits

...

123 Commits

Author SHA1 Message Date
Ryan Dick
f6045682c0 Fix bug with partial offload of model buffers. 2024-12-10 22:19:17 +00:00
Ryan Dick
84a75ddb72 Fix bug in ModelCache that was causing it to offload more models from VRAM than necessary. 2024-12-10 20:38:37 +00:00
Ryan Dick
a9fb1c82a0 Fix handling of torch.nn.Module buffers in CachedModelWithPartialLoad. 2024-12-10 19:38:04 +00:00
Ryan Dick
cc7391e630 Enable LoRAPatcher.apply_smart_lora_patches(...) throughout the stack. 2024-12-10 17:27:33 +00:00
Ryan Dick
62d595f695 (minor) Rename num_layers -> num_loras in unit tests. 2024-12-10 16:41:52 +00:00
Ryan Dick
5e2080266e Add test_apply_smart_lora_patches_to_partially_loaded_model(...). 2024-12-10 16:38:48 +00:00
Ryan Dick
ed7bb7ea3d Add LoRAPatcher.smart_apply_lora_patches() 2024-12-10 16:26:34 +00:00
Ryan Dick
62407f7c6b Refactor LoRAPatcher slightly in preparation for a 'smart' patcher. 2024-12-10 15:36:36 +00:00
Ryan Dick
80128e1e14 Fix LoRAPatcher.apply_lora_wrapper_patches(...) 2024-12-10 03:10:23 +00:00
Ryan Dick
4c84d39e7d Finish consolidating LoRA sidecar wrapper implementations. 2024-12-10 02:54:32 +00:00
Ryan Dick
0c4a368555 Begin to consolidate the LoRA sidecar and LoRA layer wrapper implementations. 2024-12-10 01:16:01 +00:00
Ryan Dick
55dc762a91 Fix bias handling in LoRAModuleWrapper and add unit test that checks that all LoRA patching methods produce the same outputs. 2024-12-09 16:59:37 +00:00
Ryan Dick
d825d3856e Add LoRA wrapper patching to LoRAPatcher. 2024-12-09 16:35:23 +00:00
Ryan Dick
d94733f55a Add LoRA wrapper layer. 2024-12-09 15:17:50 +00:00
Ryan Dick
2144d21f80 Maintain a read-only CPU state dict copy in CachedModelWithPartialLoad. 2024-12-06 21:49:24 +00:00
Ryan Dick
958efa19d7 Memoize frequently accessed values in CachedModelWithPartialLoad. 2024-12-06 20:39:05 +00:00
Ryan Dick
11af57def3 More ModelCache logging improvements. 2024-12-06 18:38:36 +00:00
Ryan Dick
8b70a5b9bd Cleanup of ModelCache and added a bunch of debug logging. 2024-12-06 17:39:16 +00:00
Ryan Dick
5d9fdcd78d Fix a couple of bugs to get basic vanilla partial model load working with the model cache. 2024-12-06 00:50:58 +00:00
Ryan Dick
c7b84cf012 WIP - first pass at overhauling ModelCache to work with partial loads. 2024-12-05 23:03:40 +00:00
Ryan Dick
8e409e3436 Delete experimental torch device autocasting solutions and clean up TorchFunctionAutocastDeviceContext. 2024-12-05 19:36:44 +00:00
Ryan Dick
987393853c Create CachedModelOnlyFullLoad class. 2024-12-05 18:43:50 +00:00
Ryan Dick
91c5af1b95 Move CachedModelWithPartialLoad into the main model_cache/ directory. 2024-12-05 18:21:26 +00:00
Ryan Dick
5c67dd507a Get rid of ModelLocker. It was an unnecessary layer of indirection. 2024-12-05 16:59:40 +00:00
Ryan Dick
2ff928ec17 Move lock(...) and unlock(...) logic from ModelLocker to the ModelCache and make a bunch of ModelCache properties/methods private. 2024-12-05 16:11:40 +00:00
Ryan Dick
4327bbe77e Pull get_model_cache_key(...) out of ModelCache. The ModelCache should not be concerned with implementation details like the submodel_type. 2024-12-04 22:53:57 +00:00
Ryan Dick
ad1c0d37ef Rename model_cache_default.py -> model_cache.py. 2024-12-04 22:45:30 +00:00
Ryan Dick
9708d87946 Remove ModelCacheBase. 2024-12-04 22:05:34 +00:00
Ryan Dick
3ad44f7850 Move CacheStats to its own file. 2024-12-04 21:56:50 +00:00
Ryan Dick
9a482981b2 Move CacheRecord out to its own file. 2024-12-04 21:53:19 +00:00
Ryan Dick
6b02362b12 Rip out ModelLockerBase. 2024-12-04 21:47:11 +00:00
Ryan Dick
8fec4ec91c Tidy up CachedModel and improve unit test coverage. 2024-12-04 20:28:31 +00:00
Ryan Dick
693e421970 Alternative implementation with torch.nn.Linear module streaming. 2024-12-03 22:32:15 +00:00
Ryan Dick
dc14104bc8 Add TorchFunctionAutocastContext 2024-12-03 19:26:46 +00:00
Ryan Dick
f286a1d1f3 Remove debug logs. 2024-12-03 18:04:55 +00:00
Ryan Dick
9dc86b2b71 Add basic CachedModel class with features for partial load/unload. 2024-12-03 17:12:22 +00:00
Ryan Dick
2cab689b79 Naive TorchAutocastContext. 2024-12-03 14:55:43 +00:00
psychedelicious
f8c7adddd0 feat(ui): add vietnamese to language picker
Closes #7384
2024-12-02 08:12:14 -05:00
psychedelicious
17da1d92e9 fix(ui): remove "adding to" text on Invoke tooltip on Workflows/Upscaling tabs
The "adding to" text indicates if images are going to the gallery or staging area. This info is relevant only to the canvas tab, but was displayed on Upscaling and Workflows tabs. Removed it from those tabs.
2024-12-02 08:08:16 -05:00
psychedelicious
1cc57a4854 chore(ui): lint 2024-12-02 07:59:12 -05:00
psychedelicious
3993fae331 fix(ui): unable to invoke w/ empty inpaint mask or raster layer
Removed the empty state checks for these layer types - it's always OK to invoke when they are empty.
2024-12-02 07:59:12 -05:00
psychedelicious
1446526d55 tidy(ui): translation keys for canvas layer warnings 2024-12-02 07:59:12 -05:00
psychedelicious
62c024e725 feat(ui): add gallery image ctx menu items to create ref image from image
Appears these actions disappeared at some point. Restoring them.
2024-12-02 07:52:58 -05:00
psychedelicious
1e92bb4e94 fix(ui): ref image defaults to prev ref image's image selection
A redux selector is used to get the "default" IP Adapter. The selector uses the model list query result to select an IP Adapter model to be preset by default.

The selector is memoized, so if we mutate the returned default IP Adapter state, it mutates the result of the selector for all consumers.

For example, the `image` property of the default IP Adapter selector result is `null`. When we set the `image` property of the selector result while creating an IP Adapter, this does not trigger the selector to recompute its result. We end up setting the image for the selector result directly, and all other consumers now have that same image set.

Solution - we need to clone the selector result everywhere it is used. This was missed in a few spots, causing the issue.
2024-12-02 07:48:39 -05:00
psychedelicious
db6398fdf6 feat(ui): less confusing empty state for rg ref images
It was easy to misunderstand the empty state for a regional guidance reference image. There was no label, so it seemed like it was the whole region that was empty.

This small change adds the "Reference Image" heading to the empty state, so it's clear that the empty state messaging refers to this reference image, not the whole regional guidance layer.
2024-12-02 07:46:10 -05:00
Riccardo Giovanetti
ebd73a2ac2 translationBot(ui): update translation (Italian)
Currently translated at 98.7% (1622 of 1643 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2024-12-02 02:13:51 -08:00
Hosted Weblate
8ee95cab00 translationBot(ui): update translation files
Updated by "Cleanup translation files" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI
2024-12-02 02:13:51 -08:00
Linos
d1184201a8 translationBot(ui): update translation (Vietnamese)
Currently translated at 100.0% (1643 of 1643 strings)

translationBot(ui): update translation (Vietnamese)

Currently translated at 100.0% (1638 of 1638 strings)

Co-authored-by: Linos <linos.coding@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/vi/
Translation: InvokeAI/Web UI
2024-12-02 02:13:51 -08:00
Nik Nikovsky
5887891654 translationBot(ui): update translation (Polish)
Currently translated at 4.9% (81 of 1638 strings)

Co-authored-by: Nik Nikovsky <zejdzztegomaila@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/pl/
Translation: InvokeAI/Web UI
2024-12-02 02:13:51 -08:00
Riku
765ca4e004 translationBot(ui): update translation (German)
Currently translated at 69.7% (1142 of 1638 strings)

Co-authored-by: Riku <riku.block@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/de/
Translation: InvokeAI/Web UI
2024-12-02 02:13:51 -08:00
Riku
159b00a490 fix(app): adjust session queue api type 2024-12-01 20:06:05 -08:00
Riku
3fbf6f2d2a chore(ui): update typegen schema 2024-12-01 19:56:09 -08:00
Riku
931fca7cd1 fix(ui): call cancel instead of clear queue 2024-12-01 19:53:12 -08:00
Riku
db84a3a5d4 refactor(ui): move clear queue hook to separate file 2024-12-01 19:42:25 -08:00
psychedelicious
ca8313e805 feat(ui): add new layer from image menu items for staging area
The layers are disabled when created so as to not interfere with the canvas state.
2024-12-01 19:37:49 -08:00
psychedelicious
df849035ee feat(ui): allow setting isEnabled, isLocked and name in createNewCanvasEntityFromImage util 2024-12-01 19:37:49 -08:00
psychedelicious
8d97fe69ca feat(ui): use imageDTOToFile in staging area save to gallery button 2024-12-01 19:37:49 -08:00
psychedelicious
9044e53a9b feat(ui): add imageDTOToFile util 2024-12-01 19:37:49 -08:00
Jonathan
6012b0f912 Update flux_text_encoder.py
Updated version number for FLUX Text Encoding.
2024-11-30 08:29:21 -05:00
Jonathan
bb0ed5dc8a Update flux_denoise.py
Updated node version for FLUX Denoise.
2024-11-30 08:29:21 -05:00
Ryan Dick
021552fd81 Avoid unnecessary dtype conversions with rope encodings. 2024-11-29 12:32:50 -05:00
Ryan Dick
be73dbba92 Use view() instead of rearrange() for better performance. 2024-11-29 12:32:50 -05:00
Ryan Dick
db9c0cad7c Replace custom RMSNorm implementation with torch.nn.functional.rms_norm(...) for improved speed. 2024-11-29 12:32:50 -05:00
Ryan Dick
54b7f9a063 FLUX Regional Prompting (#7388)
## Summary

This PR adds support for regional prompting with FLUX.

### Example 1
Global prompt: `An architecture rendering of the reception area of a
corporate office with modern decor.`
<img width="1386" alt="image"
src="https://github.com/user-attachments/assets/c8169bdb-49a9-44bc-bd9e-58d98e09094b">

![image](https://github.com/user-attachments/assets/4a426be9-9d7a-4527-b27c-2d2514ee73fe)

## QA Instructions

- [x] Test that there is no slowdown in the base case with a single
global prompt.
- [x] Test image fully covered by regional masks.
- [x] Test image covered by region masks with small gaps.
- [x] Test region masks with large unmasked ‘background’ regions
- [x] Test region masks with significant overlap
- [x] Test multiple global prompts.
- [x] Test no global prompt.
- [x] Test regional negative prompts (It runs... but results are not
great. Needs more tuning to be useful.)
- Test compatibility with:
    - [x] ControlNet
    - [x] LoRA
    - [x] IP-Adapter

## Remaining TODO

- [x] Disable the following UI features for FLUX prompt regions:
negative prompts, reference images, auto-negative.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
2024-11-29 08:56:42 -05:00
psychedelicious
7d488a5352 feat(ui): add delete button to regional ref image empty state 2024-11-29 15:51:24 +10:00
psychedelicious
4d7667f63d fix(ui): add missing translations 2024-11-29 15:43:49 +10:00
psychedelicious
08704ee8ec feat(ui): use canvas layer validators in control/ip adapter graph builders 2024-11-29 15:32:48 +10:00
psychedelicious
5910892c33 Merge remote-tracking branch 'origin/main' into ryan/flux-regional-prompting 2024-11-29 15:19:39 +10:00
psychedelicious
46a09d9e90 feat(ui): format warnings tooltip 2024-11-29 13:32:51 +10:00
psychedelicious
df0c7d73f3 feat(ui): use regional guidance validation utils in graph builders 2024-11-29 13:26:09 +10:00
psychedelicious
3905c97e32 feat(ui): return translation keys from validation utils instead of translated strings 2024-11-29 13:25:09 +10:00
psychedelicious
0be796a808 feat(ui): use layer validation utils in invoke readiness utils 2024-11-29 13:14:26 +10:00
psychedelicious
7dd33b0f39 feat(ui): add indicator to canvas layer headers, displaying validation warnings
If there are any issues with the layer, the icon is displayed. If the layer is disabled, the icon is greyed out but still visible.
2024-11-29 13:13:47 +10:00
psychedelicious
484aaf1595 feat(ui): add canvas layer validation utils
These helpers consolidate layer validation checks. For example, checking that the layer has content drawn, is compatible with the selected main model, has valid reference images, etc.
2024-11-29 13:12:32 +10:00
psychedelicious
c276b60af9 tidy(ui): use object for addRegions graph builder util arg 2024-11-29 08:49:41 +10:00
Ryan Dick
5d8dd6e26e Fix FLUX regional negative prompts. 2024-11-28 18:49:29 +00:00
Emmanuel Ferdman
5bca68d873 docs: update code of conduct reference
Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com>
2024-11-27 17:38:33 -08:00
Ryan Dick
64364e7911 Short-circuit if there are no region masks in FLUX and don't apply attention masking. 2024-11-27 22:40:10 +00:00
Ryan Dick
6565cea039 Comment unused _prepare_unrestricted_attn_mask(...) for future reference. 2024-11-27 22:16:44 +00:00
Ryan Dick
3ebd8d6c07 Delete outdated TODO comment. 2024-11-27 22:13:25 +00:00
Ryan Dick
e970185161 Tweak flux regional prompting attention scheme based on latest experimentation. 2024-11-27 22:13:07 +00:00
Ryan Dick
fa5653cdf7 Remove unused 'denoise' param to addRegions(). 2024-11-27 17:08:42 +00:00
Ryan Dick
9a7b000995 Update frontend to support regional prompts with FLUX in the canvas. 2024-11-27 17:04:43 +00:00
Ryan Dick
3a27242838 Bump transformers. The main motivation for this bump is to ingest a fix for DepthAnything postprocessing artifacts. 2024-11-27 07:46:16 -08:00
Ryan Dick
8cfb032051 Add utility ImagePanelLayoutInvocation for working with In-Context LoRA workflows. 2024-11-26 20:58:31 -08:00
Ryan Dick
06a9d4e2b2 Use a Textarea component for the FluxTextEncoderInvocation prompt field. 2024-11-26 20:58:31 -08:00
Brandon Rising
ed46acee79 fix: Fail scan on InvalidMagicError in picklescan, update default for read_checkpoint_meta to scan unless explicitly told not to 2024-11-26 16:17:12 -05:00
Ryan Dick
b54463d294 Allow regional prompting background regions to attend to themselves and to the entire txt embedding. 2024-11-26 17:57:31 +00:00
Ryan Dick
faee79dc95 Distinguish between restricted and unrestricted attn masks in FLUX regional prompting. 2024-11-26 16:55:52 +00:00
Mary Hipp
965cd76e33 lint fix 2024-11-26 11:25:53 -05:00
Mary Hipp
e5e8cbf34c shorten reference image mode descriptions; 2024-11-26 11:25:53 -05:00
Mary Hipp
3412a52594 (ui): updates various informational tooltips, adds descriptons to IP adapter method options 2024-11-26 11:25:53 -05:00
Ryan Dick
e01f66b026 Apply regional attention masks in the single stream blocks in addition to the double stream blocks. 2024-11-25 22:40:08 +00:00
Ryan Dick
53abdde242 Update Flux RegionalPromptingExtension to prepare both a mask with restricted image self-attention and a mask with unrestricted image self attention. 2024-11-25 22:04:23 +00:00
Ryan Dick
94c088300f Be smarter about selecting the global CLIP embedding for FLUX regional prompting. 2024-11-25 20:15:04 +00:00
Ryan Dick
3741a6f5e0 Fix device handling for regional masks and apply the attention mask in the FLUX double stream block. 2024-11-25 16:02:03 +00:00
Kent Keirsey
059336258f Create SECURITY.md 2024-11-25 04:10:03 -08:00
Ryan Dick
2c23b8414c Use a single global CLIP embedding for FLUX regional guidance. 2024-11-22 23:01:43 +00:00
Mary Hipp
271cc52c80 fix(ui): use token for download if its in store 2024-11-22 12:08:05 -05:00
Ryan Dick
20356c0746 Fixup the logic for preparing FLUX regional prompt attention masks. 2024-11-21 22:46:25 +00:00
psychedelicious
e44458609f chore: bump version to v5.4.3rc1 2024-11-21 10:32:43 -08:00
psychedelicious
69d86a7696 feat(ui): address feedback 2024-11-21 09:54:35 -08:00
Hippalectryon
56db1a9292 Use proxyrect and setEntityPosition to sync transformer position 2024-11-21 09:54:35 -08:00
Hippalectryon
cf50e5eeee Make sure the canvas is focused 2024-11-21 09:54:35 -08:00
Hippalectryon
c9c07968d2 lint 2024-11-21 09:54:35 -08:00
Hippalectryon
97d0757176 use $isInteractable instead of $isDisabled 2024-11-21 09:54:35 -08:00
Hippalectryon
0f51b677a9 refactor 2024-11-21 09:54:35 -08:00
Hippalectryon
56ca94c3a9 Don't move if the layer is disabled
Lint
2024-11-21 09:54:35 -08:00
Hippalectryon
28d169f859 Allow moving layers using the keyboard 2024-11-21 09:54:35 -08:00
psychedelicious
92f71d99ee tweak(ui): use X icon for rg ref image delete button 2024-11-21 08:50:39 -08:00
psychedelicious
0764c02b1d tweak(ui): code style 2024-11-21 08:50:39 -08:00
psychedelicious
081c7569fe feat(ui): add global ref image empty state 2024-11-21 08:50:39 -08:00
psychedelicious
20f6532ee8 feat(ui): add empty state for regional guidance ref image 2024-11-21 08:50:39 -08:00
Mary Hipp
b9e8910478 feat(ui): add actions for video modal clicks 2024-11-21 11:15:55 -05:00
Mary Hipp
ded8391e3c use nanostore for schema parsed instead 2024-11-20 20:13:31 -05:00
Mary Hipp
e9dd2c396a limit to one hook 2024-11-20 20:13:31 -05:00
Mary Hipp
0d86de0cb5 fix(ui): make sure schema has loaded before trying to load any workflows 2024-11-20 20:13:31 -05:00
Ryan Dick
bad1149504 WIP - add rough logic for preparing the FLUX regional prompting attention mask. 2024-11-20 22:29:36 +00:00
Ryan Dick
fda7aaa7ca Pass RegionalPromptingExtension down to the CustomDoubleStreamBlockProcessor in FLUX. 2024-11-20 19:48:04 +00:00
Ryan Dick
85c616fa34 WIP - Pass prompt masks to FLUX model during denoising. 2024-11-20 18:51:43 +00:00
psychedelicious
549f4e9794 feat(ui): set default infill method to lama 2024-11-20 11:19:17 -05:00
psychedelicious
ef8ededd2f fix(ui): disable width and height output on image batch output
There's a technical challenge with outputting these values directly. `ImageField` does not store them, so the batch's `ImageField` collection does not have width and height for each image.

In order to set up the batch and pass along width and height for each image, we'd need to make a network request for each image when the user clicks Invoke. It would often be cached, but this will eventually create a scaling issue and poor user experience.

As a very simple workaround, users can output the batch image output into an `Image Primitive` node to access the width and height.

This change is implemented by adding some simple special handling when parsing the output fields for the `image_batch` node.

I'll keep this situation in mind when extending the batching system to other field types.
2024-11-20 11:16:54 -05:00
Mary Hipp
1948ffe106 make sure Soft Edge Detection has preprocessor applied 2024-11-20 08:46:02 -05:00
122 changed files with 4183 additions and 1614 deletions

14
SECURITY.md Normal file
View File

@@ -0,0 +1,14 @@
# Security Policy
## Supported Versions
Only the latest version of Invoke will receive security updates.
We do not currently maintain multiple versions of the application with updates.
## Reporting a Vulnerability
To report a vulnerability, contact the Invoke team directly at security@invoke.ai
At this time, we do not maintain a formal bug bounty program.
You can also share identified security issues with our team on huntr.com

View File

@@ -1364,7 +1364,6 @@ the in-memory loaded model:
|----------------|-----------------|------------------|
| `config` | AnyModelConfig | A copy of the model's configuration record for retrieving base type, etc. |
| `model` | AnyModel | The instantiated model (details below) |
| `locker` | ModelLockerBase | A context manager that mediates the movement of the model into VRAM |
### get_model_by_key(key, [submodel]) -> LoadedModel

View File

@@ -38,7 +38,7 @@ This project is a combined effort of dedicated people from across the world. [C
## Code of Conduct
The InvokeAI community is a welcoming place, and we want your help in maintaining that. Please review our [Code of Conduct](https://github.com/invoke-ai/InvokeAI/blob/main/CODE_OF_CONDUCT.md) to learn more - it's essential to maintaining a respectful and inclusive environment.
The InvokeAI community is a welcoming place, and we want your help in maintaining that. Please review our [Code of Conduct](https://github.com/invoke-ai/InvokeAI/blob/main/docs/CODE_OF_CONDUCT.md) to learn more - it's essential to maintaining a respectful and inclusive environment.
By making a contribution to this project, you certify that:

View File

@@ -37,7 +37,7 @@ from invokeai.backend.model_manager.config import (
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.load.model_cache.model_cache_base import CacheStats
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
from invokeai.backend.model_manager.search import ModelSearch

View File

@@ -110,7 +110,7 @@ async def cancel_by_batch_ids(
@session_queue_router.put(
"/{queue_id}/cancel_by_destination",
operation_id="cancel_by_destination",
responses={200: {"model": CancelByBatchIDsResult}},
responses={200: {"model": CancelByDestinationResult}},
)
async def cancel_by_destination(
queue_id: str = Path(description="The queue id to perform this operation on"),

View File

@@ -82,10 +82,11 @@ class CompelInvocation(BaseInvocation):
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
tokenizer_info as tokenizer,
LoRAPatcher.apply_lora_patches(
LoRAPatcher.apply_smart_lora_patches(
model=text_encoder,
patches=_lora_loader(),
prefix="lora_te_",
dtype=TorchDevice.choose_torch_dtype(),
cached_weights=cached_weights,
),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
@@ -179,10 +180,11 @@ class SDXLPromptInvocationBase:
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
tokenizer_info as tokenizer,
LoRAPatcher.apply_lora_patches(
LoRAPatcher.apply_smart_lora_patches(
text_encoder,
patches=_lora_loader(),
prefix=lora_prefix,
dtype=TorchDevice.choose_torch_dtype(),
cached_weights=cached_weights,
),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.

View File

@@ -1003,10 +1003,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
# Apply the LoRA after unet has been moved to its target device for faster patching.
LoRAPatcher.apply_lora_patches(
LoRAPatcher.apply_smart_lora_patches(
model=unet,
patches=_lora_loader(),
prefix="lora_unet_",
dtype=unet.dtype,
cached_weights=cached_weights,
),
):

View File

@@ -250,6 +250,11 @@ class FluxConditioningField(BaseModel):
"""A conditioning tensor primitive value"""
conditioning_name: str = Field(description="The name of conditioning tensor")
mask: Optional[TensorField] = Field(
default=None,
description="The mask associated with this conditioning tensor. Excluded regions should be set to False, "
"included regions should be set to True.",
)
class SD3ConditioningField(BaseModel):

View File

@@ -30,6 +30,7 @@ from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlN
from invokeai.backend.flux.denoise import denoise
from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import XlabsIpAdapterFlux
@@ -42,6 +43,7 @@ from invokeai.backend.flux.sampling_utils import (
pack,
unpack,
)
from invokeai.backend.flux.text_conditioning import FluxTextConditioning
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
@@ -56,7 +58,7 @@ from invokeai.backend.util.devices import TorchDevice
title="FLUX Denoise",
tags=["image", "flux"],
category="image",
version="3.2.1",
version="3.2.2",
classification=Classification.Prototype,
)
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
@@ -87,10 +89,10 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
input=Input.Connection,
title="Transformer",
)
positive_text_conditioning: FluxConditioningField = InputField(
positive_text_conditioning: FluxConditioningField | list[FluxConditioningField] = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
negative_text_conditioning: FluxConditioningField | None = InputField(
negative_text_conditioning: FluxConditioningField | list[FluxConditioningField] | None = InputField(
default=None,
description="Negative conditioning tensor. Can be None if cfg_scale is 1.0.",
input=Input.Connection,
@@ -139,36 +141,12 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
def _load_text_conditioning(
self, context: InvocationContext, conditioning_name: str, dtype: torch.dtype
) -> Tuple[torch.Tensor, torch.Tensor]:
# Load the conditioning data.
cond_data = context.conditioning.load(conditioning_name)
assert len(cond_data.conditionings) == 1
flux_conditioning = cond_data.conditionings[0]
assert isinstance(flux_conditioning, FLUXConditioningInfo)
flux_conditioning = flux_conditioning.to(dtype=dtype)
t5_embeddings = flux_conditioning.t5_embeds
clip_embeddings = flux_conditioning.clip_embeds
return t5_embeddings, clip_embeddings
def _run_diffusion(
self,
context: InvocationContext,
):
inference_dtype = torch.bfloat16
# Load the conditioning data.
pos_t5_embeddings, pos_clip_embeddings = self._load_text_conditioning(
context, self.positive_text_conditioning.conditioning_name, inference_dtype
)
neg_t5_embeddings: torch.Tensor | None = None
neg_clip_embeddings: torch.Tensor | None = None
if self.negative_text_conditioning is not None:
neg_t5_embeddings, neg_clip_embeddings = self._load_text_conditioning(
context, self.negative_text_conditioning.conditioning_name, inference_dtype
)
# Load the input latents, if provided.
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
if init_latents is not None:
@@ -183,15 +161,45 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
dtype=inference_dtype,
seed=self.seed,
)
b, _c, latent_h, latent_w = noise.shape
packed_h = latent_h // 2
packed_w = latent_w // 2
# Load the conditioning data.
pos_text_conditionings = self._load_text_conditioning(
context=context,
cond_field=self.positive_text_conditioning,
packed_height=packed_h,
packed_width=packed_w,
dtype=inference_dtype,
device=TorchDevice.choose_torch_device(),
)
neg_text_conditionings: list[FluxTextConditioning] | None = None
if self.negative_text_conditioning is not None:
neg_text_conditionings = self._load_text_conditioning(
context=context,
cond_field=self.negative_text_conditioning,
packed_height=packed_h,
packed_width=packed_w,
dtype=inference_dtype,
device=TorchDevice.choose_torch_device(),
)
pos_regional_prompting_extension = RegionalPromptingExtension.from_text_conditioning(
pos_text_conditionings, img_seq_len=packed_h * packed_w
)
neg_regional_prompting_extension = (
RegionalPromptingExtension.from_text_conditioning(neg_text_conditionings, img_seq_len=packed_h * packed_w)
if neg_text_conditionings
else None
)
transformer_info = context.models.load(self.transformer.transformer)
is_schnell = "schnell" in transformer_info.config.config_path
# Calculate the timestep schedule.
image_seq_len = noise.shape[-1] * noise.shape[-2] // 4
timesteps = get_schedule(
num_steps=self.num_steps,
image_seq_len=image_seq_len,
image_seq_len=packed_h * packed_w,
shift=not is_schnell,
)
@@ -228,28 +236,17 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
inpaint_mask = self._prep_inpaint_mask(context, x)
b, _c, latent_h, latent_w = x.shape
img_ids = generate_img_ids(h=latent_h, w=latent_w, batch_size=b, device=x.device, dtype=x.dtype)
pos_bs, pos_t5_seq_len, _ = pos_t5_embeddings.shape
pos_txt_ids = torch.zeros(
pos_bs, pos_t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device()
)
neg_txt_ids: torch.Tensor | None = None
if neg_t5_embeddings is not None:
neg_bs, neg_t5_seq_len, _ = neg_t5_embeddings.shape
neg_txt_ids = torch.zeros(
neg_bs, neg_t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device()
)
# Pack all latent tensors.
init_latents = pack(init_latents) if init_latents is not None else None
inpaint_mask = pack(inpaint_mask) if inpaint_mask is not None else None
noise = pack(noise)
x = pack(x)
# Now that we have 'packed' the latent tensors, verify that we calculated the image_seq_len correctly.
assert image_seq_len == x.shape[1]
# Now that we have 'packed' the latent tensors, verify that we calculated the image_seq_len, packed_h, and
# packed_w correctly.
assert packed_h * packed_w == x.shape[1]
# Prepare inpaint extension.
inpaint_extension: InpaintExtension | None = None
@@ -299,10 +296,11 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
if config.format in [ModelFormat.Checkpoint]:
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
exit_stack.enter_context(
LoRAPatcher.apply_lora_patches(
LoRAPatcher.apply_smart_lora_patches(
model=transformer,
patches=self._lora_iterator(context),
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
dtype=inference_dtype,
cached_weights=cached_weights,
)
)
@@ -314,7 +312,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
# The model is quantized, so apply the LoRA weights as sidecar layers. This results in slower inference,
# than directly patching the weights, but is agnostic to the quantization format.
exit_stack.enter_context(
LoRAPatcher.apply_lora_sidecar_patches(
LoRAPatcher.apply_lora_wrapper_patches(
model=transformer,
patches=self._lora_iterator(context),
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
@@ -338,12 +336,8 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
model=transformer,
img=x,
img_ids=img_ids,
txt=pos_t5_embeddings,
txt_ids=pos_txt_ids,
vec=pos_clip_embeddings,
neg_txt=neg_t5_embeddings,
neg_txt_ids=neg_txt_ids,
neg_vec=neg_clip_embeddings,
pos_regional_prompting_extension=pos_regional_prompting_extension,
neg_regional_prompting_extension=neg_regional_prompting_extension,
timesteps=timesteps,
step_callback=self._build_step_callback(context),
guidance=self.guidance,
@@ -357,6 +351,43 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
x = unpack(x.float(), self.height, self.width)
return x
def _load_text_conditioning(
self,
context: InvocationContext,
cond_field: FluxConditioningField | list[FluxConditioningField],
packed_height: int,
packed_width: int,
dtype: torch.dtype,
device: torch.device,
) -> list[FluxTextConditioning]:
"""Load text conditioning data from a FluxConditioningField or a list of FluxConditioningFields."""
# Normalize to a list of FluxConditioningFields.
cond_list = [cond_field] if isinstance(cond_field, FluxConditioningField) else cond_field
text_conditionings: list[FluxTextConditioning] = []
for cond_field in cond_list:
# Load the text embeddings.
cond_data = context.conditioning.load(cond_field.conditioning_name)
assert len(cond_data.conditionings) == 1
flux_conditioning = cond_data.conditionings[0]
assert isinstance(flux_conditioning, FLUXConditioningInfo)
flux_conditioning = flux_conditioning.to(dtype=dtype, device=device)
t5_embeddings = flux_conditioning.t5_embeds
clip_embeddings = flux_conditioning.clip_embeds
# Load the mask, if provided.
mask: Optional[torch.Tensor] = None
if cond_field.mask is not None:
mask = context.tensors.load(cond_field.mask.tensor_name)
mask = mask.to(device=device)
mask = RegionalPromptingExtension.preprocess_regional_prompt_mask(
mask, packed_height, packed_width, dtype, device
)
text_conditionings.append(FluxTextConditioning(t5_embeddings, clip_embeddings, mask))
return text_conditionings
@classmethod
def prep_cfg_scale(
cls, cfg_scale: float | list[float], timesteps: list[float], cfg_scale_start_step: int, cfg_scale_end_step: int

View File

@@ -1,11 +1,18 @@
from contextlib import ExitStack
from typing import Iterator, Literal, Tuple
from typing import Iterator, Literal, Optional, Tuple
import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
from invokeai.app.invocations.fields import (
FieldDescriptions,
FluxConditioningField,
Input,
InputField,
TensorField,
UIComponent,
)
from invokeai.app.invocations.model import CLIPField, T5EncoderField
from invokeai.app.invocations.primitives import FluxConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
@@ -15,6 +22,7 @@ from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
from invokeai.backend.util.devices import TorchDevice
@invocation(
@@ -22,7 +30,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Condit
title="FLUX Text Encoding",
tags=["prompt", "conditioning", "flux"],
category="conditioning",
version="1.1.0",
version="1.1.1",
classification=Classification.Prototype,
)
class FluxTextEncoderInvocation(BaseInvocation):
@@ -41,7 +49,10 @@ class FluxTextEncoderInvocation(BaseInvocation):
t5_max_seq_len: Literal[256, 512] = InputField(
description="Max sequence length for the T5 encoder. Expected to be 256 for FLUX schnell models and 512 for FLUX dev models."
)
prompt: str = InputField(description="Text prompt to encode.")
prompt: str = InputField(description="Text prompt to encode.", ui_component=UIComponent.Textarea)
mask: Optional[TensorField] = InputField(
default=None, description="A mask defining the region that this conditioning prompt applies to."
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
@@ -54,7 +65,9 @@ class FluxTextEncoderInvocation(BaseInvocation):
)
conditioning_name = context.conditioning.save(conditioning_data)
return FluxConditioningOutput.build(conditioning_name)
return FluxConditioningOutput(
conditioning=FluxConditioningField(conditioning_name=conditioning_name, mask=self.mask)
)
def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
@@ -99,10 +112,11 @@ class FluxTextEncoderInvocation(BaseInvocation):
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
exit_stack.enter_context(
LoRAPatcher.apply_lora_patches(
LoRAPatcher.apply_smart_lora_patches(
model=clip_text_encoder,
patches=self._clip_lora_iterator(context),
prefix=FLUX_LORA_CLIP_PREFIX,
dtype=TorchDevice.choose_torch_dtype(),
cached_weights=cached_weights,
)
)

View File

@@ -0,0 +1,59 @@
from pydantic import ValidationInfo, field_validator
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import InputField, OutputField
from invokeai.app.services.shared.invocation_context import InvocationContext
@invocation_output("image_panel_coordinate_output")
class ImagePanelCoordinateOutput(BaseInvocationOutput):
x_left: int = OutputField(description="The left x-coordinate of the panel.")
y_top: int = OutputField(description="The top y-coordinate of the panel.")
width: int = OutputField(description="The width of the panel.")
height: int = OutputField(description="The height of the panel.")
@invocation(
"image_panel_layout",
title="Image Panel Layout",
tags=["image", "panel", "layout"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class ImagePanelLayoutInvocation(BaseInvocation):
"""Get the coordinates of a single panel in a grid. (If the full image shape cannot be divided evenly into panels,
then the grid may not cover the entire image.)
"""
width: int = InputField(description="The width of the entire grid.")
height: int = InputField(description="The height of the entire grid.")
num_cols: int = InputField(ge=1, default=1, description="The number of columns in the grid.")
num_rows: int = InputField(ge=1, default=1, description="The number of rows in the grid.")
panel_col_idx: int = InputField(ge=0, default=0, description="The column index of the panel to be processed.")
panel_row_idx: int = InputField(ge=0, default=0, description="The row index of the panel to be processed.")
@field_validator("panel_col_idx")
def validate_panel_col_idx(cls, v: int, info: ValidationInfo) -> int:
if v < 0 or v >= info.data["num_cols"]:
raise ValueError(f"panel_col_idx must be between 0 and {info.data['num_cols'] - 1}")
return v
@field_validator("panel_row_idx")
def validate_panel_row_idx(cls, v: int, info: ValidationInfo) -> int:
if v < 0 or v >= info.data["num_rows"]:
raise ValueError(f"panel_row_idx must be between 0 and {info.data['num_rows'] - 1}")
return v
def invoke(self, context: InvocationContext) -> ImagePanelCoordinateOutput:
x_left = self.panel_col_idx * (self.width // self.num_cols)
y_top = self.panel_row_idx * (self.height // self.num_rows)
width = self.width // self.num_cols
height = self.height // self.num_rows
return ImagePanelCoordinateOutput(x_left=x_left, y_top=y_top, width=width, height=height)

View File

@@ -21,6 +21,7 @@ from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, SD3ConditioningInfo
from invokeai.backend.util.devices import TorchDevice
# The SD3 T5 Max Sequence Length set based on the default in diffusers.
SD3_T5_MAX_SEQ_LEN = 256
@@ -150,10 +151,11 @@ class Sd3TextEncoderInvocation(BaseInvocation):
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
exit_stack.enter_context(
LoRAPatcher.apply_lora_patches(
LoRAPatcher.apply_smart_lora_patches(
model=clip_text_encoder,
patches=self._clip_lora_iterator(context, clip_model),
prefix=FLUX_LORA_CLIP_PREFIX,
dtype=TorchDevice.choose_torch_dtype(),
cached_weights=cached_weights,
)
)

View File

@@ -207,7 +207,9 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
with (
ExitStack() as exit_stack,
unet_info as unet,
LoRAPatcher.apply_lora_patches(model=unet, patches=_lora_loader(), prefix="lora_unet_"),
LoRAPatcher.apply_smart_lora_patches(
model=unet, patches=_lora_loader(), prefix="lora_unet_", dtype=unet.dtype
),
):
assert isinstance(unet, UNet2DConditionModel)
latents = latents.to(device=unet.device, dtype=unet.dtype)

View File

@@ -20,7 +20,7 @@ from invokeai.app.services.invocation_stats.invocation_stats_common import (
NodeExecutionStatsSummary,
)
from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_manager.load.model_cache import CacheStats
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
# Size of 1GB in bytes.
GB = 2**30

View File

@@ -7,7 +7,7 @@ from typing import Callable, Optional
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
class ModelLoadServiceBase(ABC):
@@ -24,7 +24,7 @@ class ModelLoadServiceBase(ABC):
@property
@abstractmethod
def ram_cache(self) -> ModelCacheBase[AnyModel]:
def ram_cache(self) -> ModelCache:
"""Return the RAM cache used by this loader."""
@abstractmethod

View File

@@ -18,7 +18,7 @@ from invokeai.backend.model_manager.load import (
ModelLoaderRegistry,
ModelLoaderRegistryBase,
)
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
@@ -30,7 +30,7 @@ class ModelLoadService(ModelLoadServiceBase):
def __init__(
self,
app_config: InvokeAIAppConfig,
ram_cache: ModelCacheBase[AnyModel],
ram_cache: ModelCache,
registry: Optional[Type[ModelLoaderRegistryBase]] = ModelLoaderRegistry,
):
"""Initialize the model load service."""
@@ -45,7 +45,7 @@ class ModelLoadService(ModelLoadServiceBase):
self._invoker = invoker
@property
def ram_cache(self) -> ModelCacheBase[AnyModel]:
def ram_cache(self) -> ModelCache:
"""Return the RAM cache used by this loader."""
return self._ram_cache
@@ -78,15 +78,14 @@ class ModelLoadService(ModelLoadServiceBase):
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
) -> LoadedModelWithoutConfig:
cache_key = str(model_path)
ram_cache = self.ram_cache
try:
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))
return LoadedModelWithoutConfig(cache_record=self._ram_cache.get(key=cache_key), cache=self._ram_cache)
except IndexError:
pass
def torch_load_file(checkpoint: Path) -> AnyModel:
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0:
if scan_result.infected_files != 0 or scan_result.scan_err:
raise Exception("The model at {checkpoint} is potentially infected by malware. Aborting load.")
result = torch_load(checkpoint, map_location="cpu")
return result
@@ -109,5 +108,5 @@ class ModelLoadService(ModelLoadServiceBase):
)
assert loader is not None
raw_model = loader(model_path)
ram_cache.put(key=cache_key, model=raw_model)
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))
self._ram_cache.put(key=cache_key, model=raw_model)
return LoadedModelWithoutConfig(cache_record=self._ram_cache.get(key=cache_key), cache=self._ram_cache)

View File

@@ -16,7 +16,8 @@ from invokeai.app.services.model_load.model_load_base import ModelLoadServiceBas
from invokeai.app.services.model_load.model_load_default import ModelLoadService
from invokeai.app.services.model_manager.model_manager_base import ModelManagerServiceBase
from invokeai.app.services.model_records.model_records_base import ModelRecordServiceBase
from invokeai.backend.model_manager.load import ModelCache, ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger

View File

@@ -1,9 +1,10 @@
import einops
import torch
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
from invokeai.backend.flux.math import attention
from invokeai.backend.flux.modules.layers import DoubleStreamBlock
from invokeai.backend.flux.modules.layers import DoubleStreamBlock, SingleStreamBlock
class CustomDoubleStreamBlockProcessor:
@@ -13,7 +14,12 @@ class CustomDoubleStreamBlockProcessor:
@staticmethod
def _double_stream_block_forward(
block: DoubleStreamBlock, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, pe: torch.Tensor
block: DoubleStreamBlock,
img: torch.Tensor,
txt: torch.Tensor,
vec: torch.Tensor,
pe: torch.Tensor,
attn_mask: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""This function is a direct copy of DoubleStreamBlock.forward(), but it returns some of the intermediate
values.
@@ -40,7 +46,7 @@ class CustomDoubleStreamBlockProcessor:
k = torch.cat((txt_k, img_k), dim=2)
v = torch.cat((txt_v, img_v), dim=2)
attn = attention(q, k, v, pe=pe)
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# calculate the img bloks
@@ -63,11 +69,15 @@ class CustomDoubleStreamBlockProcessor:
vec: torch.Tensor,
pe: torch.Tensor,
ip_adapter_extensions: list[XLabsIPAdapterExtension],
regional_prompting_extension: RegionalPromptingExtension,
) -> tuple[torch.Tensor, torch.Tensor]:
"""A custom implementation of DoubleStreamBlock.forward() with additional features:
- IP-Adapter support
"""
img, txt, img_q = CustomDoubleStreamBlockProcessor._double_stream_block_forward(block, img, txt, vec, pe)
attn_mask = regional_prompting_extension.get_double_stream_attn_mask(block_index)
img, txt, img_q = CustomDoubleStreamBlockProcessor._double_stream_block_forward(
block, img, txt, vec, pe, attn_mask=attn_mask
)
# Apply IP-Adapter conditioning.
for ip_adapter_extension in ip_adapter_extensions:
@@ -81,3 +91,48 @@ class CustomDoubleStreamBlockProcessor:
)
return img, txt
class CustomSingleStreamBlockProcessor:
"""A class containing a custom implementation of SingleStreamBlock.forward() with additional features (masking,
etc.)
"""
@staticmethod
def _single_stream_block_forward(
block: SingleStreamBlock,
x: torch.Tensor,
vec: torch.Tensor,
pe: torch.Tensor,
attn_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""This function is a direct copy of SingleStreamBlock.forward()."""
mod, _ = block.modulation(vec)
x_mod = (1 + mod.scale) * block.pre_norm(x) + mod.shift
qkv, mlp = torch.split(block.linear1(x_mod), [3 * block.hidden_size, block.mlp_hidden_dim], dim=-1)
q, k, v = einops.rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=block.num_heads)
q, k = block.norm(q, k, v)
# compute attention
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
# compute activation in mlp stream, cat again and run second linear layer
output = block.linear2(torch.cat((attn, block.mlp_act(mlp)), 2))
return x + mod.gate * output
@staticmethod
def custom_single_block_forward(
timestep_index: int,
total_num_timesteps: int,
block_index: int,
block: SingleStreamBlock,
img: torch.Tensor,
vec: torch.Tensor,
pe: torch.Tensor,
regional_prompting_extension: RegionalPromptingExtension,
) -> torch.Tensor:
"""A custom implementation of SingleStreamBlock.forward() with additional features:
- Masking
"""
attn_mask = regional_prompting_extension.get_single_stream_attn_mask(block_index)
return CustomSingleStreamBlockProcessor._single_stream_block_forward(block, img, vec, pe, attn_mask=attn_mask)

View File

@@ -7,6 +7,7 @@ from tqdm import tqdm
from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput, sum_controlnet_flux_outputs
from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
from invokeai.backend.flux.model import Flux
@@ -18,14 +19,8 @@ def denoise(
# model input
img: torch.Tensor,
img_ids: torch.Tensor,
# positive text conditioning
txt: torch.Tensor,
txt_ids: torch.Tensor,
vec: torch.Tensor,
# negative text conditioning
neg_txt: torch.Tensor | None,
neg_txt_ids: torch.Tensor | None,
neg_vec: torch.Tensor | None,
pos_regional_prompting_extension: RegionalPromptingExtension,
neg_regional_prompting_extension: RegionalPromptingExtension | None,
# sampling parameters
timesteps: list[float],
step_callback: Callable[[PipelineIntermediateState], None],
@@ -61,9 +56,9 @@ def denoise(
total_num_timesteps=total_steps,
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
timesteps=t_vec,
guidance=guidance_vec,
)
@@ -78,9 +73,9 @@ def denoise(
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
timesteps=t_vec,
guidance=guidance_vec,
timestep_index=step_index,
@@ -88,6 +83,7 @@ def denoise(
controlnet_double_block_residuals=merged_controlnet_residuals.double_block_residuals,
controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals,
ip_adapter_extensions=pos_ip_adapter_extensions,
regional_prompting_extension=pos_regional_prompting_extension,
)
step_cfg_scale = cfg_scale[step_index]
@@ -97,15 +93,15 @@ def denoise(
# TODO(ryand): Add option to run positive and negative predictions in a single batch for better performance
# on systems with sufficient VRAM.
if neg_txt is None or neg_txt_ids is None or neg_vec is None:
if neg_regional_prompting_extension is None:
raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.")
neg_pred = model(
img=img,
img_ids=img_ids,
txt=neg_txt,
txt_ids=neg_txt_ids,
y=neg_vec,
txt=neg_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
txt_ids=neg_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
y=neg_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
timesteps=t_vec,
guidance=guidance_vec,
timestep_index=step_index,
@@ -113,6 +109,7 @@ def denoise(
controlnet_double_block_residuals=None,
controlnet_single_block_residuals=None,
ip_adapter_extensions=neg_ip_adapter_extensions,
regional_prompting_extension=neg_regional_prompting_extension,
)
pred = neg_pred + step_cfg_scale * (pred - neg_pred)

View File

@@ -0,0 +1,276 @@
from typing import Optional
import torch
import torchvision
from invokeai.backend.flux.text_conditioning import FluxRegionalTextConditioning, FluxTextConditioning
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Range
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.mask import to_standard_float_mask
class RegionalPromptingExtension:
"""A class for managing regional prompting with FLUX.
This implementation is inspired by https://arxiv.org/pdf/2411.02395 (though there are significant differences).
"""
def __init__(
self,
regional_text_conditioning: FluxRegionalTextConditioning,
restricted_attn_mask: torch.Tensor | None = None,
):
self.regional_text_conditioning = regional_text_conditioning
self.restricted_attn_mask = restricted_attn_mask
def get_double_stream_attn_mask(self, block_index: int) -> torch.Tensor | None:
order = [self.restricted_attn_mask, None]
return order[block_index % len(order)]
def get_single_stream_attn_mask(self, block_index: int) -> torch.Tensor | None:
order = [self.restricted_attn_mask, None]
return order[block_index % len(order)]
@classmethod
def from_text_conditioning(cls, text_conditioning: list[FluxTextConditioning], img_seq_len: int):
"""Create a RegionalPromptingExtension from a list of text conditionings.
Args:
text_conditioning (list[FluxTextConditioning]): The text conditionings to use for regional prompting.
img_seq_len (int): The image sequence length (i.e. packed_height * packed_width).
"""
regional_text_conditioning = cls._concat_regional_text_conditioning(text_conditioning)
attn_mask_with_restricted_img_self_attn = cls._prepare_restricted_attn_mask(
regional_text_conditioning, img_seq_len
)
return cls(
regional_text_conditioning=regional_text_conditioning,
restricted_attn_mask=attn_mask_with_restricted_img_self_attn,
)
# Keeping _prepare_unrestricted_attn_mask for reference as an alternative masking strategy:
#
# @classmethod
# def _prepare_unrestricted_attn_mask(
# cls,
# regional_text_conditioning: FluxRegionalTextConditioning,
# img_seq_len: int,
# ) -> torch.Tensor:
# """Prepare an 'unrestricted' attention mask. In this context, 'unrestricted' means that:
# - img self-attention is not masked.
# - img regions attend to both txt within their own region and to global prompts.
# """
# device = TorchDevice.choose_torch_device()
# # Infer txt_seq_len from the t5_embeddings tensor.
# txt_seq_len = regional_text_conditioning.t5_embeddings.shape[1]
# # In the attention blocks, the txt seq and img seq are concatenated and then attention is applied.
# # Concatenation happens in the following order: [txt_seq, img_seq].
# # There are 4 portions of the attention mask to consider as we prepare it:
# # 1. txt attends to itself
# # 2. txt attends to corresponding regional img
# # 3. regional img attends to corresponding txt
# # 4. regional img attends to itself
# # Initialize empty attention mask.
# regional_attention_mask = torch.zeros(
# (txt_seq_len + img_seq_len, txt_seq_len + img_seq_len), device=device, dtype=torch.float16
# )
# for image_mask, t5_embedding_range in zip(
# regional_text_conditioning.image_masks, regional_text_conditioning.t5_embedding_ranges, strict=True
# ):
# # 1. txt attends to itself
# regional_attention_mask[
# t5_embedding_range.start : t5_embedding_range.end, t5_embedding_range.start : t5_embedding_range.end
# ] = 1.0
# # 2. txt attends to corresponding regional img
# # Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired.
# fill_value = image_mask.view(1, img_seq_len) if image_mask is not None else 1.0
# regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = fill_value
# # 3. regional img attends to corresponding txt
# # Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired.
# fill_value = image_mask.view(img_seq_len, 1) if image_mask is not None else 1.0
# regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = fill_value
# # 4. regional img attends to itself
# # Allow unrestricted img self attention.
# regional_attention_mask[txt_seq_len:, txt_seq_len:] = 1.0
# # Convert attention mask to boolean.
# regional_attention_mask = regional_attention_mask > 0.5
# return regional_attention_mask
@classmethod
def _prepare_restricted_attn_mask(
cls,
regional_text_conditioning: FluxRegionalTextConditioning,
img_seq_len: int,
) -> torch.Tensor | None:
"""Prepare a 'restricted' attention mask. In this context, 'restricted' means that:
- img self-attention is only allowed within regions.
- img regions only attend to txt within their own region, not to global prompts.
"""
# Identify background region. I.e. the region that is not covered by any region masks.
background_region_mask: None | torch.Tensor = None
for image_mask in regional_text_conditioning.image_masks:
if image_mask is not None:
if background_region_mask is None:
background_region_mask = torch.ones_like(image_mask)
background_region_mask *= 1 - image_mask
if background_region_mask is None:
# There are no region masks, short-circuit and return None.
# TODO(ryand): We could restrict txt-txt attention across multiple global prompts, but this would
# is a rare use case and would make the logic here significantly more complicated.
return None
device = TorchDevice.choose_torch_device()
# Infer txt_seq_len from the t5_embeddings tensor.
txt_seq_len = regional_text_conditioning.t5_embeddings.shape[1]
# In the attention blocks, the txt seq and img seq are concatenated and then attention is applied.
# Concatenation happens in the following order: [txt_seq, img_seq].
# There are 4 portions of the attention mask to consider as we prepare it:
# 1. txt attends to itself
# 2. txt attends to corresponding regional img
# 3. regional img attends to corresponding txt
# 4. regional img attends to itself
# Initialize empty attention mask.
regional_attention_mask = torch.zeros(
(txt_seq_len + img_seq_len, txt_seq_len + img_seq_len), device=device, dtype=torch.float16
)
for image_mask, t5_embedding_range in zip(
regional_text_conditioning.image_masks, regional_text_conditioning.t5_embedding_ranges, strict=True
):
# 1. txt attends to itself
regional_attention_mask[
t5_embedding_range.start : t5_embedding_range.end, t5_embedding_range.start : t5_embedding_range.end
] = 1.0
if image_mask is not None:
# 2. txt attends to corresponding regional img
# Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired.
regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = (
image_mask.view(1, img_seq_len)
)
# 3. regional img attends to corresponding txt
# Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired.
regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = (
image_mask.view(img_seq_len, 1)
)
# 4. regional img attends to itself
image_mask = image_mask.view(img_seq_len, 1)
regional_attention_mask[txt_seq_len:, txt_seq_len:] += image_mask @ image_mask.T
else:
# We don't allow attention between non-background image regions and global prompts. This helps to ensure
# that regions focus on their local prompts. We do, however, allow attention between background regions
# and global prompts. If we didn't do this, then the background regions would not attend to any txt
# embeddings, which we found experimentally to cause artifacts.
# 2. global txt attends to background region
# Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired.
regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = (
background_region_mask.view(1, img_seq_len)
)
# 3. background region attends to global txt
# Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired.
regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = (
background_region_mask.view(img_seq_len, 1)
)
# Allow background regions to attend to themselves.
regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(img_seq_len, 1)
regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(1, img_seq_len)
# Convert attention mask to boolean.
regional_attention_mask = regional_attention_mask > 0.5
return regional_attention_mask
@classmethod
def _concat_regional_text_conditioning(
cls,
text_conditionings: list[FluxTextConditioning],
) -> FluxRegionalTextConditioning:
"""Concatenate regional text conditioning data into a single conditioning tensor (with associated masks)."""
concat_t5_embeddings: list[torch.Tensor] = []
concat_t5_embedding_ranges: list[Range] = []
image_masks: list[torch.Tensor | None] = []
# Choose global CLIP embedding.
# Use the first global prompt's CLIP embedding as the global CLIP embedding. If there is no global prompt, use
# the first prompt's CLIP embedding.
global_clip_embedding: torch.Tensor = text_conditionings[0].clip_embeddings
for text_conditioning in text_conditionings:
if text_conditioning.mask is None:
global_clip_embedding = text_conditioning.clip_embeddings
break
cur_t5_embedding_len = 0
for text_conditioning in text_conditionings:
concat_t5_embeddings.append(text_conditioning.t5_embeddings)
concat_t5_embedding_ranges.append(
Range(start=cur_t5_embedding_len, end=cur_t5_embedding_len + text_conditioning.t5_embeddings.shape[1])
)
image_masks.append(text_conditioning.mask)
cur_t5_embedding_len += text_conditioning.t5_embeddings.shape[1]
t5_embeddings = torch.cat(concat_t5_embeddings, dim=1)
# Initialize the txt_ids tensor.
pos_bs, pos_t5_seq_len, _ = t5_embeddings.shape
t5_txt_ids = torch.zeros(
pos_bs, pos_t5_seq_len, 3, dtype=t5_embeddings.dtype, device=TorchDevice.choose_torch_device()
)
return FluxRegionalTextConditioning(
t5_embeddings=t5_embeddings,
clip_embeddings=global_clip_embedding,
t5_txt_ids=t5_txt_ids,
image_masks=image_masks,
t5_embedding_ranges=concat_t5_embedding_ranges,
)
@staticmethod
def preprocess_regional_prompt_mask(
mask: Optional[torch.Tensor], packed_height: int, packed_width: int, dtype: torch.dtype, device: torch.device
) -> torch.Tensor:
"""Preprocess a regional prompt mask to match the target height and width.
If mask is None, returns a mask of all ones with the target height and width.
If mask is not None, resizes the mask to the target height and width using 'nearest' interpolation.
packed_height and packed_width are the target height and width of the mask in the 'packed' latent space.
Returns:
torch.Tensor: The processed mask. shape: (1, 1, packed_height * packed_width).
"""
if mask is None:
return torch.ones((1, 1, packed_height * packed_width), dtype=dtype, device=device)
mask = to_standard_float_mask(mask, out_dtype=dtype)
tf = torchvision.transforms.Resize(
(packed_height, packed_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
)
# Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w).
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
resized_mask = tf(mask)
# Flatten the height and width dimensions into a single image_seq_len dimension.
return resized_mask.flatten(start_dim=2)

View File

@@ -5,10 +5,10 @@ from einops import rearrange
from torch import Tensor
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Tensor | None = None) -> Tensor:
q, k = apply_rope(q, k, pe)
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
x = rearrange(x, "B H L D -> B L (H D)")
return x
@@ -24,12 +24,12 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
out = torch.einsum("...n,d->...nd", pos, omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
return out.float()
return out.to(dtype=pos.dtype, device=pos.device)
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_ = xq.view(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.view(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
return xq_out.view(*xq.shape), xk_out.view(*xk.shape)

View File

@@ -5,7 +5,11 @@ from dataclasses import dataclass
import torch
from torch import Tensor, nn
from invokeai.backend.flux.custom_block_processor import CustomDoubleStreamBlockProcessor
from invokeai.backend.flux.custom_block_processor import (
CustomDoubleStreamBlockProcessor,
CustomSingleStreamBlockProcessor,
)
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
from invokeai.backend.flux.modules.layers import (
DoubleStreamBlock,
@@ -95,6 +99,7 @@ class Flux(nn.Module):
controlnet_double_block_residuals: list[Tensor] | None,
controlnet_single_block_residuals: list[Tensor] | None,
ip_adapter_extensions: list[XLabsIPAdapterExtension],
regional_prompting_extension: RegionalPromptingExtension,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
@@ -117,7 +122,6 @@ class Flux(nn.Module):
assert len(controlnet_double_block_residuals) == len(self.double_blocks)
for block_index, block in enumerate(self.double_blocks):
assert isinstance(block, DoubleStreamBlock)
img, txt = CustomDoubleStreamBlockProcessor.custom_double_block_forward(
timestep_index=timestep_index,
total_num_timesteps=total_num_timesteps,
@@ -128,6 +132,7 @@ class Flux(nn.Module):
vec=vec,
pe=pe,
ip_adapter_extensions=ip_adapter_extensions,
regional_prompting_extension=regional_prompting_extension,
)
if controlnet_double_block_residuals is not None:
@@ -140,7 +145,17 @@ class Flux(nn.Module):
assert len(controlnet_single_block_residuals) == len(self.single_blocks)
for block_index, block in enumerate(self.single_blocks):
img = block(img, vec=vec, pe=pe)
assert isinstance(block, SingleStreamBlock)
img = CustomSingleStreamBlockProcessor.custom_single_block_forward(
timestep_index=timestep_index,
total_num_timesteps=total_num_timesteps,
block_index=block_index,
block=block,
img=img,
vec=vec,
pe=pe,
regional_prompting_extension=regional_prompting_extension,
)
if controlnet_single_block_residuals is not None:
img[:, txt.shape[1] :, ...] += controlnet_single_block_residuals[block_index]

View File

@@ -66,10 +66,7 @@ class RMSNorm(torch.nn.Module):
self.scale = nn.Parameter(torch.ones(dim))
def forward(self, x: Tensor):
x_dtype = x.dtype
x = x.float()
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
return (x * rrms).to(dtype=x_dtype) * self.scale
return torch.nn.functional.rms_norm(x, self.scale.shape, self.scale, eps=1e-6)
class QKNorm(torch.nn.Module):

View File

@@ -0,0 +1,36 @@
from dataclasses import dataclass
import torch
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Range
@dataclass
class FluxTextConditioning:
t5_embeddings: torch.Tensor
clip_embeddings: torch.Tensor
# If mask is None, the prompt is a global prompt.
mask: torch.Tensor | None
@dataclass
class FluxRegionalTextConditioning:
# Concatenated text embeddings.
# Shape: (1, concatenated_txt_seq_len, 4096)
t5_embeddings: torch.Tensor
# Shape: (1, concatenated_txt_seq_len, 3)
t5_txt_ids: torch.Tensor
# Global CLIP embeddings.
# Shape: (1, 768)
clip_embeddings: torch.Tensor
# A binary mask indicating the regions of the image that the prompt should be applied to. If None, the prompt is a
# global prompt.
# image_masks[i] is the mask for the ith prompt.
# image_masks[i] has shape (1, image_seq_len) and dtype torch.bool.
image_masks: list[torch.Tensor | None]
# List of ranges that represent the embedding ranges for each mask.
# t5_embedding_ranges[i] contains the range of the t5 embeddings that correspond to image_masks[i].
t5_embedding_ranges: list[Range]

View File

@@ -0,0 +1,133 @@
import torch
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
class LoRASidecarWrapper(torch.nn.Module):
def __init__(self, orig_module: torch.nn.Module, lora_layers: list[AnyLoRALayer], lora_weights: list[float]):
super().__init__()
self._orig_module = orig_module
self._lora_layers = lora_layers
self._lora_weights = lora_weights
@property
def orig_module(self) -> torch.nn.Module:
return self._orig_module
def add_lora_layer(self, lora_layer: AnyLoRALayer, lora_weight: float):
self._lora_layers.append(lora_layer)
self._lora_weights.append(lora_weight)
@torch.no_grad()
def _get_lora_patched_parameters(
self, orig_params: dict[str, torch.Tensor], lora_layers: list[AnyLoRALayer], lora_weights: list[float]
) -> dict[str, torch.Tensor]:
params: dict[str, torch.Tensor] = {}
for lora_layer, lora_weight in zip(lora_layers, lora_weights, strict=True):
layer_params = lora_layer.get_parameters(self._orig_module)
for param_name, param_weight in layer_params.items():
if orig_params[param_name].shape != param_weight.shape:
param_weight = param_weight.reshape(orig_params[param_name].shape)
if param_name not in params:
params[param_name] = param_weight * (lora_layer.scale() * lora_weight)
else:
params[param_name] += param_weight * (lora_layer.scale() * lora_weight)
return params
class LoRALinearWrapper(LoRASidecarWrapper):
def _lora_linear_forward(self, input: torch.Tensor, lora_layer: LoRALayer, lora_weight: float) -> torch.Tensor:
"""An optimized implementation of the residual calculation for a Linear LoRALayer."""
x = torch.nn.functional.linear(input, lora_layer.down)
if lora_layer.mid is not None:
x = torch.nn.functional.linear(x, lora_layer.mid)
x = torch.nn.functional.linear(x, lora_layer.up, bias=lora_layer.bias)
x *= lora_weight * lora_layer.scale()
return x
def _concatenated_lora_forward(
self, input: torch.Tensor, concatenated_lora_layer: ConcatenatedLoRALayer, lora_weight: float
) -> torch.Tensor:
"""An optimized implementation of the residual calculation for a Linear ConcatenatedLoRALayer."""
x_chunks: list[torch.Tensor] = []
for lora_layer in concatenated_lora_layer.lora_layers:
x_chunk = torch.nn.functional.linear(input, lora_layer.down)
if lora_layer.mid is not None:
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.mid)
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.up, bias=lora_layer.bias)
x_chunk *= lora_weight * lora_layer.scale()
x_chunks.append(x_chunk)
# TODO(ryand): Generalize to support concat_axis != 0.
assert concatenated_lora_layer.concat_axis == 0
x = torch.cat(x_chunks, dim=-1)
return x
def forward(self, input: torch.Tensor) -> torch.Tensor:
# Split the LoRA layers into those that have optimized implementations and those that don't.
optimized_layer_types = (LoRALayer, ConcatenatedLoRALayer)
optimized_layers = [
(layer, weight)
for layer, weight in zip(self._lora_layers, self._lora_weights, strict=True)
if isinstance(layer, optimized_layer_types)
]
non_optimized_layers = [
(layer, weight)
for layer, weight in zip(self._lora_layers, self._lora_weights, strict=True)
if not isinstance(layer, optimized_layer_types)
]
# First, calculate the residual for LoRA layers for which there is an optimized implementation.
residual = None
for lora_layer, lora_weight in optimized_layers:
if isinstance(lora_layer, LoRALayer):
added_residual = self._lora_linear_forward(input, lora_layer, lora_weight)
elif isinstance(lora_layer, ConcatenatedLoRALayer):
added_residual = self._concatenated_lora_forward(input, lora_layer, lora_weight)
else:
raise ValueError(f"Unsupported LoRA layer type: {type(lora_layer)}")
if residual is None:
residual = added_residual
else:
residual += added_residual
# Next, calculate the residuals for the LoRA layers for which there is no optimized implementation.
if non_optimized_layers:
unoptimized_layers, unoptimized_weights = zip(*non_optimized_layers, strict=True)
params = self._get_lora_patched_parameters(
orig_params={"weight": self._orig_module.weight, "bias": self._orig_module.bias},
lora_layers=unoptimized_layers,
lora_weights=unoptimized_weights,
)
added_residual = torch.nn.functional.linear(input, params["weight"], params.get("bias", None))
if residual is None:
residual = added_residual
else:
residual += added_residual
return self.orig_module(input) + residual
class LoRAConv1dWrapper(LoRASidecarWrapper):
def forward(self, input: torch.Tensor) -> torch.Tensor:
params = self._get_lora_patched_parameters(
orig_params={"weight": self._orig_module.weight, "bias": self._orig_module.bias},
lora_layers=self._lora_layers,
lora_weights=self._lora_weights,
)
return self.orig_module(input) + torch.nn.functional.conv1d(input, params["weight"], params.get("bias", None))
class LoRAConv2dWrapper(LoRASidecarWrapper):
def forward(self, input: torch.Tensor) -> torch.Tensor:
params = self._get_lora_patched_parameters(
orig_params={"weight": self._orig_module.weight, "bias": self._orig_module.bias},
lora_layers=self._lora_layers,
lora_weights=self._lora_weights,
)
return self.orig_module(input) + torch.nn.functional.conv2d(input, params["weight"], params.get("bias", None))

View File

@@ -4,19 +4,126 @@ from typing import Dict, Iterable, Optional, Tuple
import torch
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.sidecar_layers.concatenated_lora.concatenated_lora_linear_sidecar_layer import (
ConcatenatedLoRALinearSidecarLayer,
from invokeai.backend.lora.lora_layer_wrappers import (
LoRAConv1dWrapper,
LoRAConv2dWrapper,
LoRALinearWrapper,
LoRASidecarWrapper,
)
from invokeai.backend.lora.sidecar_layers.lora.lora_linear_sidecar_layer import LoRALinearSidecarLayer
from invokeai.backend.lora.sidecar_layers.lora_sidecar_module import LoRASidecarModule
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
class LoRAPatcher:
@staticmethod
@torch.no_grad()
@contextmanager
def apply_smart_lora_patches(
model: torch.nn.Module,
patches: Iterable[Tuple[LoRAModelRaw, float]],
prefix: str,
dtype: torch.dtype,
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
):
"""Apply 'smart' LoRA patching that chooses whether to use direct patching or a sidecar wrapper for each module."""
# original_weights are stored for unpatching layers that are directly patched.
original_weights = OriginalWeightsStorage(cached_weights)
# original_modules are stored for unpatching layers that are wrapped in a LoRASidecarWrapper.
original_modules: dict[str, torch.nn.Module] = {}
try:
for patch, patch_weight in patches:
LoRAPatcher._apply_smart_lora_patch(
model=model,
prefix=prefix,
patch=patch,
patch_weight=patch_weight,
original_weights=original_weights,
original_modules=original_modules,
dtype=dtype,
)
yield
finally:
# Restore directly patched layers.
for param_key, weight in original_weights.get_changed_weights():
model.get_parameter(param_key).copy_(weight)
# Restore LoRASidecarWrapper modules.
# Note: This logic assumes no nested modules in original_modules.
for module_key, orig_module in original_modules.items():
module_parent_key, module_name = LoRAPatcher._split_parent_key(module_key)
parent_module = model.get_submodule(module_parent_key)
LoRAPatcher._set_submodule(parent_module, module_name, orig_module)
@staticmethod
@torch.no_grad()
def _apply_smart_lora_patch(
model: torch.nn.Module,
prefix: str,
patch: LoRAModelRaw,
patch_weight: float,
original_weights: OriginalWeightsStorage,
original_modules: dict[str, torch.nn.Module],
dtype: torch.dtype,
):
"""Apply a single LoRA patch to a model using the 'smart' patching strategy that chooses whether to use direct
patching or a sidecar wrapper for each module.
"""
if patch_weight == 0:
return
# If the layer keys contain a dot, then they are not flattened, and can be directly used to access model
# submodules. If the layer keys do not contain a dot, then they are flattened, meaning that all '.' have been
# replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly
# without searching, but some legacy code still uses flattened keys.
layer_keys_are_flattened = "." not in next(iter(patch.layers.keys()))
prefix_len = len(prefix)
for layer_key, layer in patch.layers.items():
if not layer_key.startswith(prefix):
continue
module_key, module = LoRAPatcher._get_submodule(
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
)
# Decide whether to use direct patching or a sidecar wrapper.
# Direct patching is preferred, because it results in better runtime speed.
# Reasons to use sidecar patching:
# - The module is already wrapped in a LoRASidecarWrapper.
# - The module is quantized.
# - The module is on the CPU (and we don't want to store a second full copy of the original weights on the
# CPU, since this would double the RAM usage)
# NOTE: For now, we don't check if the layer is quantized here. We assume that this is checked in the caller
# and that the caller will use the 'apply_lora_wrapper_patches' method if the layer is quantized.
# TODO(ryand): Handle the case where we are running without a GPU. Should we set a config flag that allows
# forcing full patching even on the CPU?
if isinstance(module, LoRASidecarWrapper) or LoRAPatcher._is_any_part_of_layer_on_cpu(module):
LoRAPatcher._apply_lora_layer_wrapper_patch(
model=model,
module_to_patch=module,
module_to_patch_key=module_key,
patch=layer,
patch_weight=patch_weight,
original_modules=original_modules,
dtype=dtype,
)
else:
LoRAPatcher._apply_lora_layer_patch(
module_to_patch=module,
module_to_patch_key=module_key,
patch=layer,
patch_weight=patch_weight,
original_weights=original_weights,
)
@staticmethod
def _is_any_part_of_layer_on_cpu(layer: torch.nn.Module) -> bool:
return any(p.device.type == "cpu" for p in layer.parameters())
@staticmethod
@torch.no_grad()
@contextmanager
@@ -40,7 +147,7 @@ class LoRAPatcher:
original_weights = OriginalWeightsStorage(cached_weights)
try:
for patch, patch_weight in patches:
LoRAPatcher.apply_lora_patch(
LoRAPatcher._apply_lora_patch(
model=model,
prefix=prefix,
patch=patch,
@@ -56,7 +163,7 @@ class LoRAPatcher:
@staticmethod
@torch.no_grad()
def apply_lora_patch(
def _apply_lora_patch(
model: torch.nn.Module,
prefix: str,
patch: LoRAModelRaw,
@@ -91,48 +198,67 @@ class LoRAPatcher:
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
)
# All of the LoRA weight calculations will be done on the same device as the module weight.
# (Performance will be best if this is a CUDA device.)
device = module.weight.device
dtype = module.weight.dtype
LoRAPatcher._apply_lora_layer_patch(
module_to_patch=module,
module_to_patch_key=module_key,
patch=layer,
patch_weight=patch_weight,
original_weights=original_weights,
)
layer_scale = layer.scale()
@staticmethod
@torch.no_grad()
def _apply_lora_layer_patch(
module_to_patch: torch.nn.Module,
module_to_patch_key: str,
patch: AnyLoRALayer,
patch_weight: float,
original_weights: OriginalWeightsStorage,
):
# All of the LoRA weight calculations will be done on the same device as the module weight.
# (Performance will be best if this is a CUDA device.)
device = module_to_patch.weight.device
dtype = module_to_patch.weight.dtype
# We intentionally move to the target device first, then cast. Experimentally, this was found to
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
# same thing in a single call to '.to(...)'.
layer.to(device=device)
layer.to(dtype=torch.float32)
layer_scale = patch.scale()
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
for param_name, lora_param_weight in layer.get_parameters(module).items():
param_key = module_key + "." + param_name
module_param = module.get_parameter(param_name)
# We intentionally move to the target device first, then cast. Experimentally, this was found to
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
# same thing in a single call to '.to(...)'.
patch.to(device=device)
patch.to(dtype=torch.float32)
# Save original weight
original_weights.save(param_key, module_param)
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
for param_name, lora_param_weight in patch.get_parameters(module_to_patch).items():
param_key = module_to_patch_key + "." + param_name
module_param = module_to_patch.get_parameter(param_name)
if module_param.shape != lora_param_weight.shape:
lora_param_weight = lora_param_weight.reshape(module_param.shape)
# Save original weight
original_weights.save(param_key, module_param)
lora_param_weight *= patch_weight * layer_scale
module_param += lora_param_weight.to(dtype=dtype)
if module_param.shape != lora_param_weight.shape:
lora_param_weight = lora_param_weight.reshape(module_param.shape)
layer.to(device=TorchDevice.CPU_DEVICE)
lora_param_weight *= patch_weight * layer_scale
module_param += lora_param_weight.to(dtype=dtype)
patch.to(device=TorchDevice.CPU_DEVICE)
@staticmethod
@torch.no_grad()
@contextmanager
def apply_lora_sidecar_patches(
def apply_lora_wrapper_patches(
model: torch.nn.Module,
patches: Iterable[Tuple[LoRAModelRaw, float]],
prefix: str,
dtype: torch.dtype,
):
"""Apply one or more LoRA sidecar patches to a model within a context manager. Sidecar patches incur some
overhead compared to normal LoRA patching, but they allow for LoRA layers to applied to base layers in any
quantization format.
"""Apply one or more LoRA wrapper patches to a model within a context manager. Wrapper patches incur some
runtime overhead compared to normal LoRA patching, but they enable:
- LoRA layers to be applied to quantized models
- LoRA layers to be applied to CPU layers without needing to store a full copy of the original weights (i.e.
avoid doubling the memory requirements).
Args:
model (torch.nn.Module): The model to patch.
@@ -140,14 +266,11 @@ class LoRAPatcher:
associated weights. An iterator is used so that the LoRA patches do not need to be loaded into memory
all at once.
prefix (str): The keys in the patches will be filtered to only include weights with this prefix.
dtype (torch.dtype): The compute dtype of the sidecar layers. This cannot easily be inferred from the model,
since the sidecar layers are typically applied on top of quantized layers whose weight dtype is
different from their compute dtype.
"""
original_modules: dict[str, torch.nn.Module] = {}
try:
for patch, patch_weight in patches:
LoRAPatcher._apply_lora_sidecar_patch(
LoRAPatcher._apply_lora_wrapper_patch(
model=model,
prefix=prefix,
patch=patch,
@@ -165,7 +288,7 @@ class LoRAPatcher:
LoRAPatcher._set_submodule(parent_module, module_name, orig_module)
@staticmethod
def _apply_lora_sidecar_patch(
def _apply_lora_wrapper_patch(
model: torch.nn.Module,
patch: LoRAModelRaw,
patch_weight: float,
@@ -173,7 +296,7 @@ class LoRAPatcher:
original_modules: dict[str, torch.nn.Module],
dtype: torch.dtype,
):
"""Apply a single LoRA sidecar patch to a model."""
"""Apply a single LoRA wrapper patch to a model."""
if patch_weight == 0:
return
@@ -194,28 +317,47 @@ class LoRAPatcher:
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
)
# Initialize the LoRA sidecar layer.
lora_sidecar_layer = LoRAPatcher._initialize_lora_sidecar_layer(module, layer, patch_weight)
LoRAPatcher._apply_lora_layer_wrapper_patch(
model=model,
module_to_patch=module,
module_to_patch_key=module_key,
patch=layer,
patch_weight=patch_weight,
original_modules=original_modules,
dtype=dtype,
)
# Replace the original module with a LoRASidecarModule if it has not already been done.
if module_key in original_modules:
# The module has already been patched with a LoRASidecarModule. Append to it.
assert isinstance(module, LoRASidecarModule)
lora_sidecar_module = module
else:
# The module has not yet been patched with a LoRASidecarModule. Create one.
lora_sidecar_module = LoRASidecarModule(module, [])
original_modules[module_key] = module
module_parent_key, module_name = LoRAPatcher._split_parent_key(module_key)
module_parent = model.get_submodule(module_parent_key)
LoRAPatcher._set_submodule(module_parent, module_name, lora_sidecar_module)
@staticmethod
@torch.no_grad()
def _apply_lora_layer_wrapper_patch(
model: torch.nn.Module,
module_to_patch: torch.nn.Module,
module_to_patch_key: str,
patch: AnyLoRALayer,
patch_weight: float,
original_modules: dict[str, torch.nn.Module],
dtype: torch.dtype,
):
"""Apply a single LoRA wrapper patch to a model."""
# Move the LoRA sidecar layer to the same device/dtype as the orig module.
# TODO(ryand): Experiment with moving to the device first, then casting. This could be faster.
lora_sidecar_layer.to(device=lora_sidecar_module.orig_module.weight.device, dtype=dtype)
# Replace the original module with a LoRASidecarWrapper if it has not already been done.
if not isinstance(module_to_patch, LoRASidecarWrapper):
lora_wrapper_layer = LoRAPatcher._initialize_lora_wrapper_layer(module_to_patch)
original_modules[module_to_patch_key] = module_to_patch
module_parent_key, module_name = LoRAPatcher._split_parent_key(module_to_patch_key)
module_parent = model.get_submodule(module_parent_key)
LoRAPatcher._set_submodule(module_parent, module_name, lora_wrapper_layer)
orig_module = module_to_patch
else:
assert module_to_patch_key in original_modules
lora_wrapper_layer = module_to_patch
orig_module = module_to_patch.orig_module
# Add the LoRA sidecar layer to the LoRASidecarModule.
lora_sidecar_module.add_lora_layer(lora_sidecar_layer)
# Move the LoRA layer to the same device/dtype as the orig module.
patch.to(device=orig_module.weight.device, dtype=dtype)
# Add the LoRA wrapper layer to the LoRASidecarWrapper.
lora_wrapper_layer.add_lora_layer(patch, patch_weight)
@staticmethod
def _split_parent_key(module_key: str) -> tuple[str, str]:
@@ -236,17 +378,13 @@ class LoRAPatcher:
raise ValueError(f"Invalid module key: {module_key}")
@staticmethod
def _initialize_lora_sidecar_layer(orig_layer: torch.nn.Module, lora_layer: AnyLoRALayer, patch_weight: float):
# TODO(ryand): Add support for more original layer types and LoRA layer types.
if isinstance(orig_layer, torch.nn.Linear) or (
isinstance(orig_layer, LoRASidecarModule) and isinstance(orig_layer.orig_module, torch.nn.Linear)
):
if isinstance(lora_layer, LoRALayer):
return LoRALinearSidecarLayer(lora_layer=lora_layer, weight=patch_weight)
elif isinstance(lora_layer, ConcatenatedLoRALayer):
return ConcatenatedLoRALinearSidecarLayer(concatenated_lora_layer=lora_layer, weight=patch_weight)
else:
raise ValueError(f"Unsupported Linear LoRA layer type: {type(lora_layer)}")
def _initialize_lora_wrapper_layer(orig_layer: torch.nn.Module):
if isinstance(orig_layer, torch.nn.Linear):
return LoRALinearWrapper(orig_layer, [], [])
elif isinstance(orig_layer, torch.nn.Conv1d):
return LoRAConv1dWrapper(orig_layer, [], [])
elif isinstance(orig_layer, torch.nn.Conv2d):
return LoRAConv2dWrapper(orig_layer, [], [])
else:
raise ValueError(f"Unsupported layer type: {type(orig_layer)}")

View File

@@ -1,34 +0,0 @@
import torch
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
class ConcatenatedLoRALinearSidecarLayer(torch.nn.Module):
def __init__(
self,
concatenated_lora_layer: ConcatenatedLoRALayer,
weight: float,
):
super().__init__()
self._concatenated_lora_layer = concatenated_lora_layer
self._weight = weight
def forward(self, input: torch.Tensor) -> torch.Tensor:
x_chunks: list[torch.Tensor] = []
for lora_layer in self._concatenated_lora_layer.lora_layers:
x_chunk = torch.nn.functional.linear(input, lora_layer.down)
if lora_layer.mid is not None:
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.mid)
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.up, bias=lora_layer.bias)
x_chunk *= self._weight * lora_layer.scale()
x_chunks.append(x_chunk)
# TODO(ryand): Generalize to support concat_axis != 0.
assert self._concatenated_lora_layer.concat_axis == 0
x = torch.cat(x_chunks, dim=-1)
return x
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
self._concatenated_lora_layer.to(device=device, dtype=dtype)
return self

View File

@@ -1,27 +0,0 @@
import torch
from invokeai.backend.lora.layers.lora_layer import LoRALayer
class LoRALinearSidecarLayer(torch.nn.Module):
def __init__(
self,
lora_layer: LoRALayer,
weight: float,
):
super().__init__()
self._lora_layer = lora_layer
self._weight = weight
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = torch.nn.functional.linear(x, self._lora_layer.down)
if self._lora_layer.mid is not None:
x = torch.nn.functional.linear(x, self._lora_layer.mid)
x = torch.nn.functional.linear(x, self._lora_layer.up, bias=self._lora_layer.bias)
x *= self._weight * self._lora_layer.scale()
return x
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
self._lora_layer.to(device=device, dtype=dtype)
return self

View File

@@ -1,24 +0,0 @@
import torch
class LoRASidecarModule(torch.nn.Module):
"""A LoRA sidecar module that wraps an original module and adds LoRA layers to it."""
def __init__(self, orig_module: torch.nn.Module, lora_layers: list[torch.nn.Module]):
super().__init__()
self.orig_module = orig_module
self._lora_layers = lora_layers
def add_lora_layer(self, lora_layer: torch.nn.Module):
self._lora_layers.append(lora_layer)
def forward(self, input: torch.Tensor) -> torch.Tensor:
x = self.orig_module(input)
for lora_layer in self._lora_layers:
x += lora_layer(input)
return x
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
self._orig_module.to(device=device, dtype=dtype)
for lora_layer in self._lora_layers:
lora_layer.to(device=device, dtype=dtype)

View File

@@ -8,7 +8,7 @@ from pathlib import Path
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig, ModelLoaderBase
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_cache.model_cache_default import ModelCache
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase
# This registers the subclasses that implement loaders of specific model types

View File

@@ -5,7 +5,6 @@ Base class for model loading in InvokeAI.
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass
from logging import Logger
from pathlib import Path
from typing import Any, Dict, Generator, Optional, Tuple
@@ -18,19 +17,17 @@ from invokeai.backend.model_manager.config import (
AnyModelConfig,
SubModelType,
)
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
@dataclass
class LoadedModelWithoutConfig:
"""
Context manager object that mediates transfer from RAM<->VRAM.
"""Context manager object that mediates transfer from RAM<->VRAM.
This is a context manager object that has two distinct APIs:
1. Older API (deprecated):
Use the LoadedModel object directly as a context manager.
It will move the model into VRAM (on CUDA devices), and
Use the LoadedModel object directly as a context manager. It will move the model into VRAM (on CUDA devices), and
return the model in a form suitable for passing to torch.
Example:
```
@@ -40,13 +37,9 @@ class LoadedModelWithoutConfig:
```
2. Newer API (recommended):
Call the LoadedModel's `model_on_device()` method in a
context. It returns a tuple consisting of a copy of
the model's state dict in CPU RAM followed by a copy
of the model in VRAM. The state dict is provided to allow
LoRAs and other model patchers to return the model to
its unpatched state without expensive copy and restore
operations.
Call the LoadedModel's `model_on_device()` method in a context. It returns a tuple consisting of a copy of the
model's state dict in CPU RAM followed by a copy of the model in VRAM. The state dict is provided to allow LoRAs and
other model patchers to return the model to its unpatched state without expensive copy and restore operations.
Example:
```
@@ -55,43 +48,42 @@ class LoadedModelWithoutConfig:
image = vae.decode(latents)[0]
```
The state_dict should be treated as a read-only object and
never modified. Also be aware that some loadable models do
not have a state_dict, in which case this value will be None.
The state_dict should be treated as a read-only object and never modified. Also be aware that some loadable models
do not have a state_dict, in which case this value will be None.
"""
_locker: ModelLockerBase
def __init__(self, cache_record: CacheRecord, cache: ModelCache):
self._cache_record = cache_record
self._cache = cache
def __enter__(self) -> AnyModel:
"""Context entry."""
self._locker.lock()
self._cache.lock(self._cache_record.key)
return self.model
def __exit__(self, *args: Any, **kwargs: Any) -> None:
"""Context exit."""
self._locker.unlock()
self._cache.unlock(self._cache_record.key)
@contextmanager
def model_on_device(self) -> Generator[Tuple[Optional[Dict[str, torch.Tensor]], AnyModel], None, None]:
"""Return a tuple consisting of the model's state dict (if it exists) and the locked model on execution device."""
locked_model = self._locker.lock()
self._cache.lock(self._cache_record.key)
try:
state_dict = self._locker.get_state_dict()
yield (state_dict, locked_model)
yield (self._cache_record.cached_model.get_cpu_state_dict(), self._cache_record.cached_model.model)
finally:
self._locker.unlock()
self._cache.unlock(self._cache_record.key)
@property
def model(self) -> AnyModel:
"""Return the model without locking it."""
return self._locker.model
return self._cache_record.cached_model.model
@dataclass
class LoadedModel(LoadedModelWithoutConfig):
"""Context manager object that mediates transfer from RAM<->VRAM."""
config: Optional[AnyModelConfig] = None
def __init__(self, config: Optional[AnyModelConfig], cache_record: CacheRecord, cache: ModelCache):
super().__init__(cache_record=cache_record, cache=cache)
self.config = config
# TODO(MM2):
@@ -110,7 +102,7 @@ class ModelLoaderBase(ABC):
self,
app_config: InvokeAIAppConfig,
logger: Logger,
ram_cache: ModelCacheBase[AnyModel],
ram_cache: ModelCache,
):
"""Initialize the loader."""
pass
@@ -138,6 +130,6 @@ class ModelLoaderBase(ABC):
@property
@abstractmethod
def ram_cache(self) -> ModelCacheBase[AnyModel]:
def ram_cache(self) -> ModelCache:
"""Return the ram cache associated with this loader."""
pass

View File

@@ -14,7 +14,8 @@ from invokeai.backend.model_manager import (
)
from invokeai.backend.model_manager.config import DiffusersConfigBase
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache, get_model_cache_key
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.util.devices import TorchDevice
@@ -28,7 +29,7 @@ class ModelLoader(ModelLoaderBase):
self,
app_config: InvokeAIAppConfig,
logger: Logger,
ram_cache: ModelCacheBase[AnyModel],
ram_cache: ModelCache,
):
"""Initialize the loader."""
self._app_config = app_config
@@ -54,11 +55,11 @@ class ModelLoader(ModelLoaderBase):
raise InvalidModelConfigException(f"Files for model '{model_config.name}' not found at {model_path}")
with skip_torch_weight_init():
locker = self._load_and_cache(model_config, submodel_type)
return LoadedModel(config=model_config, _locker=locker)
cache_record = self._load_and_cache(model_config, submodel_type)
return LoadedModel(config=model_config, cache_record=cache_record, cache=self._ram_cache)
@property
def ram_cache(self) -> ModelCacheBase[AnyModel]:
def ram_cache(self) -> ModelCache:
"""Return the ram cache associated with this loader."""
return self._ram_cache
@@ -66,10 +67,10 @@ class ModelLoader(ModelLoaderBase):
model_base = self._app_config.models_path
return (model_base / config.path).resolve()
def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> ModelLockerBase:
def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> CacheRecord:
stats_name = ":".join([config.base, config.type, config.name, (submodel_type or "")])
try:
return self._ram_cache.get(config.key, submodel_type, stats_name=stats_name)
return self._ram_cache.get(key=get_model_cache_key(config.key, submodel_type), stats_name=stats_name)
except IndexError:
pass
@@ -78,16 +79,11 @@ class ModelLoader(ModelLoaderBase):
loaded_model = self._load_model(config, submodel_type)
self._ram_cache.put(
config.key,
submodel_type=submodel_type,
get_model_cache_key(config.key, submodel_type),
model=loaded_model,
)
return self._ram_cache.get(
key=config.key,
submodel_type=submodel_type,
stats_name=stats_name,
)
return self._ram_cache.get(key=get_model_cache_key(config.key, submodel_type), stats_name=stats_name)
def get_size_fs(
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None

View File

@@ -1,6 +0,0 @@
"""Init file for ModelCache."""
from .model_cache_base import ModelCacheBase, CacheStats # noqa F401
from .model_cache_default import ModelCache # noqa F401
_all__ = ["ModelCacheBase", "ModelCache", "CacheStats"]

View File

@@ -0,0 +1,31 @@
from dataclasses import dataclass
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import (
CachedModelOnlyFullLoad,
)
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import (
CachedModelWithPartialLoad,
)
@dataclass
class CacheRecord:
"""A class that represents a model in the model cache."""
# Cache key.
key: str
# Model in memory.
cached_model: CachedModelWithPartialLoad | CachedModelOnlyFullLoad
# If locks > 0, the model is actively being used, so we should do our best to keep it on the compute device.
_locks: int = 0
def lock(self) -> None:
self._locks += 1
def unlock(self) -> None:
self._locks -= 1
assert self._locks >= 0
@property
def is_locked(self) -> bool:
return self._locks > 0

View File

@@ -0,0 +1,15 @@
from dataclasses import dataclass, field
from typing import Dict
@dataclass
class CacheStats(object):
"""Collect statistics on cache performance."""
hits: int = 0 # cache hits
misses: int = 0 # cache misses
high_watermark: int = 0 # amount of cache used
in_cache: int = 0 # number of models in cache
cleared: int = 0 # number of models cleared to make space
cache_size: int = 0 # total size of cache
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)

View File

@@ -0,0 +1,81 @@
from typing import Any
import torch
class CachedModelOnlyFullLoad:
"""A wrapper around a PyTorch model to handle full loads and unloads between the CPU and the compute device.
Note: "VRAM" is used throughout this class to refer to the memory on the compute device. It could be CUDA memory,
MPS memory, etc.
"""
def __init__(self, model: torch.nn.Module | Any, compute_device: torch.device, total_bytes: int):
"""Initialize a CachedModelOnlyFullLoad.
Args:
model (torch.nn.Module | Any): The model to wrap. Should be on the CPU.
compute_device (torch.device): The compute device to move the model to.
total_bytes (int): The total size (in bytes) of all the weights in the model.
"""
# model is often a torch.nn.Module, but could be any model type. Throughout this class, we handle both cases.
self._model = model
self._compute_device = compute_device
self._total_bytes = total_bytes
self._is_in_vram = False
@property
def model(self) -> torch.nn.Module:
return self._model
def get_cpu_state_dict(self) -> dict[str, torch.Tensor] | None:
"""Get a read-only copy of the model's state dict in RAM."""
# TODO(ryand): Document this better and implement it.
return None
def total_bytes(self) -> int:
"""Get the total size (in bytes) of all the weights in the model."""
return self._total_bytes
def cur_vram_bytes(self) -> int:
"""Get the size (in bytes) of the weights that are currently in VRAM."""
if self._is_in_vram:
return self._total_bytes
else:
return 0
def is_in_vram(self) -> bool:
"""Return true if the model is currently in VRAM."""
return self._is_in_vram
def full_load_to_vram(self) -> int:
"""Load all weights into VRAM (if supported by the model).
Returns:
The number of bytes loaded into VRAM.
"""
if self._is_in_vram:
# Already in VRAM.
return 0
if not hasattr(self._model, "to"):
# Model doesn't support moving to a device.
return 0
self._model.to(self._compute_device)
self._is_in_vram = True
return self._total_bytes
def full_unload_from_vram(self) -> int:
"""Unload all weights from VRAM.
Returns:
The number of bytes unloaded from VRAM.
"""
if not self._is_in_vram:
# Already in RAM.
return 0
self._model.to("cpu")
self._is_in_vram = False
return self._total_bytes

View File

@@ -0,0 +1,150 @@
import itertools
import torch
from invokeai.backend.model_manager.load.model_cache.torch_function_autocast_context import (
add_autocast_to_module_forward,
)
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
def set_nested_attr(obj: object, attr: str, value: object):
"""A helper function that extends setattr() to support nested attributes.
Example:
set_nested_attr(model, "module.encoder.conv1.weight", new_conv1_weight)
"""
attrs = attr.split(".")
for attr in attrs[:-1]:
obj = getattr(obj, attr)
setattr(obj, attrs[-1], value)
class CachedModelWithPartialLoad:
"""A wrapper around a PyTorch model to handle partial loads and unloads between the CPU and the compute device.
Note: "VRAM" is used throughout this class to refer to the memory on the compute device. It could be CUDA memory,
MPS memory, etc.
"""
def __init__(self, model: torch.nn.Module, compute_device: torch.device):
self._model = model
self._compute_device = compute_device
# A CPU read-only copy of the model's state dict.
self._cpu_state_dict: dict[str, torch.Tensor] = model.state_dict()
# Monkey-patch the model to add autocasting to the model's forward method.
add_autocast_to_module_forward(model, compute_device)
self._total_bytes = sum(
calc_tensor_size(p) for p in itertools.chain(self._model.parameters(), self._model.buffers())
)
self._cur_vram_bytes: int | None = None
@property
def model(self) -> torch.nn.Module:
return self._model
def get_cpu_state_dict(self) -> dict[str, torch.Tensor] | None:
"""Get a read-only copy of the model's state dict in RAM."""
# TODO(ryand): Document this better.
return self._cpu_state_dict
def total_bytes(self) -> int:
"""Get the total size (in bytes) of all the weights in the model."""
return self._total_bytes
def cur_vram_bytes(self) -> int:
"""Get the size (in bytes) of the weights that are currently in VRAM."""
if self._cur_vram_bytes is None:
self._cur_vram_bytes = sum(
calc_tensor_size(p)
for p in itertools.chain(self._model.parameters(), self._model.buffers())
if p.device.type == self._compute_device.type
)
return self._cur_vram_bytes
def full_load_to_vram(self) -> int:
"""Load all weights into VRAM."""
return self.partial_load_to_vram(self.total_bytes())
def full_unload_from_vram(self) -> int:
"""Unload all weights from VRAM."""
return self.partial_unload_from_vram(self.total_bytes())
@torch.no_grad()
def partial_load_to_vram(self, vram_bytes_to_load: int) -> int:
"""Load more weights into VRAM without exceeding vram_bytes_to_load.
Returns:
The number of bytes loaded into VRAM.
"""
vram_bytes_loaded = 0
for key, param in itertools.chain(self._model.named_parameters(), self._model.named_buffers()):
# Skip parameters that are already on the compute device.
if param.device.type == self._compute_device.type:
continue
# Check the size of the parameter.
param_size = calc_tensor_size(param)
if vram_bytes_loaded + param_size > vram_bytes_to_load:
# TODO(ryand): Should we just break here? If we couldn't fit this parameter into VRAM, is it really
# worth continuing to search for a smaller parameter that would fit?
continue
# Copy the parameter to the compute device.
# We use the 'overwrite' strategy from torch.nn.Module._apply().
# TODO(ryand): For some edge cases (e.g. quantized models?), we may need to support other strategies (e.g.
# swap).
if isinstance(param, torch.nn.Parameter):
assert param.is_leaf
out_param = torch.nn.Parameter(
param.to(self._compute_device, copy=True), requires_grad=param.requires_grad
)
set_nested_attr(self._model, key, out_param)
# We did not port the param.grad handling from torch.nn.Module._apply(), because we do not expect to be
# handling gradients. We assert that this assumption is true.
assert param.grad is None
else:
# Handle buffers.
set_nested_attr(self._model, key, param.to(self._compute_device, copy=True))
vram_bytes_loaded += param_size
if self._cur_vram_bytes is not None:
self._cur_vram_bytes += vram_bytes_loaded
return vram_bytes_loaded
@torch.no_grad()
def partial_unload_from_vram(self, vram_bytes_to_free: int) -> int:
"""Unload weights from VRAM until vram_bytes_to_free bytes are freed. Or the entire model is unloaded.
Returns:
The number of bytes unloaded from VRAM.
"""
vram_bytes_freed = 0
for key, param in itertools.chain(self._model.named_parameters(), self._model.named_buffers()):
if vram_bytes_freed >= vram_bytes_to_free:
break
if param.device.type != self._compute_device.type:
continue
if isinstance(param, torch.nn.Parameter):
# Create a new parameter, but inject the existing CPU tensor into it.
out_param = torch.nn.Parameter(self._cpu_state_dict[key], requires_grad=param.requires_grad)
set_nested_attr(self._model, key, out_param)
else:
# Handle buffers.
set_nested_attr(self._model, key, self._cpu_state_dict[key])
vram_bytes_freed += calc_tensor_size(param)
if self._cur_vram_bytes is not None:
self._cur_vram_bytes -= vram_bytes_freed
return vram_bytes_freed

View File

@@ -0,0 +1,538 @@
import gc
from logging import Logger
from typing import Dict, List, Optional
import torch
from invokeai.backend.model_manager import AnyModel, SubModelType
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import (
CachedModelOnlyFullLoad,
)
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import (
CachedModelWithPartialLoad,
)
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.backend.util.prefix_logger_adapter import PrefixedLoggerAdapter
# Size of a GB in bytes.
GB = 2**30
# Size of a MB in bytes.
MB = 2**20
# TODO(ryand): Where should this go? The ModelCache shouldn't be concerned with submodels.
def get_model_cache_key(model_key: str, submodel_type: Optional[SubModelType] = None) -> str:
"""Get the cache key for a model based on the optional submodel type."""
if submodel_type:
return f"{model_key}:{submodel_type.value}"
else:
return model_key
class ModelCache:
"""A cache for managing models in memory.
The cache is based on two levels of model storage:
- execution_device: The device where most models are executed (typically "cuda", "mps", or "cpu").
- storage_device: The device where models are offloaded when not in active use (typically "cpu").
The model cache is based on the following assumptions:
- storage_device_mem_size > execution_device_mem_size
- disk_to_storage_device_transfer_time >> storage_device_to_execution_device_transfer_time
A copy of all models in the cache is always kept on the storage_device. A subset of the models also have a copy on
the execution_device.
Models are moved between the storage_device and the execution_device as necessary. Cache size limits are enforced
on both the storage_device and the execution_device. The execution_device cache uses a smallest-first offload
policy. The storage_device cache uses a least-recently-used (LRU) offload policy.
Note: Neither of these offload policies has really been compared against alternatives. It's likely that different
policies would be better, although the optimal policies are likely heavily dependent on usage patterns and HW
configuration.
The cache returns context manager generators designed to load the model into the execution device (often GPU) within
the context, and unload outside the context.
Example usage:
```
cache = ModelCache(max_cache_size=7.5, max_vram_cache_size=6.0)
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1:
do_something_on_gpu(SD1)
```
"""
def __init__(
self,
max_cache_size: float,
max_vram_cache_size: float,
execution_device: torch.device = torch.device("cuda"),
storage_device: torch.device = torch.device("cpu"),
lazy_offloading: bool = True,
log_memory_usage: bool = False,
logger: Optional[Logger] = None,
):
"""
Initialize the model RAM cache.
:param max_cache_size: Maximum size of the storage_device cache in GBs.
:param max_vram_cache_size: Maximum size of the execution_device cache in GBs.
:param execution_device: Torch device to load active model into [torch.device('cuda')]
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
behaviour.
:param logger: InvokeAILogger to use (otherwise creates one)
"""
# allow lazy offloading only when vram cache enabled
# TODO(ryand): Think about what lazy_offloading should mean in the new model cache.
self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
self._max_cache_size: float = max_cache_size
self._max_vram_cache_size: float = max_vram_cache_size
self._execution_device: torch.device = execution_device
self._storage_device: torch.device = storage_device
self._logger = PrefixedLoggerAdapter(
logger or InvokeAILogger.get_logger(self.__class__.__name__), "MODEL CACHE"
)
self._log_memory_usage = log_memory_usage
self._stats: Optional[CacheStats] = None
self._cached_models: Dict[str, CacheRecord] = {}
self._cache_stack: List[str] = []
@property
def max_cache_size(self) -> float:
"""Return the cap on cache size."""
return self._max_cache_size
@max_cache_size.setter
def max_cache_size(self, value: float) -> None:
"""Set the cap on cache size."""
self._max_cache_size = value
@property
def max_vram_cache_size(self) -> float:
"""Return the cap on vram cache size."""
return self._max_vram_cache_size
@max_vram_cache_size.setter
def max_vram_cache_size(self, value: float) -> None:
"""Set the cap on vram cache size."""
self._max_vram_cache_size = value
@property
def stats(self) -> Optional[CacheStats]:
"""Return collected CacheStats object."""
return self._stats
@stats.setter
def stats(self, stats: CacheStats) -> None:
"""Set the CacheStats object for collecting cache statistics."""
self._stats = stats
def put(self, key: str, model: AnyModel) -> None:
"""Add a model to the cache."""
if key in self._cached_models:
self._logger.debug(
f"Attempted to add model {key} ({model.__class__.__name__}), but it already exists in the cache. No action necessary."
)
return
size = calc_model_size_by_data(self._logger, model)
self.make_room(size)
# Wrap model.
if isinstance(model, torch.nn.Module):
wrapped_model = CachedModelWithPartialLoad(model, self._execution_device)
else:
wrapped_model = CachedModelOnlyFullLoad(model, self._execution_device, size)
# running_on_cpu = self._execution_device == torch.device("cpu")
# state_dict = model.state_dict() if isinstance(model, torch.nn.Module) and not running_on_cpu else None
cache_record = CacheRecord(key=key, cached_model=wrapped_model)
self._cached_models[key] = cache_record
self._cache_stack.append(key)
self._logger.debug(
f"Added model {key} (Type: {model.__class__.__name__}, Wrap mode: {wrapped_model.__class__.__name__}, Model size: {size/MB:.2f}MB)"
)
def get(self, key: str, stats_name: Optional[str] = None) -> CacheRecord:
"""Retrieve a model from the cache.
:param key: Model key
:param stats_name: A human-readable id for the model for the purposes of stats reporting.
Raises IndexError if the model is not in the cache.
"""
if key in self._cached_models:
if self.stats:
self.stats.hits += 1
else:
if self.stats:
self.stats.misses += 1
self._logger.debug(f"Cache miss: {key}")
raise IndexError(f"The model with key {key} is not in the cache.")
cache_entry = self._cached_models[key]
# more stats
if self.stats:
stats_name = stats_name or key
self.stats.cache_size = int(self._max_cache_size * GB)
self.stats.high_watermark = max(self.stats.high_watermark, self._get_ram_in_use())
self.stats.in_cache = len(self._cached_models)
self.stats.loaded_model_sizes[stats_name] = max(
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.cached_model.total_bytes()
)
# this moves the entry to the top (right end) of the stack
self._cache_stack = [k for k in self._cache_stack if k != key]
self._cache_stack.append(key)
self._logger.debug(f"Cache hit: {key} (Type: {cache_entry.cached_model.model.__class__.__name__})")
return cache_entry
def lock(self, key: str) -> None:
"""Lock a model for use and move it into VRAM."""
cache_entry = self._cached_models[key]
cache_entry.lock()
self._logger.debug(f"Locking model {key} (Type: {cache_entry.cached_model.model.__class__.__name__})")
try:
self._load_locked_model(cache_entry)
self._logger.debug(
f"Finished locking model {key} (Type: {cache_entry.cached_model.model.__class__.__name__})"
)
except torch.cuda.OutOfMemoryError:
self._logger.warning("Insufficient GPU memory to load model. Aborting")
cache_entry.unlock()
raise
except Exception:
cache_entry.unlock()
raise
self._log_cache_state()
def unlock(self, key: str) -> None:
"""Unlock a model."""
cache_entry = self._cached_models[key]
cache_entry.unlock()
self._logger.debug(f"Unlocked model {key} (Type: {cache_entry.cached_model.model.__class__.__name__})")
def _load_locked_model(self, cache_entry: CacheRecord) -> None:
"""Helper function for self.lock(). Loads a locked model into VRAM."""
vram_available = self._get_vram_available()
# Calculate model_vram_needed, the amount of additional VRAM that will be used if we fully load the model into
# VRAM.
model_cur_vram_bytes = cache_entry.cached_model.cur_vram_bytes()
model_total_bytes = cache_entry.cached_model.total_bytes()
model_vram_needed = model_total_bytes - model_cur_vram_bytes
# The amount of VRAM that must be freed to make room for model_vram_needed.
vram_bytes_to_free = max(0, model_vram_needed - vram_available)
self._logger.debug(
f"Before unloading: {self._get_vram_state_str(model_cur_vram_bytes, model_total_bytes, vram_available)}"
)
# Make room for the model in VRAM.
# 1. If the model can fit entirely in VRAM, then make enough room for it to be loaded fully.
# 2. If the model can't fit fully into VRAM, then unload all other models and load as much of the model as
# possible.
vram_bytes_freed = self._offload_unlocked_models(vram_bytes_to_free)
self._logger.debug(f"Unloaded models (if necessary): vram_bytes_freed={(vram_bytes_freed/MB):.2f}MB")
# Check the updated vram_available after offloading.
vram_available = self._get_vram_available()
self._logger.debug(
f"After unloading: {self._get_vram_state_str(model_cur_vram_bytes, model_total_bytes, vram_available)}"
)
# Move as much of the model as possible into VRAM.
model_bytes_loaded = 0
if isinstance(cache_entry.cached_model, CachedModelWithPartialLoad):
model_bytes_loaded = cache_entry.cached_model.partial_load_to_vram(vram_available)
elif isinstance(cache_entry.cached_model, CachedModelOnlyFullLoad): # type: ignore
# Partial load is not supported, so we have not choice but to try and fit it all into VRAM.
model_bytes_loaded = cache_entry.cached_model.full_load_to_vram()
else:
raise ValueError(f"Unsupported cached model type: {type(cache_entry.cached_model)}")
model_cur_vram_bytes = cache_entry.cached_model.cur_vram_bytes()
vram_available = self._get_vram_available()
self._logger.debug(f"Loaded model onto execution device: model_bytes_loaded={(model_bytes_loaded/MB):.2f}MB, ")
self._logger.debug(
f"After loading: {self._get_vram_state_str(model_cur_vram_bytes, model_total_bytes, vram_available)}"
)
def _get_vram_available(self) -> int:
"""Get the amount of VRAM available in the cache."""
return int(self._max_vram_cache_size * GB) - self._get_vram_in_use()
def _get_vram_in_use(self) -> int:
"""Get the amount of VRAM currently in use."""
return sum(ce.cached_model.cur_vram_bytes() for ce in self._cached_models.values())
def _get_ram_available(self) -> int:
"""Get the amount of RAM available in the cache."""
return int(self._max_cache_size * GB) - self._get_ram_in_use()
def _get_ram_in_use(self) -> int:
"""Get the amount of RAM currently in use."""
return sum(ce.cached_model.total_bytes() for ce in self._cached_models.values())
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
if self._log_memory_usage:
return MemorySnapshot.capture()
return None
def _get_vram_state_str(self, model_cur_vram_bytes: int, model_total_bytes: int, vram_available: int) -> str:
"""Helper function for preparing a VRAM state log string."""
model_cur_vram_bytes_percent = model_cur_vram_bytes / model_total_bytes if model_total_bytes > 0 else 0
return (
f"model_total={model_total_bytes/MB:.0f} MB, "
+ f"model_vram={model_cur_vram_bytes/MB:.0f} MB ({model_cur_vram_bytes_percent:.1%} %), "
+ f"vram_total={int(self._max_vram_cache_size * GB)/MB:.0f} MB, "
+ f"vram_available={(vram_available/MB):.0f} MB, "
)
def _offload_unlocked_models(self, vram_bytes_to_free: int) -> int:
"""Offload models from the execution_device until vram_bytes_to_free bytes are freed, or all models are
offloaded. Of course, locked models are not offloaded.
Returns:
int: The number of bytes freed.
"""
self._logger.debug(f"Offloading unlocked models with goal of freeing {vram_bytes_to_free/MB:.2f}MB of VRAM.")
vram_bytes_freed = 0
# TODO(ryand): Give more thought to the offloading policy used here.
cache_entries_increasing_size = sorted(self._cached_models.values(), key=lambda x: x.cached_model.total_bytes())
for cache_entry in cache_entries_increasing_size:
if vram_bytes_freed >= vram_bytes_to_free:
break
if cache_entry.is_locked:
continue
if isinstance(cache_entry.cached_model, CachedModelWithPartialLoad):
cache_entry_bytes_freed = cache_entry.cached_model.partial_unload_from_vram(
vram_bytes_to_free - vram_bytes_freed
)
elif isinstance(cache_entry.cached_model, CachedModelOnlyFullLoad): # type: ignore
cache_entry_bytes_freed = cache_entry.cached_model.full_unload_from_vram()
else:
raise ValueError(f"Unsupported cached model type: {type(cache_entry.cached_model)}")
if cache_entry_bytes_freed > 0:
self._logger.debug(
f"Unloaded {cache_entry.key} from VRAM to free {(cache_entry_bytes_freed/MB):.0f} MB."
)
vram_bytes_freed += cache_entry_bytes_freed
TorchDevice.empty_cache()
return vram_bytes_freed
# def _move_model_to_device(self, cache_entry: CacheRecord, target_device: torch.device) -> None:
# """Move model into the indicated device.
# :param cache_entry: The CacheRecord for the model
# :param target_device: The torch.device to move the model into
# May raise a torch.cuda.OutOfMemoryError
# """
# self._logger.debug(f"Called to move {cache_entry.key} to {target_device}")
# source_device = cache_entry.device
# # Note: We compare device types only so that 'cuda' == 'cuda:0'.
# # This would need to be revised to support multi-GPU.
# if torch.device(source_device).type == torch.device(target_device).type:
# return
# # Some models don't have a `to` method, in which case they run in RAM/CPU.
# if not hasattr(cache_entry.model, "to"):
# return
# # This roundabout method for moving the model around is done to avoid
# # the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
# # When moving to VRAM, we copy (not move) each element of the state dict from
# # RAM to a new state dict in VRAM, and then inject it into the model.
# # This operation is slightly faster than running `to()` on the whole model.
# #
# # When the model needs to be removed from VRAM we simply delete the copy
# # of the state dict in VRAM, and reinject the state dict that is cached
# # in RAM into the model. So this operation is very fast.
# start_model_to_time = time.time()
# snapshot_before = self._capture_memory_snapshot()
# try:
# if cache_entry.state_dict is not None:
# assert hasattr(cache_entry.model, "load_state_dict")
# if target_device == self._storage_device:
# cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True)
# else:
# new_dict: Dict[str, torch.Tensor] = {}
# for k, v in cache_entry.state_dict.items():
# new_dict[k] = v.to(target_device, copy=True)
# cache_entry.model.load_state_dict(new_dict, assign=True)
# cache_entry.model.to(target_device)
# cache_entry.device = target_device
# except Exception as e: # blow away cache entry
# self._delete_cache_entry(cache_entry)
# raise e
# snapshot_after = self._capture_memory_snapshot()
# end_model_to_time = time.time()
# self._logger.debug(
# f"Moved model '{cache_entry.key}' from {source_device} to"
# f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
# f"Estimated model size: {(cache_entry.size/GB):.3f} GB."
# f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
# )
# if (
# snapshot_before is not None
# and snapshot_after is not None
# and snapshot_before.vram is not None
# and snapshot_after.vram is not None
# ):
# vram_change = abs(snapshot_before.vram - snapshot_after.vram)
# # If the estimated model size does not match the change in VRAM, log a warning.
# if not math.isclose(
# vram_change,
# cache_entry.size,
# rel_tol=0.1,
# abs_tol=10 * MB,
# ):
# self._logger.debug(
# f"Moving model '{cache_entry.key}' from {source_device} to"
# f" {target_device} caused an unexpected change in VRAM usage. The model's"
# " estimated size may be incorrect. Estimated model size:"
# f" {(cache_entry.size/GB):.3f} GB.\n"
# f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
# )
def _log_cache_state(self, title: str = "Model cache state:", include_entry_details: bool = True):
ram_size_bytes = self._max_cache_size * GB
ram_in_use_bytes = self._get_ram_in_use()
ram_in_use_bytes_percent = ram_in_use_bytes / ram_size_bytes if ram_size_bytes > 0 else 0
ram_available_bytes = self._get_ram_available()
ram_available_bytes_percent = ram_available_bytes / ram_size_bytes if ram_size_bytes > 0 else 0
vram_size_bytes = self._max_vram_cache_size * GB
vram_in_use_bytes = self._get_vram_in_use()
vram_in_use_bytes_percent = vram_in_use_bytes / vram_size_bytes if vram_size_bytes > 0 else 0
vram_available_bytes = self._get_vram_available()
vram_available_bytes_percent = vram_available_bytes / vram_size_bytes if vram_size_bytes > 0 else 0
log = f"{title}\n"
log_format = " {:<30} Limit: {:>7.1f} MB, Used: {:>7.1f} MB ({:>5.1%}), Available: {:>7.1f} MB ({:>5.1%})\n"
log += log_format.format(
f"Storage Device ({self._storage_device.type})",
ram_size_bytes / MB,
ram_in_use_bytes / MB,
ram_in_use_bytes_percent,
ram_available_bytes / MB,
ram_available_bytes_percent,
)
log += log_format.format(
f"Compute Device ({self._execution_device.type})",
vram_size_bytes / MB,
vram_in_use_bytes / MB,
vram_in_use_bytes_percent,
vram_available_bytes / MB,
vram_available_bytes_percent,
)
if torch.cuda.is_available():
log += " {:<30} {} MB\n".format("CUDA Memory Allocated:", torch.cuda.memory_allocated() / MB)
log += " {:<30} {}\n".format("Total models:", len(self._cached_models))
if include_entry_details and len(self._cached_models) > 0:
log += " Models:\n"
log_format = (
" {:<80} total={:>7.1f} MB, vram={:>7.1f} MB ({:>5.1%}), ram={:>7.1f} MB ({:>5.1%}), locked={}\n"
)
for cache_record in self._cached_models.values():
total_bytes = cache_record.cached_model.total_bytes()
cur_vram_bytes = cache_record.cached_model.cur_vram_bytes()
cur_vram_bytes_percent = cur_vram_bytes / total_bytes if total_bytes > 0 else 0
cur_ram_bytes = total_bytes - cur_vram_bytes
cur_ram_bytes_percent = cur_ram_bytes / total_bytes if total_bytes > 0 else 0
log += log_format.format(
f"{cache_record.key} ({cache_record.cached_model.model.__class__.__name__}):",
total_bytes / MB,
cur_vram_bytes / MB,
cur_vram_bytes_percent,
cur_ram_bytes / MB,
cur_ram_bytes_percent,
cache_record.is_locked,
)
self._logger.debug(log)
def make_room(self, bytes_needed: int) -> None:
"""Make enough room in the cache to accommodate a new model of indicated size.
Note: This function deletes all of the cache's internal references to a model in order to free it. If there are
external references to the model, there's nothing that the cache can do about it, and those models will not be
garbage-collected.
"""
self._logger.debug(f"Making room for {bytes_needed/MB:.2f}MB of RAM.")
self._log_cache_state(title="Before dropping models:")
ram_bytes_available = self._get_ram_available()
ram_bytes_to_free = max(0, bytes_needed - ram_bytes_available)
ram_bytes_freed = 0
pos = 0
models_cleared = 0
while ram_bytes_freed < ram_bytes_to_free and pos < len(self._cache_stack):
model_key = self._cache_stack[pos]
cache_entry = self._cached_models[model_key]
if not cache_entry.is_locked:
ram_bytes_freed += cache_entry.cached_model.total_bytes()
self._logger.debug(
f"Dropping {model_key} from RAM cache to free {(cache_entry.cached_model.total_bytes()/MB):.2f}MB."
)
self._delete_cache_entry(cache_entry)
del cache_entry
models_cleared += 1
else:
pos += 1
if models_cleared > 0:
# There would likely be some 'garbage' to be collected regardless of whether a model was cleared or not, but
# there is a significant time cost to calling `gc.collect()`, so we want to use it sparingly. (The time cost
# is high even if no garbage gets collected.)
#
# Calling gc.collect(...) when a model is cleared seems like a good middle-ground:
# - If models had to be cleared, it's a signal that we are close to our memory limit.
# - If models were cleared, there's a good chance that there's a significant amount of garbage to be
# collected.
#
# Keep in mind that gc is only responsible for handling reference cycles. Most objects should be cleaned up
# immediately when their reference count hits 0.
if self.stats:
self.stats.cleared = models_cleared
gc.collect()
TorchDevice.empty_cache()
self._logger.debug(f"Dropped {models_cleared} models to free {ram_bytes_freed/MB:.2f}MB of RAM.")
self._log_cache_state(title="After dropping models:")
def _delete_cache_entry(self, cache_entry: CacheRecord) -> None:
self._cache_stack.remove(cache_entry.key)
del self._cached_models[cache_entry.key]

View File

@@ -1,221 +0,0 @@
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
# TODO: Add Stalker's proper name to copyright
"""
Manage a RAM cache of diffusion/transformer models for fast switching.
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
grows larger than a preset maximum, then the least recently used
model will be cleared and (re)loaded from disk when next needed.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from logging import Logger
from typing import Dict, Generic, Optional, TypeVar
import torch
from invokeai.backend.model_manager.config import AnyModel, SubModelType
class ModelLockerBase(ABC):
"""Base class for the model locker used by the loader."""
@abstractmethod
def lock(self) -> AnyModel:
"""Lock the contained model and move it into VRAM."""
pass
@abstractmethod
def unlock(self) -> None:
"""Unlock the contained model, and remove it from VRAM."""
pass
@abstractmethod
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
"""Return the state dict (if any) for the cached model."""
pass
@property
@abstractmethod
def model(self) -> AnyModel:
"""Return the model."""
pass
T = TypeVar("T")
@dataclass
class CacheRecord(Generic[T]):
"""
Elements of the cache:
key: Unique key for each model, same as used in the models database.
model: Model in memory.
state_dict: A read-only copy of the model's state dict in RAM. It will be
used as a template for creating a copy in the VRAM.
size: Size of the model
loaded: True if the model's state dict is currently in VRAM
Before a model is executed, the state_dict template is copied into VRAM,
and then injected into the model. When the model is finished, the VRAM
copy of the state dict is deleted, and the RAM version is reinjected
into the model.
The state_dict should be treated as a read-only attribute. Do not attempt
to patch or otherwise modify it. Instead, patch the copy of the state_dict
after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel`
context manager call `model_on_device()`.
"""
key: str
model: T
device: torch.device
state_dict: Optional[Dict[str, torch.Tensor]]
size: int
loaded: bool = False
_locks: int = 0
def lock(self) -> None:
"""Lock this record."""
self._locks += 1
def unlock(self) -> None:
"""Unlock this record."""
self._locks -= 1
assert self._locks >= 0
@property
def locked(self) -> bool:
"""Return true if record is locked."""
return self._locks > 0
@dataclass
class CacheStats(object):
"""Collect statistics on cache performance."""
hits: int = 0 # cache hits
misses: int = 0 # cache misses
high_watermark: int = 0 # amount of cache used
in_cache: int = 0 # number of models in cache
cleared: int = 0 # number of models cleared to make space
cache_size: int = 0 # total size of cache
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
class ModelCacheBase(ABC, Generic[T]):
"""Virtual base class for RAM model cache."""
@property
@abstractmethod
def storage_device(self) -> torch.device:
"""Return the storage device (e.g. "CPU" for RAM)."""
pass
@property
@abstractmethod
def execution_device(self) -> torch.device:
"""Return the exection device (e.g. "cuda" for VRAM)."""
pass
@property
@abstractmethod
def lazy_offloading(self) -> bool:
"""Return true if the cache is configured to lazily offload models in VRAM."""
pass
@property
@abstractmethod
def max_cache_size(self) -> float:
"""Return the maximum size the RAM cache can grow to."""
pass
@max_cache_size.setter
@abstractmethod
def max_cache_size(self, value: float) -> None:
"""Set the cap on vram cache size."""
@property
@abstractmethod
def max_vram_cache_size(self) -> float:
"""Return the maximum size the VRAM cache can grow to."""
pass
@max_vram_cache_size.setter
@abstractmethod
def max_vram_cache_size(self, value: float) -> float:
"""Set the maximum size the VRAM cache can grow to."""
pass
@abstractmethod
def offload_unlocked_models(self, size_required: int) -> None:
"""Offload from VRAM any models not actively in use."""
pass
@abstractmethod
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
"""Move model into the indicated device."""
pass
@property
@abstractmethod
def stats(self) -> Optional[CacheStats]:
"""Return collected CacheStats object."""
pass
@stats.setter
@abstractmethod
def stats(self, stats: CacheStats) -> None:
"""Set the CacheStats object for collectin cache statistics."""
pass
@property
@abstractmethod
def logger(self) -> Logger:
"""Return the logger used by the cache."""
pass
@abstractmethod
def make_room(self, size: int) -> None:
"""Make enough room in the cache to accommodate a new model of indicated size."""
pass
@abstractmethod
def put(
self,
key: str,
model: T,
submodel_type: Optional[SubModelType] = None,
) -> None:
"""Store model under key and optional submodel_type."""
pass
@abstractmethod
def get(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
stats_name: Optional[str] = None,
) -> ModelLockerBase:
"""
Retrieve model using key and optional submodel_type.
:param key: Opaque model key
:param submodel_type: Type of the submodel to fetch
:param stats_name: A human-readable id for the model for the purposes of
stats reporting.
This may raise an IndexError if the model is not in the cache.
"""
pass
@abstractmethod
def cache_size(self) -> int:
"""Get the total size of the models currently cached."""
pass
@abstractmethod
def print_cuda_stats(self) -> None:
"""Log debugging information on CUDA usage."""
pass

View File

@@ -1,426 +0,0 @@
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
# TODO: Add Stalker's proper name to copyright
""" """
import gc
import math
import time
from contextlib import suppress
from logging import Logger
from typing import Dict, List, Optional
import torch
from invokeai.backend.model_manager import AnyModel, SubModelType
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.model_manager.load.model_cache.model_cache_base import (
CacheRecord,
CacheStats,
ModelCacheBase,
ModelLockerBase,
)
from invokeai.backend.model_manager.load.model_cache.model_locker import ModelLocker
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
# Size of a GB in bytes.
GB = 2**30
# Size of a MB in bytes.
MB = 2**20
class ModelCache(ModelCacheBase[AnyModel]):
"""A cache for managing models in memory.
The cache is based on two levels of model storage:
- execution_device: The device where most models are executed (typically "cuda", "mps", or "cpu").
- storage_device: The device where models are offloaded when not in active use (typically "cpu").
The model cache is based on the following assumptions:
- storage_device_mem_size > execution_device_mem_size
- disk_to_storage_device_transfer_time >> storage_device_to_execution_device_transfer_time
A copy of all models in the cache is always kept on the storage_device. A subset of the models also have a copy on
the execution_device.
Models are moved between the storage_device and the execution_device as necessary. Cache size limits are enforced
on both the storage_device and the execution_device. The execution_device cache uses a smallest-first offload
policy. The storage_device cache uses a least-recently-used (LRU) offload policy.
Note: Neither of these offload policies has really been compared against alternatives. It's likely that different
policies would be better, although the optimal policies are likely heavily dependent on usage patterns and HW
configuration.
The cache returns context manager generators designed to load the model into the execution device (often GPU) within
the context, and unload outside the context.
Example usage:
```
cache = ModelCache(max_cache_size=7.5, max_vram_cache_size=6.0)
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1:
do_something_on_gpu(SD1)
```
"""
def __init__(
self,
max_cache_size: float,
max_vram_cache_size: float,
execution_device: torch.device = torch.device("cuda"),
storage_device: torch.device = torch.device("cpu"),
precision: torch.dtype = torch.float16,
lazy_offloading: bool = True,
log_memory_usage: bool = False,
logger: Optional[Logger] = None,
):
"""
Initialize the model RAM cache.
:param max_cache_size: Maximum size of the storage_device cache in GBs.
:param max_vram_cache_size: Maximum size of the execution_device cache in GBs.
:param execution_device: Torch device to load active model into [torch.device('cuda')]
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
:param precision: Precision for loaded models [torch.float16]
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
behaviour.
:param logger: InvokeAILogger to use (otherwise creates one)
"""
# allow lazy offloading only when vram cache enabled
self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
self._max_cache_size: float = max_cache_size
self._max_vram_cache_size: float = max_vram_cache_size
self._execution_device: torch.device = execution_device
self._storage_device: torch.device = storage_device
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
self._log_memory_usage = log_memory_usage
self._stats: Optional[CacheStats] = None
self._cached_models: Dict[str, CacheRecord[AnyModel]] = {}
self._cache_stack: List[str] = []
@property
def logger(self) -> Logger:
"""Return the logger used by the cache."""
return self._logger
@property
def lazy_offloading(self) -> bool:
"""Return true if the cache is configured to lazily offload models in VRAM."""
return self._lazy_offloading
@property
def storage_device(self) -> torch.device:
"""Return the storage device (e.g. "CPU" for RAM)."""
return self._storage_device
@property
def execution_device(self) -> torch.device:
"""Return the exection device (e.g. "cuda" for VRAM)."""
return self._execution_device
@property
def max_cache_size(self) -> float:
"""Return the cap on cache size."""
return self._max_cache_size
@max_cache_size.setter
def max_cache_size(self, value: float) -> None:
"""Set the cap on cache size."""
self._max_cache_size = value
@property
def max_vram_cache_size(self) -> float:
"""Return the cap on vram cache size."""
return self._max_vram_cache_size
@max_vram_cache_size.setter
def max_vram_cache_size(self, value: float) -> None:
"""Set the cap on vram cache size."""
self._max_vram_cache_size = value
@property
def stats(self) -> Optional[CacheStats]:
"""Return collected CacheStats object."""
return self._stats
@stats.setter
def stats(self, stats: CacheStats) -> None:
"""Set the CacheStats object for collectin cache statistics."""
self._stats = stats
def cache_size(self) -> int:
"""Get the total size of the models currently cached."""
total = 0
for cache_record in self._cached_models.values():
total += cache_record.size
return total
def put(
self,
key: str,
model: AnyModel,
submodel_type: Optional[SubModelType] = None,
) -> None:
"""Store model under key and optional submodel_type."""
key = self._make_cache_key(key, submodel_type)
if key in self._cached_models:
return
size = calc_model_size_by_data(self.logger, model)
self.make_room(size)
running_on_cpu = self.execution_device == torch.device("cpu")
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) and not running_on_cpu else None
cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size)
self._cached_models[key] = cache_record
self._cache_stack.append(key)
def get(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
stats_name: Optional[str] = None,
) -> ModelLockerBase:
"""
Retrieve model using key and optional submodel_type.
:param key: Opaque model key
:param submodel_type: Type of the submodel to fetch
:param stats_name: A human-readable id for the model for the purposes of
stats reporting.
This may raise an IndexError if the model is not in the cache.
"""
key = self._make_cache_key(key, submodel_type)
if key in self._cached_models:
if self.stats:
self.stats.hits += 1
else:
if self.stats:
self.stats.misses += 1
raise IndexError(f"The model with key {key} is not in the cache.")
cache_entry = self._cached_models[key]
# more stats
if self.stats:
stats_name = stats_name or key
self.stats.cache_size = int(self._max_cache_size * GB)
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
self.stats.in_cache = len(self._cached_models)
self.stats.loaded_model_sizes[stats_name] = max(
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size
)
# this moves the entry to the top (right end) of the stack
with suppress(Exception):
self._cache_stack.remove(key)
self._cache_stack.append(key)
return ModelLocker(
cache=self,
cache_entry=cache_entry,
)
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
if self._log_memory_usage:
return MemorySnapshot.capture()
return None
def _make_cache_key(self, model_key: str, submodel_type: Optional[SubModelType] = None) -> str:
if submodel_type:
return f"{model_key}:{submodel_type.value}"
else:
return model_key
def offload_unlocked_models(self, size_required: int) -> None:
"""Offload models from the execution_device to make room for size_required.
:param size_required: The amount of space to clear in the execution_device cache, in bytes.
"""
reserved = self._max_vram_cache_size * GB
vram_in_use = torch.cuda.memory_allocated() + size_required
self.logger.debug(f"{(vram_in_use/GB):.2f}GB VRAM needed for models; max allowed={(reserved/GB):.2f}GB")
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
if vram_in_use <= reserved:
break
if not cache_entry.loaded:
continue
if not cache_entry.locked:
self.move_model_to_device(cache_entry, self.storage_device)
cache_entry.loaded = False
vram_in_use = torch.cuda.memory_allocated() + size_required
self.logger.debug(
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GB):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GB):.2f}GB"
)
TorchDevice.empty_cache()
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
"""Move model into the indicated device.
:param cache_entry: The CacheRecord for the model
:param target_device: The torch.device to move the model into
May raise a torch.cuda.OutOfMemoryError
"""
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
source_device = cache_entry.device
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
# This would need to be revised to support multi-GPU.
if torch.device(source_device).type == torch.device(target_device).type:
return
# Some models don't have a `to` method, in which case they run in RAM/CPU.
if not hasattr(cache_entry.model, "to"):
return
# This roundabout method for moving the model around is done to avoid
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
# When moving to VRAM, we copy (not move) each element of the state dict from
# RAM to a new state dict in VRAM, and then inject it into the model.
# This operation is slightly faster than running `to()` on the whole model.
#
# When the model needs to be removed from VRAM we simply delete the copy
# of the state dict in VRAM, and reinject the state dict that is cached
# in RAM into the model. So this operation is very fast.
start_model_to_time = time.time()
snapshot_before = self._capture_memory_snapshot()
try:
if cache_entry.state_dict is not None:
assert hasattr(cache_entry.model, "load_state_dict")
if target_device == self.storage_device:
cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True)
else:
new_dict: Dict[str, torch.Tensor] = {}
for k, v in cache_entry.state_dict.items():
new_dict[k] = v.to(target_device, copy=True)
cache_entry.model.load_state_dict(new_dict, assign=True)
cache_entry.model.to(target_device)
cache_entry.device = target_device
except Exception as e: # blow away cache entry
self._delete_cache_entry(cache_entry)
raise e
snapshot_after = self._capture_memory_snapshot()
end_model_to_time = time.time()
self.logger.debug(
f"Moved model '{cache_entry.key}' from {source_device} to"
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
f"Estimated model size: {(cache_entry.size/GB):.3f} GB."
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
if (
snapshot_before is not None
and snapshot_after is not None
and snapshot_before.vram is not None
and snapshot_after.vram is not None
):
vram_change = abs(snapshot_before.vram - snapshot_after.vram)
# If the estimated model size does not match the change in VRAM, log a warning.
if not math.isclose(
vram_change,
cache_entry.size,
rel_tol=0.1,
abs_tol=10 * MB,
):
self.logger.debug(
f"Moving model '{cache_entry.key}' from {source_device} to"
f" {target_device} caused an unexpected change in VRAM usage. The model's"
" estimated size may be incorrect. Estimated model size:"
f" {(cache_entry.size/GB):.3f} GB.\n"
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
def print_cuda_stats(self) -> None:
"""Log CUDA diagnostics."""
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GB)
ram = "%4.2fG" % (self.cache_size() / GB)
in_ram_models = 0
in_vram_models = 0
locked_in_vram_models = 0
for cache_record in self._cached_models.values():
if hasattr(cache_record.model, "device"):
if cache_record.model.device == self.storage_device:
in_ram_models += 1
else:
in_vram_models += 1
if cache_record.locked:
locked_in_vram_models += 1
self.logger.debug(
f"Current VRAM/RAM usage: {vram}/{ram}; models_in_ram/models_in_vram(locked) ="
f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})"
)
def make_room(self, size: int) -> None:
"""Make enough room in the cache to accommodate a new model of indicated size.
Note: This function deletes all of the cache's internal references to a model in order to free it. If there are
external references to the model, there's nothing that the cache can do about it, and those models will not be
garbage-collected.
"""
bytes_needed = size
maximum_size = self.max_cache_size * GB # stored in GB, convert to bytes
current_size = self.cache_size()
if current_size + bytes_needed > maximum_size:
self.logger.debug(
f"Max cache size exceeded: {(current_size/GB):.2f}/{self.max_cache_size:.2f} GB, need an additional"
f" {(bytes_needed/GB):.2f} GB"
)
self.logger.debug(f"Before making_room: cached_models={len(self._cached_models)}")
pos = 0
models_cleared = 0
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
model_key = self._cache_stack[pos]
cache_entry = self._cached_models[model_key]
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
self.logger.debug(
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}"
)
if not cache_entry.locked:
self.logger.debug(
f"Removing {model_key} from RAM cache to free at least {(size/GB):.2f} GB (-{(cache_entry.size/GB):.2f} GB)"
)
current_size -= cache_entry.size
models_cleared += 1
self._delete_cache_entry(cache_entry)
del cache_entry
else:
pos += 1
if models_cleared > 0:
# There would likely be some 'garbage' to be collected regardless of whether a model was cleared or not, but
# there is a significant time cost to calling `gc.collect()`, so we want to use it sparingly. (The time cost
# is high even if no garbage gets collected.)
#
# Calling gc.collect(...) when a model is cleared seems like a good middle-ground:
# - If models had to be cleared, it's a signal that we are close to our memory limit.
# - If models were cleared, there's a good chance that there's a significant amount of garbage to be
# collected.
#
# Keep in mind that gc is only responsible for handling reference cycles. Most objects should be cleaned up
# immediately when their reference count hits 0.
if self.stats:
self.stats.cleared = models_cleared
gc.collect()
TorchDevice.empty_cache()
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:
self._cache_stack.remove(cache_entry.key)
del self._cached_models[cache_entry.key]

View File

@@ -1,64 +0,0 @@
"""
Base class and implementation of a class that moves models in and out of VRAM.
"""
from typing import Dict, Optional
import torch
from invokeai.backend.model_manager import AnyModel
from invokeai.backend.model_manager.load.model_cache.model_cache_base import (
CacheRecord,
ModelCacheBase,
ModelLockerBase,
)
class ModelLocker(ModelLockerBase):
"""Internal class that mediates movement in and out of GPU."""
def __init__(self, cache: ModelCacheBase[AnyModel], cache_entry: CacheRecord[AnyModel]):
"""
Initialize the model locker.
:param cache: The ModelCache object
:param cache_entry: The entry in the model cache
"""
self._cache = cache
self._cache_entry = cache_entry
@property
def model(self) -> AnyModel:
"""Return the model without moving it around."""
return self._cache_entry.model
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
"""Return the state dict (if any) for the cached model."""
return self._cache_entry.state_dict
def lock(self) -> AnyModel:
"""Move the model into the execution device (GPU) and lock it."""
self._cache_entry.lock()
try:
if self._cache.lazy_offloading:
self._cache.offload_unlocked_models(self._cache_entry.size)
self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device)
self._cache_entry.loaded = True
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}")
self._cache.print_cuda_stats()
except torch.cuda.OutOfMemoryError:
self._cache.logger.warning("Insufficient GPU memory to load model. Aborting")
self._cache_entry.unlock()
raise
except Exception:
self._cache_entry.unlock()
raise
return self.model
def unlock(self) -> None:
"""Call upon exit from context."""
self._cache_entry.unlock()
if not self._cache.lazy_offloading:
self._cache.offload_unlocked_models(0)
self._cache.print_cuda_stats()

View File

@@ -0,0 +1,33 @@
from typing import Any, Callable
import torch
from torch.overrides import TorchFunctionMode
def add_autocast_to_module_forward(m: torch.nn.Module, to_device: torch.device):
"""Monkey-patch m.forward(...) with a new forward(...) method that activates device autocasting for its duration."""
old_forward = m.forward
def new_forward(*args: Any, **kwargs: Any):
with TorchFunctionAutocastDeviceContext(to_device):
return old_forward(*args, **kwargs)
m.forward = new_forward
def _cast_to_device_and_run(
func: Callable[..., Any], args: tuple[Any, ...], kwargs: dict[str, Any], to_device: torch.device
):
args_on_device = [a.to(to_device) if isinstance(a, torch.Tensor) else a for a in args]
kwargs_on_device = {k: v.to(to_device) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
return func(*args_on_device, **kwargs_on_device)
class TorchFunctionAutocastDeviceContext(TorchFunctionMode):
def __init__(self, to_device: torch.device):
self._to_device = to_device
def __torch_function__(
self, func: Callable[..., Any], types, args: tuple[Any, ...] = (), kwargs: dict[str, Any] | None = None
):
return _cast_to_device_and_run(func, args, kwargs or {}, self._to_device)

View File

@@ -26,7 +26,7 @@ from invokeai.backend.model_manager import (
SubModelType,
)
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
@@ -40,7 +40,7 @@ class LoRALoader(ModelLoader):
self,
app_config: InvokeAIAppConfig,
logger: Logger,
ram_cache: ModelCacheBase[AnyModel],
ram_cache: ModelCache,
):
"""Initialize the loader."""
super().__init__(app_config, logger, ram_cache)

View File

@@ -25,6 +25,7 @@ from invokeai.backend.model_manager.config import (
DiffusersConfigBase,
MainCheckpointConfig,
)
from invokeai.backend.model_manager.load.model_cache.model_cache import get_model_cache_key
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.util.silence_warnings import SilenceWarnings
@@ -132,5 +133,5 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
if subtype == submodel_type:
continue
if submodel := getattr(pipeline, subtype.value, None):
self._ram_cache.put(config.key, submodel_type=subtype, model=submodel)
self._ram_cache.put(get_model_cache_key(config.key, subtype), model=submodel)
return getattr(pipeline, submodel_type.value)

View File

@@ -469,7 +469,7 @@ class ModelProbe(object):
"""
# scan model
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0:
if scan_result.infected_files != 0 or scan_result.scan_err:
raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")
@@ -485,6 +485,7 @@ MODEL_NAME_TO_PREPROCESSOR = {
"lineart anime": "lineart_anime_image_processor",
"lineart_anime": "lineart_anime_image_processor",
"lineart": "lineart_image_processor",
"soft": "hed_image_processor",
"softedge": "hed_image_processor",
"hed": "hed_image_processor",
"shuffle": "content_shuffle_image_processor",

View File

@@ -44,7 +44,7 @@ def _fast_safetensors_reader(path: str) -> Dict[str, torch.Tensor]:
return checkpoint
def read_checkpoint_meta(path: Union[str, Path], scan: bool = False) -> Dict[str, torch.Tensor]:
def read_checkpoint_meta(path: Union[str, Path], scan: bool = True) -> Dict[str, torch.Tensor]:
if str(path).endswith(".safetensors"):
try:
path_str = path.as_posix() if isinstance(path, Path) else path
@@ -55,7 +55,7 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = False) -> Dict[str
else:
if scan:
scan_result = scan_file_path(path)
if scan_result.infected_files != 0:
if scan_result.infected_files != 0 or scan_result.scan_err:
raise Exception(f'The model file "{path}" is potentially infected by malware. Aborting import.')
if str(path).endswith(".gguf"):
# The GGUF reader used here uses numpy memmap, so these tensors are not loaded into memory during this function

View File

@@ -0,0 +1,12 @@
import logging
from typing import Any, MutableMapping
# Issue with type hints related to LoggerAdapter: https://github.com/python/typeshed/issues/7855
class PrefixedLoggerAdapter(logging.LoggerAdapter): # type: ignore
def __init__(self, logger: logging.Logger, prefix: str):
super().__init__(logger, {})
self.prefix = prefix
def process(self, msg: str, kwargs: MutableMapping[str, Any]) -> tuple[str, MutableMapping[str, Any]]:
return f"[{self.prefix}] {msg}", kwargs

View File

@@ -96,7 +96,9 @@
"new": "Neu",
"ok": "OK",
"close": "Schließen",
"clipboard": "Zwischenablage"
"clipboard": "Zwischenablage",
"generating": "Generieren",
"loadingModel": "Lade Modell"
},
"gallery": {
"galleryImageSize": "Bildgröße",
@@ -591,7 +593,15 @@
"loraTriggerPhrases": "LoRA-Auslösephrasen",
"installingBundle": "Bündel wird installiert",
"triggerPhrases": "Auslösephrasen",
"mainModelTriggerPhrases": "Hauptmodell-Auslösephrasen"
"mainModelTriggerPhrases": "Hauptmodell-Auslösephrasen",
"noDefaultSettings": "Für dieses Modell sind keine Standardeinstellungen konfiguriert. Besuchen Sie den Modell-Manager, um Standardeinstellungen hinzuzufügen.",
"defaultSettingsOutOfSync": "Einige Einstellungen stimmen nicht mit den Standardeinstellungen des Modells überein:",
"clipLEmbed": "CLIP-L einbetten",
"clipGEmbed": "CLIP-G einbetten",
"hfTokenLabel": "HuggingFace-Token (für einige Modelle erforderlich)",
"hfTokenHelperText": "Für die Nutzung einiger Modelle ist ein HF-Token erforderlich. Klicken Sie hier, um Ihr Token zu erstellen oder zu erhalten.",
"hfForbidden": "Sie haben keinen Zugriff auf dieses HF-Modell",
"hfTokenInvalid": "Ungültiges oder fehlendes HF-Token"
},
"parameters": {
"images": "Bilder",
@@ -841,7 +851,8 @@
"upscaling": "Hochskalierung",
"canvas": "Leinwand",
"prompts_one": "Prompt",
"prompts_other": "Prompts"
"prompts_other": "Prompts",
"batchSize": "Stapelgröße"
},
"metadata": {
"negativePrompt": "Negativ Beschreibung",
@@ -1081,6 +1092,21 @@
},
"patchmatchDownScaleSize": {
"heading": "Herunterskalieren"
},
"paramHeight": {
"heading": "Höhe",
"paragraphs": [
"Höhe des generierten Bildes. Muss ein Vielfaches von 8 sein."
]
},
"paramUpscaleMethod": {
"heading": "Vergrößerungsmethode",
"paragraphs": [
"Methode zum Hochskalieren des Bildes für High Resolution Fix."
]
},
"paramHrf": {
"heading": "High Resolution Fix aktivieren"
}
},
"invocationCache": {

View File

@@ -176,7 +176,8 @@
"reset": "Reset",
"none": "None",
"new": "New",
"generating": "Generating"
"generating": "Generating",
"warnings": "Warnings"
},
"hrf": {
"hrf": "High Resolution Fix",
@@ -1038,20 +1039,7 @@
"canvasIsSelectingObject": "Canvas is busy (selecting object)",
"noPrompts": "No prompts generated",
"noNodesInGraph": "No nodes in graph",
"systemDisconnected": "System disconnected",
"layer": {
"controlAdapterNoModelSelected": "no Control Adapter model selected",
"controlAdapterIncompatibleBaseModel": "incompatible Control Adapter base model",
"t2iAdapterIncompatibleBboxWidth": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}}, bbox width is {{width}}",
"t2iAdapterIncompatibleBboxHeight": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}}, bbox height is {{height}}",
"t2iAdapterIncompatibleScaledBboxWidth": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}}, scaled bbox width is {{width}}",
"t2iAdapterIncompatibleScaledBboxHeight": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}}, scaled bbox height is {{height}}",
"ipAdapterNoModelSelected": "no IP adapter selected",
"ipAdapterIncompatibleBaseModel": "incompatible IP Adapter base model",
"ipAdapterNoImageSelected": "no IP Adapter image selected",
"rgNoPromptsOrIPAdapters": "no text prompts or IP Adapters",
"rgNoRegion": "no region selected"
}
"systemDisconnected": "System disconnected"
},
"maskBlur": "Mask Blur",
"negativePromptPlaceholder": "Negative Prompt",
@@ -1319,8 +1307,9 @@
"controlNetBeginEnd": {
"heading": "Begin / End Step Percentage",
"paragraphs": [
"The part of the of the denoising process that will have the Control Adapter applied.",
"Generally, Control Adapters applied at the start of the process guide composition, and Control Adapters applied at the end guide details."
"This setting determines which portion of the denoising (generation) process incorporates the guidance from this layer.",
"• Start Step (%): Specifies when to begin applying the guidance from this layer during the generation process.",
"• End Step (%): Specifies when to stop applying this layer's guidance and revert general guidance from the model and other settings."
]
},
"controlNetControlMode": {
@@ -1338,13 +1327,15 @@
"paragraphs": ["Method to fit Control Adapter's input image size to the output generation size."]
},
"ipAdapterMethod": {
"heading": "Method",
"paragraphs": ["Method by which to apply the current IP Adapter."]
"heading": "Mode",
"paragraphs": ["The mode defines how the reference image will guide the generation process."]
},
"controlNetWeight": {
"heading": "Weight",
"paragraphs": [
"Weight of the Control Adapter. Higher weight will lead to larger impacts on the final image."
"Adjusts how strongly the layer influences the generation process",
"• Higher Weight (.75-2): Creates a more significant impact on the final result.",
"• Lower Weight (0-.75): Creates a smaller impact on the final result."
]
},
"dynamicPrompts": {
@@ -1710,6 +1701,8 @@
"controlLayer": "Control Layer",
"inpaintMask": "Inpaint Mask",
"regionalGuidance": "Regional Guidance",
"referenceImageRegional": "Reference Image (Regional)",
"referenceImageGlobal": "Reference Image (Global)",
"asRasterLayer": "As $t(controlLayers.rasterLayer)",
"asRasterLayerResize": "As $t(controlLayers.rasterLayer) (Resize)",
"asControlLayer": "As $t(controlLayers.controlLayer)",
@@ -1781,6 +1774,7 @@
"pullBboxIntoLayer": "Pull Bbox into Layer",
"pullBboxIntoReferenceImage": "Pull Bbox into Reference Image",
"showProgressOnCanvas": "Show Progress on Canvas",
"useImage": "Use Image",
"prompt": "Prompt",
"negativePrompt": "Negative Prompt",
"beginEndStepPercentShort": "Begin/End %",
@@ -1793,6 +1787,22 @@
"resetGenerationSettings": "Reset Generation Settings",
"replaceCurrent": "Replace Current",
"controlLayerEmptyState": "<UploadButton>Upload an image</UploadButton>, drag an image from the <GalleryButton>gallery</GalleryButton> onto this layer, or draw on the canvas to get started.",
"referenceImageEmptyState": "<UploadButton>Upload an image</UploadButton> or drag an image from the <GalleryButton>gallery</GalleryButton> onto this layer to get started.",
"warnings": {
"problemsFound": "Problems found",
"unsupportedModel": "layer not supported for selected base model",
"controlAdapterNoModelSelected": "no Control Layer model selected",
"controlAdapterIncompatibleBaseModel": "incompatible Control Layer base model",
"controlAdapterNoControl": "no control selected/drawn",
"ipAdapterNoModelSelected": "no Reference Image model selected",
"ipAdapterIncompatibleBaseModel": "incompatible Reference Image base model",
"ipAdapterNoImageSelected": "no Reference Image image selected",
"rgNoPromptsOrIPAdapters": "no text prompts or Reference Images",
"rgNegativePromptNotSupported": "Negative Prompt not supported for selected base model",
"rgReferenceImagesNotSupported": "regional Reference Images not supported for selected base model",
"rgAutoNegativeNotSupported": "Auto-Negative not supported for selected base model",
"rgNoRegion": "no region drawn"
},
"controlMode": {
"controlMode": "Control Mode",
"balanced": "Balanced (recommended)",
@@ -1801,10 +1811,13 @@
"megaControl": "Mega Control"
},
"ipAdapterMethod": {
"ipAdapterMethod": "IP Adapter Method",
"ipAdapterMethod": "Mode",
"full": "Style and Composition",
"fullDesc": "Applies visual style (colors, textures) & composition (layout, structure).",
"style": "Style Only",
"composition": "Composition Only"
"styleDesc": "Applies visual style (colors, textures) without considering its layout.",
"composition": "Composition Only",
"compositionDesc": "Replicates layout & structure while ignoring the reference's style."
},
"fill": {
"fillColor": "Fill Color",

View File

@@ -327,7 +327,6 @@
"t2iAdapterIncompatibleBboxHeight": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}}, la hauteur de la bounding box est {{height}}",
"t2iAdapterIncompatibleBboxWidth": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}}, la largeur de la bounding box est {{width}}",
"ipAdapterIncompatibleBaseModel": "modèle de base d'IP adapter incompatible",
"rgNoRegion": "aucune zone sélectionnée",
"controlAdapterNoModelSelected": "aucun modèle de Control Adapter sélectionné"
},
"noPrompts": "Aucun prompts généré",

View File

@@ -96,7 +96,8 @@
"clipboard": "Appunti",
"ok": "Ok",
"generating": "Generazione",
"loadingModel": "Caricamento del modello"
"loadingModel": "Caricamento del modello",
"warnings": "Avvisi"
},
"gallery": {
"galleryImageSize": "Dimensione dell'immagine",
@@ -671,11 +672,15 @@
"ipAdapterIncompatibleBaseModel": "Il modello base dell'adattatore IP non è compatibile",
"ipAdapterNoImageSelected": "Nessuna immagine dell'adattatore IP selezionata",
"rgNoPromptsOrIPAdapters": "Nessun prompt o adattatore IP",
"rgNoRegion": "Nessuna regione selezionata",
"t2iAdapterIncompatibleBboxWidth": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}}, larghezza riquadro è {{width}}",
"t2iAdapterIncompatibleBboxHeight": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}}, altezza riquadro è {{height}}",
"t2iAdapterIncompatibleScaledBboxWidth": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}}, larghezza del riquadro scalato {{width}}",
"t2iAdapterIncompatibleScaledBboxHeight": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}}, altezza del riquadro scalato {{height}}"
"t2iAdapterIncompatibleScaledBboxHeight": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}}, altezza del riquadro scalato {{height}}",
"rgNegativePromptNotSupported": "prompt negativo non supportato per il modello base selezionato",
"rgAutoNegativeNotSupported": "auto-negativo non supportato per il modello base selezionato",
"emptyLayer": "livello vuoto",
"unsupportedModel": "livello non supportato per il modello base selezionato",
"rgReferenceImagesNotSupported": "immagini di riferimento regionali non supportate per il modello base selezionato"
},
"fluxModelIncompatibleBboxHeight": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), altezza riquadro è {{height}}",
"fluxModelIncompatibleBboxWidth": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), larghezza riquadro è {{width}}",
@@ -687,7 +692,11 @@
"canvasIsTransforming": "La tela sta trasformando",
"canvasIsRasterizing": "La tela sta rasterizzando",
"canvasIsCompositing": "La tela è in fase di composizione",
"canvasIsFiltering": "La tela sta filtrando"
"canvasIsFiltering": "La tela sta filtrando",
"collectionTooManyItems": "{{nodeLabel}} -> {{fieldLabel}}: troppi elementi, massimo {{maxItems}}",
"canvasIsSelectingObject": "La tela è occupata (selezione dell'oggetto)",
"collectionTooFewItems": "{{nodeLabel}} -> {{fieldLabel}}: troppi pochi elementi, minimo {{minItems}}",
"collectionEmpty": "{{nodeLabel}} -> {{fieldLabel}} raccolta vuota"
},
"useCpuNoise": "Usa la CPU per generare rumore",
"iterations": "Iterazioni",
@@ -972,7 +981,9 @@
"saveToGallery": "Salva nella Galleria",
"noMatchingWorkflows": "Nessun flusso di lavoro corrispondente",
"noWorkflows": "Nessun flusso di lavoro",
"workflowHelpText": "Hai bisogno di aiuto? Consulta la nostra guida <LinkComponent>Introduzione ai flussi di lavoro</LinkComponent>."
"workflowHelpText": "Hai bisogno di aiuto? Consulta la nostra guida <LinkComponent>Introduzione ai flussi di lavoro</LinkComponent>.",
"specialDesc": "Questa invocazione comporta una gestione speciale nell'applicazione. Ad esempio, i nodi Lotto vengono utilizzati per mettere in coda più grafici da un singolo flusso di lavoro.",
"internalDesc": "Questa invocazione è utilizzata internamente da Invoke. Potrebbe subire modifiche significative durante gli aggiornamenti dell'app e potrebbe essere rimossa in qualsiasi momento."
},
"boards": {
"autoAddBoard": "Aggiungi automaticamente bacheca",
@@ -1093,7 +1104,8 @@
"workflows": "Flussi di lavoro",
"generation": "Generazione",
"other": "Altro",
"gallery": "Galleria"
"gallery": "Galleria",
"batchSize": "Dimensione del lotto"
},
"models": {
"noMatchingModels": "Nessun modello corrispondente",
@@ -1196,7 +1208,8 @@
"heading": "Percentuale passi Inizio / Fine",
"paragraphs": [
"La parte del processo di rimozione del rumore in cui verrà applicato l'adattatore di controllo.",
"In genere, gli adattatori di controllo applicati all'inizio del processo guidano la composizione, mentre quelli applicati alla fine guidano i dettagli."
"In genere, gli adattatori di controllo applicati all'inizio del processo guidano la composizione, mentre quelli applicati alla fine guidano i dettagli.",
"• Passo finale (%): specifica quando interrompere l'applicazione della guida di questo livello e ripristinare la guida generale dal modello e altre impostazioni."
]
},
"noiseUseCPU": {
@@ -1300,7 +1313,9 @@
"controlNetWeight": {
"heading": "Peso",
"paragraphs": [
"Peso dell'adattatore di controllo. Un peso maggiore porterà a impatti maggiori sull'immagine finale."
"Regola la forza con cui il livello influenza il processo di generazione",
"• Peso maggiore (0.75-2): crea un impatto più significativo sul risultato finale.",
"• Peso inferiore (0-0.75): crea un impatto minore sul risultato finale."
]
},
"paramCFGScale": {
@@ -1801,7 +1816,10 @@
"full": "Stile e Composizione",
"style": "Solo Stile",
"composition": "Solo Composizione",
"ipAdapterMethod": "Metodo Adattatore IP"
"ipAdapterMethod": "Metodo Adattatore IP",
"fullDesc": "Applica lo stile visivo (colori, texture) e la composizione (disposizione, struttura).",
"styleDesc": "Applica lo stile visivo (colori, texture) senza considerare la disposizione.",
"compositionDesc": "Replica disposizione e struttura ignorando lo stile di riferimento."
},
"showingType": "Mostra {{type}}",
"dynamicGrid": "Griglia dinamica",
@@ -2044,7 +2062,16 @@
"replaceCurrent": "Sostituisci corrente",
"mergeDown": "Unire in basso",
"mergingLayers": "Unione dei livelli",
"controlLayerEmptyState": "<UploadButton>Carica un'immagine</UploadButton>, trascina un'immagine dalla <GalleryButton>galleria</GalleryButton> su questo livello oppure disegna sulla tela per iniziare."
"controlLayerEmptyState": "<UploadButton>Carica un'immagine</UploadButton>, trascina un'immagine dalla <GalleryButton>galleria</GalleryButton> su questo livello oppure disegna sulla tela per iniziare.",
"useImage": "Usa immagine",
"resetGenerationSettings": "Ripristina impostazioni di generazione",
"referenceImageEmptyState": "Per iniziare, <UploadButton>carica un'immagine</UploadButton> oppure trascina un'immagine dalla <GalleryButton>galleria</GalleryButton> su questo livello.",
"asRasterLayer": "Come $t(controlLayers.rasterLayer)",
"asRasterLayerResize": "Come $t(controlLayers.rasterLayer) (Ridimensiona)",
"asControlLayer": "Come $t(controlLayers.controlLayer)",
"asControlLayerResize": "Come $t(controlLayers.controlLayer) (Ridimensiona)",
"newSession": "Nuova sessione",
"resetCanvasLayers": "Ripristina livelli Tela"
},
"ui": {
"tabs": {
@@ -2144,7 +2171,7 @@
"watchRecentReleaseVideos": "Guarda i video su questa versione",
"watchUiUpdatesOverview": "Guarda le novità dell'interfaccia",
"items": [
"<StrongComponent>SD 3.5</StrongComponent>: supporto per SD 3.5 Medium e Large.",
"<StrongComponent>Flussi di lavoro</StrongComponent>: esegui un flusso di lavoro per una raccolta di immagini utilizzando il nuovo nodo <StrongComponent>Lotto di immagini</StrongComponent>.",
"<StrongComponent>Tela</StrongComponent>: elaborazione semplificata del livello di controllo e impostazioni di controllo predefinite migliorate."
]
},
@@ -2172,5 +2199,67 @@
"logNamespaces": "Elementi del registro"
},
"enableLogging": "Abilita la registrazione"
},
"supportVideos": {
"gettingStarted": "Iniziare",
"supportVideos": "Video di supporto",
"videos": {
"usingControlLayersAndReferenceGuides": {
"title": "Utilizzo di livelli di controllo e guide di riferimento",
"description": "Scopri come guidare la creazione delle tue immagini con livelli di controllo e immagini di riferimento."
},
"creatingYourFirstImage": {
"description": "Introduzione alla creazione di un'immagine da zero utilizzando gli strumenti di Invoke.",
"title": "Creazione della tua prima immagine"
},
"understandingImageToImageAndDenoising": {
"description": "Panoramica delle trasformazioni immagine-a-immagine e della riduzione del rumore in Invoke.",
"title": "Comprendere immagine-a-immagine e riduzione del rumore"
},
"howDoIDoImageToImageTransformation": {
"description": "Tutorial su come eseguire trasformazioni da immagine a immagine in Invoke.",
"title": "Come si esegue la trasformazione da immagine-a-immagine?"
},
"howDoIUseInpaintMasks": {
"title": "Come si usano le maschere Inpaint?",
"description": "Come applicare maschere inpaint per la correzione e la variazione delle immagini."
},
"howDoIOutpaint": {
"description": "Guida all'outpainting oltre i confini dell'immagine originale.",
"title": "Come posso eseguire l'outpainting?"
},
"exploringAIModelsAndConceptAdapters": {
"description": "Approfondisci i modelli di intelligenza artificiale e scopri come utilizzare gli adattatori concettuali per il controllo creativo.",
"title": "Esplorazione dei modelli di IA e degli adattatori concettuali"
},
"upscaling": {
"title": "Ampliamento",
"description": "Come ampliare le immagini con gli strumenti di Invoke per migliorarne la risoluzione."
},
"creatingAndComposingOnInvokesControlCanvas": {
"description": "Impara a comporre immagini utilizzando la tela di controllo di Invoke.",
"title": "Creare e comporre sulla tela di controllo di Invoke"
},
"howDoIGenerateAndSaveToTheGallery": {
"description": "Passaggi per generare e salvare le immagini nella galleria.",
"title": "Come posso generare e salvare nella Galleria?"
},
"howDoIEditOnTheCanvas": {
"title": "Come posso apportare modifiche sulla tela?",
"description": "Guida alla modifica delle immagini direttamente sulla tela."
},
"howDoIUseControlNetsAndControlLayers": {
"title": "Come posso utilizzare le Reti di Controllo e i Livelli di Controllo?",
"description": "Impara ad applicare livelli di controllo e reti di controllo alle tue immagini."
},
"howDoIUseGlobalIPAdaptersAndReferenceImages": {
"title": "Come si utilizzano gli adattatori IP globali e le immagini di riferimento?",
"description": "Introduzione all'aggiunta di immagini di riferimento e adattatori IP globali."
}
},
"controlCanvas": "Tela di Controllo",
"watch": "Guarda",
"studioSessionsDesc1": "Dai un'occhiata a <StudioSessionsPlaylistLink /> per approfondimenti su Invoke.",
"studioSessionsDesc2": "Unisciti al nostro <DiscordLink /> per partecipare alle sessioni live e fare domande. Le sessioni vengono caricate sulla playlist la settimana successiva."
}
}

View File

@@ -236,7 +236,6 @@
"controlAdapterIncompatibleBaseModel": "niet-compatibele basismodel voor controle-adapter",
"ipAdapterIncompatibleBaseModel": "niet-compatibele basismodel voor IP-adapter",
"ipAdapterNoImageSelected": "geen afbeelding voor IP-adapter geselecteerd",
"rgNoRegion": "geen gebied geselecteerd",
"rgNoPromptsOrIPAdapters": "geen tekstprompts of IP-adapters",
"ipAdapterNoModelSelected": "geen IP-adapter geselecteerd"
}

View File

@@ -10,7 +10,24 @@
"load": "Załaduj",
"statusDisconnected": "Odłączono od serwera",
"githubLabel": "GitHub",
"discordLabel": "Discord"
"discordLabel": "Discord",
"clipboard": "Schowek",
"aboutDesc": "Wykorzystujesz Invoke do pracy? Sprawdź:",
"ai": "SI",
"areYouSure": "Czy jesteś pewien?",
"copyError": "$t(gallery.copy) Błąd",
"apply": "Zastosuj",
"copy": "Kopiuj",
"or": "albo",
"add": "Dodaj",
"off": "Wyłączony",
"accept": "Zaakceptuj",
"cancel": "Anuluj",
"advanced": "Zawansowane",
"back": "Do tyłu",
"auto": "Automatyczny",
"beta": "Beta",
"close": "Wyjdź"
},
"gallery": {
"galleryImageSize": "Rozmiar obrazów",
@@ -65,6 +82,42 @@
"uploadImage": "Wgrywanie obrazu",
"previousImage": "Poprzedni obraz",
"nextImage": "Następny obraz",
"menu": "Menu"
"menu": "Menu",
"mode": "Tryb"
},
"boards": {
"cancel": "Anuluj",
"noBoards": "Brak tablic typu {{boardType}}",
"imagesWithCount_one": "{{count}} zdjęcie",
"imagesWithCount_few": "{{count}} zdjęcia",
"imagesWithCount_many": "{{count}} zdjęcia",
"private": "Prywatne tablice",
"updateBoardError": "Błąd aktualizacji tablicy",
"uncategorized": "Nieskategoryzowane",
"selectBoard": "Wybierz tablicę",
"downloadBoard": "Pobierz tablice",
"loading": "Ładowanie...",
"move": "Przenieś",
"noMatching": "Brak pasujących tablic"
},
"accordions": {
"compositing": {
"title": "Kompozycja",
"infillTab": "Inskrypcja",
"coherenceTab": "Przebieg Koherencji"
},
"generation": {
"title": "Generowanie"
},
"image": {
"title": "Zdjęcie"
},
"advanced": {
"options": "$t(accordions.advanced.title) Opcje",
"title": "Zaawansowane"
},
"control": {
"title": "Kontrola"
}
}
}

View File

@@ -652,7 +652,6 @@
"ipAdapterNoModelSelected": "IP адаптер не выбран",
"controlAdapterNoModelSelected": "не выбрана модель адаптера контроля",
"controlAdapterIncompatibleBaseModel": "несовместимая базовая модель адаптера контроля",
"rgNoRegion": "регион не выбран",
"rgNoPromptsOrIPAdapters": "нет текстовых запросов или IP-адаптеров",
"ipAdapterIncompatibleBaseModel": "несовместимая базовая модель IP-адаптера",
"ipAdapterNoImageSelected": "изображение IP-адаптера не выбрано",

View File

@@ -217,7 +217,10 @@
"direction": "Phương Hướng",
"unknownError": "Lỗi Không Rõ",
"selected": "Đã chọn",
"tab": "Tab"
"tab": "Tab",
"loadingModel": "Đang Tải Model",
"generating": "Đang Tạo Sinh",
"warnings": "Cảnh Báo"
},
"prompt": {
"addPromptTrigger": "Thêm Prompt Trigger",
@@ -290,7 +293,8 @@
"cancelSucceeded": "Mục Đã Huỷ Bỏ",
"completedIn": "Hoàn tất trong",
"graphQueued": "Đồ Thị Đã Vào Hàng",
"batchQueuedDesc_other": "Thêm {{count}} phiên vào {{direction}} của hàng"
"batchQueuedDesc_other": "Thêm {{count}} phiên vào {{direction}} của hàng",
"batchSize": "Kích Thước Vùng Hàng Loạt"
},
"hotkeys": {
"canvas": {
@@ -733,7 +737,9 @@
"textualInversions": "Bộ Đảo Ngược Văn Bản",
"loraTriggerPhrases": "Từ Ngữ Kích Hoạt Cho LoRA",
"width": "Chiều Rộng",
"starterModelsInModelManager": "Model khởi đầu có thể tìm thấy ở Trình Quản Lý Model"
"starterModelsInModelManager": "Model khởi đầu có thể tìm thấy ở Trình Quản Lý Model",
"clipLEmbed": "CLIP-L Embed",
"clipGEmbed": "CLIP-G Embed"
},
"metadata": {
"guidance": "Hướng Dẫn",
@@ -905,7 +911,7 @@
"unknownNode": "Node Không Rõ",
"unknownNodeType": "Loại Node Không Rõ",
"unknownTemplate": "Mẫu Trình Bày Không Rõ",
"cannotConnectOutputToOutput": "Không thế kết nối đầu ra với đầu vào",
"cannotConnectOutputToOutput": "Không thế kết nối đầu ra với đầu ra",
"cannotConnectToSelf": "Không thể kết nối với chính nó",
"workflow": "Workflow",
"addNodeToolTip": "Thêm Node (Shift+A, Space)",
@@ -952,7 +958,9 @@
"executionStateInProgress": "Đang Xử Lý",
"showLegendNodes": "Hiển Thị Vùng Nhập",
"outputFieldTypeParseError": "Không thể phân tích loại dữ liệu đầu ra của {{node}}.{{field}} ({{message}})",
"modelAccessError": "Không thể tìm thấy model {{key}}, chuyển về mặc định"
"modelAccessError": "Không thể tìm thấy model {{key}}, chuyển về mặc định",
"internalDesc": "Trình kích hoạt này được dùng bên trong bởi Invoke. Nó có thể phá hỏng thay đổi trong khi cập nhật ứng dụng và có thể bị xoá bất cứ lúc nào.",
"specialDesc": "Trình kích hoạt này có một số xử lý đặc biệt trong ứng dụng. Ví dụ, Node Hàng Loạt được dùng để xếp vào nhiều đồ thị từ một workflow."
},
"popovers": {
"paramCFGRescaleMultiplier": {
@@ -1105,7 +1113,9 @@
},
"controlNetWeight": {
"paragraphs": [
"Trọng lượng của Control Adapter. Trọng lượng càng cao sẽ dẫn đến tác động càng lớn lên ảnh cuối cùng."
"Điều chỉnh mức độ layer ảnh hưởng đến quá trình xử lý tạo sinh.",
"• Trọng Lượng Lớn Hơn (.75-2): Gây ra ảnh hưởng lớn hơn lên kết quả cuối cùng.",
"• Trọng Lượng Nhỏ Hơn (0-.75): Gây ra ảnh hưởng nhỏ hơn lên kết quả cuối cùng."
],
"heading": "Trọng Lượng"
},
@@ -1149,7 +1159,7 @@
},
"ipAdapterMethod": {
"paragraphs": [
"Cách thức dùng để áp dụng IP Adapter hiện tại."
"Phương thức định nghĩa cách ảnh mẫu sẽ chỉ dẫn quá trình xử lý tạo sinh."
],
"heading": "Cách Thức"
},
@@ -1196,8 +1206,9 @@
},
"controlNetBeginEnd": {
"paragraphs": [
"Một phần trong quá trình xử lý khử nhiễu mà sẽ được Control Adapter áp dụng.",
"Nói chung, Control Adapter áp dụng vào lúc bắt đầu của quá trình hướng dẫn thành phần, và cũng áp dụng vào lúc kết thúc hướng dẫn chi tiết."
"Cài đặt này xác định phần xử lý khử nhiễu (trong khi tạo sinh) kết hợp với chỉ dẫn từ layer này.",
"• Bước Bắt Đầu (%): Chỉ định lúc bắt đầu áp dụng chỉ dẫn từ layer này trong quá trình tạo sinh.",
"• Bước Kết Thúc (%): Chỉ định lúc dừng áp dụng chỉ dẫn của layer này và trở về chỉ dẫn chung từ model và các thiết lập khác."
],
"heading": "Phần Trăm Tham Số Bước Khi Bắt Đầu/Kết Thúc"
},
@@ -1401,7 +1412,6 @@
"invoke": {
"layer": {
"t2iAdapterIncompatibleScaledBboxHeight": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}}, tỉ lệ chiều dài hộp giới hạn là {{height}}",
"rgNoRegion": "không có vùng được chọn",
"ipAdapterNoModelSelected": "không có IP Adapter được lựa chọn",
"ipAdapterNoImageSelected": "không có ảnh IP Adapter được lựa chọn",
"t2iAdapterIncompatibleBboxHeight": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}}, chiều dài hộp giới hạn là {{height}}",
@@ -1410,15 +1420,20 @@
"rgNoPromptsOrIPAdapters": "không có lệnh chữ hoặc IP Adapter",
"controlAdapterIncompatibleBaseModel": "model cơ sở của Control Adapter không tương thích",
"ipAdapterIncompatibleBaseModel": "dạng model cơ sở của IP Adapter không tương thích",
"controlAdapterNoModelSelected": "không có model Control Adapter được chọn"
"controlAdapterNoModelSelected": "không có model Control Adapter được chọn",
"emptyLayer": "layer trống",
"rgAutoNegativeNotSupported": "trình tự động đảo chiều không được hỗ trợ cho model cơ sở đang dùng",
"rgNegativePromptNotSupported": "lệnh tiêu cực không được hỗ trợ cho model cơ sở đang dùng",
"unsupportedModel": "layer không được hỗ trợ cho model cơ sở đang dùng",
"rgReferenceImagesNotSupported": "ảnh mẫu khu vực không được hỗ trợ cho model cơ sở đang dùng"
},
"fluxModelIncompatibleBboxWidth": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), chiều rộng hộp giới hạn là {{width}}",
"noModelSelected": "Không có model được lựa chọn",
"fluxModelIncompatibleScaledBboxHeight": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), tỉ lệ chiều dài hộp giới hạn là {{height}}",
"canvasIsFiltering": "Canvas đang được lọc",
"canvasIsRasterizing": "Canvas đang được raster hoá",
"canvasIsTransforming": "Canvas đang được biến đổi",
"canvasIsCompositing": "Canvas đang được kết hợp",
"canvasIsFiltering": "Canvas đang bận (đang lọc)",
"canvasIsRasterizing": "Canvas đang bận (đang raster hoá)",
"canvasIsTransforming": "Canvas đang bận (đang biến đổi)",
"canvasIsCompositing": "Canvas đang bận (đang kết hợp)",
"noPrompts": "Không có lệnh được tạo",
"noNodesInGraph": "Không có node trong đồ thị",
"addingImagesTo": "Thêm ảnh vào",
@@ -1430,8 +1445,12 @@
"missingNodeTemplate": "Thiếu mẫu trình bày node",
"fluxModelIncompatibleBboxHeight": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), chiều dài hộp giới hạn là {{height}}",
"fluxModelIncompatibleScaledBboxWidth": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), tỉ lệ chiều rộng hộp giới hạn là {{width}}",
"missingInputForField": "{{nodeLabel}} -> {{fieldLabel}} thiếu đầu ra",
"missingFieldTemplate": "Thiếu vùng mẫu trình bày"
"missingInputForField": "{{nodeLabel}} -> {{fieldLabel}}: thiếu đầu vào",
"missingFieldTemplate": "Thiếu vùng mẫu trình bày",
"collectionEmpty": "{{nodeLabel}} -> {{fieldLabel}} tài nguyên trống",
"collectionTooFewItems": "{{nodeLabel}} -> {{fieldLabel}}: quá ít mục, tối thiểu {{minItems}}",
"collectionTooManyItems": "{{nodeLabel}} -> {{fieldLabel}}: quá nhiều mục, tối đa {{maxItems}}",
"canvasIsSelectingObject": "Canvas đang bận (đang chọn đồ vật)"
},
"cfgScale": "Thước Đo CFG",
"useSeed": "Dùng Tham Số Hạt Giống",
@@ -1542,7 +1561,8 @@
"resetWebUIDesc2": "Nếu ảnh không được xuất hiện trong thư viện hoặc điều gì đó không ổn đang diễn ra, hãy thử khởi động lại trước khi báo lỗi trên Github.",
"displayInProgress": "Hiển Thị Hình Ảnh Đang Xử Lý",
"intermediatesClearedFailed": "Có Vấn Đề Khi Dọn Sạch Sản Phẩm Trung Gian",
"enableInvisibleWatermark": "Bật Chế Độ Ẩn Watermark"
"enableInvisibleWatermark": "Bật Chế Độ Ẩn Watermark",
"showDetailedInvocationProgress": "Hiện Dữ Liệu Xử Lý"
},
"sdxl": {
"loading": "Đang Tải...",
@@ -1594,7 +1614,7 @@
"pullBboxIntoLayerError": "Có Vấn Đề Khi Chuyển Hộp Giới Hạn Thành Layer",
"pullBboxIntoReferenceImageOk": "Chuyển Hộp Giới Hạn Thành Ảnh Mẫu",
"clearCaches": "Xoá Bộ Nhớ Đệm",
"outputOnlyMaskedRegions": "Chỉ Xuất Đầu Ra Ở Vùng Phủ",
"outputOnlyMaskedRegions": "Chỉ Xuất Đầu Ra Ở Vùng Tạo Sinh",
"addLayer": "Thêm Layer",
"regional": "Khu Vực",
"regionIsEmpty": "Vùng được chọn trống",
@@ -1608,10 +1628,13 @@
"moveForward": "Chuyển Lên Đầu",
"fitBboxToLayers": "Xếp Vừa Hộp Giới Hạn Vào Layer",
"ipAdapterMethod": {
"full": "Đầy Đủ",
"full": "Phong Cách Và Thành Phần",
"style": "Chỉ Lấy Phong Cách",
"composition": "Chỉ Lấy Thành Phần",
"ipAdapterMethod": "Cách Thức IP Adapter"
"ipAdapterMethod": "Cách Thức",
"compositionDesc": "Áp dụng cách trình bày và bỏ qua phong cách mẫu.",
"fullDesc": "Áp dụng phong cách trực quan (màu, cấu tạo) & thành phần (cách trình bày).",
"styleDesc": "Áp dụng phong cách trực quan (màu, cấu tạo) và bỏ qua cách trình bày."
},
"deletePrompt": "Xoá Lệnh",
"rasterLayer": "Layer Dạng Raster",
@@ -1899,7 +1922,16 @@
"colorPicker": "Chọn Màu"
},
"mergingLayers": "Đang gộp layer",
"controlLayerEmptyState": "<UploadButton>Tải lên ảnh</UploadButton>, kéo thả ảnh từ <GalleryButton>thư viện</GalleryButton> vào layer này, hoặc vẽ trên canvas để bắt đầu."
"controlLayerEmptyState": "<UploadButton>Tải lên ảnh</UploadButton>, kéo thả ảnh từ <GalleryButton>thư viện</GalleryButton> vào layer này, hoặc vẽ trên canvas để bắt đầu.",
"referenceImageEmptyState": "<UploadButton>Tải lên ảnh</UploadButton> hoặc kéo thả ảnh từ <GalleryButton>thư viện</GalleryButton> vào layer này để bắt đầu.",
"useImage": "Dùng Hình Ảnh",
"resetCanvasLayers": "Khởi Động Lại Layer Canvas",
"asRasterLayer": "Như $t(controlLayers.rasterLayer)",
"asRasterLayerResize": "Như $t(controlLayers.rasterLayer) (Thay Đổi Kích Thước)",
"asControlLayer": "Như $t(controlLayers.controlLayer)",
"asControlLayerResize": "Như $t(controlLayers.controlLayer) (Thay Đổi Kích Thước)",
"newSession": "Phiên Làm Việc Mới",
"resetGenerationSettings": "Khởi Động Lại Cài Đặt Tạo Sinh"
},
"stylePresets": {
"negativePrompt": "Lệnh Tiêu Cực",
@@ -2124,8 +2156,8 @@
"watchRecentReleaseVideos": "Xem Video Phát Hành Mới Nhất",
"watchUiUpdatesOverview": "Xem Tổng Quan Về Những Cập Nhật Cho Giao Diện Người Dùng",
"items": [
"<StrongComponent>SD 3.5</StrongComponent>: Hỗ trợ cho Từ ngữ Sang Hình Ảnh trong Workflow với phiên bản SD 3.5 Medium hoặc Large.",
"<StrongComponent>Canvas</StrongComponent>: Hợp lý hoá cách xử lý Layer Điều Khiển Được và cải thiện thiết lập điều khiển mặc định."
"<StrongComponent>Workflows</StrongComponent>: Chạy một workflow cho nhiều ảnh bằng node <StrongComponent>Ảnh Hàng Loạt</StrongComponent> mới.",
"<StrongComponent>FLUX</StrongComponent>: Hỗ trợ cho XLabs IP Adapter v2."
]
},
"upsell": {
@@ -2133,5 +2165,67 @@
"inviteTeammates": "Thêm Đồng Đội",
"shareAccess": "Chia Sẻ Quyền Truy Cập",
"professionalUpsell": "Không có sẵn Phiên Bản Chuyên Nghiệp cho Invoke. Bấm vào đây hoặc đến invoke.com/pricing để thêm chi tiết."
},
"supportVideos": {
"supportVideos": "Video Hỗ Trợ",
"gettingStarted": "Bắt Đầu Làm Quen",
"studioSessionsDesc1": "Xem thử <StudioSessionsPlaylistLink /> để hiểu rõ Invoke hơn.",
"studioSessionsDesc2": "Đến <DiscordLink /> để tham gia vào phiên trực tiếp và hỏi câu hỏi. Các phiên được tải lên danh sách phát vào các tuần.",
"videos": {
"howDoIDoImageToImageTransformation": {
"title": "Làm Sao Để Tôi Dùng Trình Biến Đổi Hình Ảnh Sang Hình Ảnh?",
"description": "Hướng dẫn cách thực hiện biến đổi ảnh sang ảnh trong Invoke."
},
"howDoIUseGlobalIPAdaptersAndReferenceImages": {
"description": "Giới thiệu về ảnh mẫu và IP adapter toàn vùng.",
"title": "Làm Sao Để Tôi Dùng IP Adapter Toàn Vùng Và Ảnh Mẫu?"
},
"creatingAndComposingOnInvokesControlCanvas": {
"description": "Học cách sáng tạo ảnh bằng trình điều khiển canvas của Invoke.",
"title": "Sáng Tạo Trong Trình Kiểm Soát Canvas Của Invoke"
},
"upscaling": {
"description": "Cách upscale ảnh bằng bộ công cụ của Invoke để nâng cấp độ phân giải.",
"title": "Upscale (Nâng Cấp Chất Lượng Hình Ảnh)"
},
"howDoIGenerateAndSaveToTheGallery": {
"title": "Làm Sao Để Tôi Tạo Sinh Và Lưu Vào Thư Viện?",
"description": "Các bước để tạo sinh và lưu ảnh vào thư viện."
},
"howDoIEditOnTheCanvas": {
"description": "Hướng dẫn chỉnh sửa ảnh trực tiếp trên canvas.",
"title": "Làm Sao Để Tôi Chỉnh Sửa Trên Canvas?"
},
"howDoIUseControlNetsAndControlLayers": {
"title": "Làm Sao Để Tôi Dùng ControlNet và Layer Điều Khiển Được?",
"description": "Học cách áp dụng layer điều khiển được và controlnet vào ảnh của bạn."
},
"howDoIUseInpaintMasks": {
"title": "Làm Sao Để Tôi Dùng Lớp Phủ Inpaint?",
"description": "Cách áp dụng lớp phủ inpaint vào chỉnh sửa và thay đổi ảnh."
},
"howDoIOutpaint": {
"title": "Làm Sao Để Tôi Outpaint?",
"description": "Hướng dẫn outpaint bên ngoài viền ảnh gốc."
},
"creatingYourFirstImage": {
"description": "Giới thiệu về cách tạo ảnh từ ban đầu bằng công cụ Invoke.",
"title": "Tạo Hình Ảnh Đầu Tiên Của Bạn"
},
"usingControlLayersAndReferenceGuides": {
"description": "Học cách chỉ dẫn ảnh được tạo ra bằng layer điều khiển được và ảnh mẫu.",
"title": "Dùng Layer Điều Khiển Được và Chỉ Dẫn Mẫu"
},
"understandingImageToImageAndDenoising": {
"title": "Hiểu Rõ Trình Hình Ảnh Sang Hình Ảnh Và Trình Khử Nhiễu",
"description": "Tổng quan về trình biến đổi ảnh sang ảnh và trình khử nhiễu trong Invoke."
},
"exploringAIModelsAndConceptAdapters": {
"title": "Khám Phá Model AI Và Khái Niệm Về Adapter",
"description": "Đào sâu vào model AI và cách dùng những adapter để điều khiển một cách sáng tạo."
}
},
"controlCanvas": "Điều Khiển Canvas",
"watch": "Xem"
}
}

View File

@@ -668,7 +668,6 @@
"controlAdapterIncompatibleBaseModel": "Control Adapter的基础模型不兼容",
"ipAdapterIncompatibleBaseModel": "IP Adapter的基础模型不兼容",
"ipAdapterNoImageSelected": "未选择IP Adapter图像",
"rgNoRegion": "未选择区域",
"t2iAdapterIncompatibleBboxWidth": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}},边界框宽度为 {{width}}",
"t2iAdapterIncompatibleScaledBboxHeight": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}},缩放后的边界框高度为 {{height}}",
"t2iAdapterIncompatibleBboxHeight": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}},边界框高度为 {{height}}",

View File

@@ -1,3 +1,4 @@
import { useStore } from '@nanostores/react';
import { useAppStore } from 'app/store/storeHooks';
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
import { withResultAsync } from 'common/util/result';
@@ -9,6 +10,7 @@ import { imageDTOToImageObject } from 'features/controlLayers/store/util';
import { $imageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
import { sentImageToCanvas } from 'features/gallery/store/actions';
import { parseAndRecallAllMetadata } from 'features/metadata/util/handlers';
import { $hasTemplates } from 'features/nodes/store/nodesSlice';
import { $isWorkflowListMenuIsOpen } from 'features/nodes/store/workflowListMenu';
import { $isStylePresetsMenuOpen, activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice';
import { toast } from 'features/toast/toast';
@@ -51,6 +53,7 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
const { t } = useTranslation();
// Use a ref to ensure that we only perform the action once
const didInit = useRef(false);
const didParseOpenAPISchema = useStore($hasTemplates);
const store = useAppStore();
const { getAndLoadWorkflow } = useGetAndLoadLibraryWorkflow();
@@ -174,7 +177,7 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
);
useEffect(() => {
if (didInit.current || !action) {
if (didInit.current || !action || !didParseOpenAPISchema) {
return;
}
@@ -187,22 +190,29 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
case 'selectStylePreset':
handleSelectStylePreset(action.data.stylePresetId);
break;
case 'sendToCanvas':
handleSendToCanvas(action.data.imageName);
break;
case 'useAllParameters':
handleUseAllMetadata(action.data.imageName);
break;
case 'goToDestination':
handleGoToDestination(action.data.destination);
break;
default:
break;
}
}, [
handleSendToCanvas,
handleUseAllMetadata,
action,
handleLoadWorkflow,
handleSelectStylePreset,
handleGoToDestination,
handleLoadWorkflow,
didParseOpenAPISchema,
]);
};

View File

@@ -46,7 +46,7 @@ const REGION_TARGETS: Record<FocusRegionName, Set<HTMLElement>> = {
/**
* The currently-focused region or `null` if no region is focused.
*/
const $focusedRegion = atom<FocusRegionName | null>(null);
export const $focusedRegion = atom<FocusRegionName | null>(null);
/**
* A map of focus regions to atoms that indicate if that region is focused.

View File

@@ -1,6 +1,6 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { useClearQueue } from 'features/queue/components/ClearQueueConfirmationAlertDialog';
import { useCancelCurrentQueueItem } from 'features/queue/hooks/useCancelCurrentQueueItem';
import { useClearQueue } from 'features/queue/hooks/useClearQueue';
import { useInvoke } from 'features/queue/hooks/useInvoke';
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';

View File

@@ -63,7 +63,7 @@ export const CanvasAddEntityButtons = memo(() => {
justifyContent="flex-start"
leftIcon={<PiPlusBold />}
onClick={addRegionalGuidance}
isDisabled={isFLUX || isSD3}
isDisabled={isSD3}
>
{t('controlLayers.regionalGuidance')}
</Button>

View File

@@ -49,7 +49,7 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
<MenuItem icon={<PiPlusBold />} onClick={addInpaintMask}>
{t('controlLayers.inpaintMask')}
</MenuItem>
<MenuItem icon={<PiPlusBold />} onClick={addRegionalGuidance} isDisabled={isFLUX || isSD3}>
<MenuItem icon={<PiPlusBold />} onClick={addRegionalGuidance} isDisabled={isSD3}>
{t('controlLayers.regionalGuidance')}
</MenuItem>
<MenuItem icon={<PiPlusBold />} onClick={addRegionalReferenceImage} isDisabled={isFLUX || isSD3}>

View File

@@ -1,8 +1,10 @@
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import type { IPMethodV2 } from 'features/controlLayers/store/types';
import { isIPMethodV2 } from 'features/controlLayers/store/types';
import { selectSystemShouldEnableModelDescriptions } from 'features/system/store/systemSlice';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { assert } from 'tsafe';
@@ -14,13 +16,27 @@ type Props = {
export const IPAdapterMethod = memo(({ method, onChange }: Props) => {
const { t } = useTranslation();
const shouldShowModelDescriptions = useAppSelector(selectSystemShouldEnableModelDescriptions);
const options: { label: string; value: IPMethodV2 }[] = useMemo(
() => [
{ label: t('controlLayers.ipAdapterMethod.full'), value: 'full' },
{ label: t('controlLayers.ipAdapterMethod.style'), value: 'style' },
{ label: t('controlLayers.ipAdapterMethod.composition'), value: 'composition' },
{
label: t('controlLayers.ipAdapterMethod.full'),
value: 'full',
description: shouldShowModelDescriptions ? t('controlLayers.ipAdapterMethod.fullDesc') : undefined,
},
{
label: t('controlLayers.ipAdapterMethod.style'),
value: 'style',
description: shouldShowModelDescriptions ? t('controlLayers.ipAdapterMethod.styleDesc') : undefined,
},
{
label: t('controlLayers.ipAdapterMethod.composition'),
value: 'composition',
description: shouldShowModelDescriptions ? t('controlLayers.ipAdapterMethod.compositionDesc') : undefined,
},
],
[t]
[t, shouldShowModelDescriptions]
);
const _onChange = useCallback<ComboboxOnChange>(
(v) => {

View File

@@ -5,6 +5,7 @@ import { BeginEndStepPct } from 'features/controlLayers/components/common/BeginE
import { CanvasEntitySettingsWrapper } from 'features/controlLayers/components/common/CanvasEntitySettingsWrapper';
import { Weight } from 'features/controlLayers/components/common/Weight';
import { IPAdapterMethod } from 'features/controlLayers/components/IPAdapter/IPAdapterMethod';
import { IPAdapterSettingsEmptyState } from 'features/controlLayers/components/IPAdapter/IPAdapterSettingsEmptyState';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { usePullBboxIntoGlobalReferenceImage } from 'features/controlLayers/hooks/saveCanvasHooks';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
@@ -17,7 +18,7 @@ import {
referenceImageIPAdapterWeightChanged,
} from 'features/controlLayers/store/canvasSlice';
import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
import { selectCanvasSlice, selectEntity, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
import type { CanvasEntityIdentifier, CLIPVisionModelV2, IPMethodV2 } from 'features/controlLayers/store/types';
import type { SetGlobalReferenceImageDndTargetData } from 'features/dnd/dnd';
import { setGlobalReferenceImageDndTarget } from 'features/dnd/dnd';
@@ -35,7 +36,7 @@ const buildSelectIPAdapter = (entityIdentifier: CanvasEntityIdentifier<'referenc
(canvas) => selectEntityOrThrow(canvas, entityIdentifier, 'IPAdapterSettings').ipAdapter
);
export const IPAdapterSettings = memo(() => {
const IPAdapterSettingsContent = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const entityIdentifier = useEntityIdentifierContext('reference_image');
@@ -134,4 +135,25 @@ export const IPAdapterSettings = memo(() => {
);
});
IPAdapterSettingsContent.displayName = 'IPAdapterSettingsContent';
const buildSelectIPAdapterHasImage = (entityIdentifier: CanvasEntityIdentifier<'reference_image'>) =>
createSelector(selectCanvasSlice, (canvas) => {
const referenceImage = selectEntity(canvas, entityIdentifier);
return !!referenceImage && referenceImage.ipAdapter.image !== null;
});
export const IPAdapterSettings = memo(() => {
const entityIdentifier = useEntityIdentifierContext('reference_image');
const selectIPAdapterHasImage = useMemo(() => buildSelectIPAdapterHasImage(entityIdentifier), [entityIdentifier]);
const hasImage = useAppSelector(selectIPAdapterHasImage);
if (!hasImage) {
return <IPAdapterSettingsEmptyState />;
}
return <IPAdapterSettingsContent />;
});
IPAdapterSettings.displayName = 'IPAdapterSettings';

View File

@@ -0,0 +1,64 @@
import { Button, Flex, Text } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import type { SetGlobalReferenceImageDndTargetData } from 'features/dnd/dnd';
import { setGlobalReferenceImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { setGlobalReferenceImage } from 'features/imageActions/actions';
import { activeTabCanvasRightPanelChanged } from 'features/ui/store/uiSlice';
import { memo, useCallback, useMemo } from 'react';
import { Trans, useTranslation } from 'react-i18next';
import type { ImageDTO } from 'services/api/types';
export const IPAdapterSettingsEmptyState = memo(() => {
const { t } = useTranslation();
const entityIdentifier = useEntityIdentifierContext('reference_image');
const dispatch = useAppDispatch();
const isBusy = useCanvasIsBusy();
const onUpload = useCallback(
(imageDTO: ImageDTO) => {
setGlobalReferenceImage({ imageDTO, entityIdentifier, dispatch });
},
[dispatch, entityIdentifier]
);
const uploadApi = useImageUploadButton({ onUpload, allowMultiple: false });
const onClickGalleryButton = useCallback(() => {
dispatch(activeTabCanvasRightPanelChanged('gallery'));
}, [dispatch]);
const dndTargetData = useMemo<SetGlobalReferenceImageDndTargetData>(
() => setGlobalReferenceImageDndTarget.getData({ entityIdentifier }),
[entityIdentifier]
);
const components = useMemo(
() => ({
UploadButton: (
<Button isDisabled={isBusy} size="sm" variant="link" color="base.300" {...uploadApi.getUploadButtonProps()} />
),
GalleryButton: (
<Button onClick={onClickGalleryButton} isDisabled={isBusy} size="sm" variant="link" color="base.300" />
),
}),
[isBusy, onClickGalleryButton, uploadApi]
);
return (
<Flex flexDir="column" gap={3} position="relative" w="full" p={4}>
<Text textAlign="center" color="base.300">
<Trans i18nKey="controlLayers.referenceImageEmptyState" components={components} />
</Text>
<input {...uploadApi.getUploadInputProps()} />
<DndDropTarget
dndTarget={setGlobalReferenceImageDndTarget}
dndTargetData={dndTargetData}
label={t('controlLayers.useImage')}
isDisabled={isBusy}
/>
</Flex>
);
});
IPAdapterSettingsEmptyState.displayName = 'IPAdapterSettingsEmptyState';

View File

@@ -1,27 +1,28 @@
import { IconButton, Tooltip } from '@invoke-ai/ui-library';
import type { IconButtonProps } from '@invoke-ai/ui-library';
import { IconButton } from '@invoke-ai/ui-library';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiTrashSimpleFill } from 'react-icons/pi';
import { PiXBold } from 'react-icons/pi';
type Props = {
type Props = Omit<IconButtonProps, 'aria-label'> & {
onDelete: () => void;
};
export const RegionalGuidanceDeletePromptButton = memo(({ onDelete }: Props) => {
export const RegionalGuidanceDeletePromptButton = memo(({ onDelete, ...rest }: Props) => {
const { t } = useTranslation();
return (
<Tooltip label={t('controlLayers.deletePrompt')}>
<IconButton
variant="link"
aria-label={t('controlLayers.deletePrompt')}
icon={<PiTrashSimpleFill />}
onClick={onDelete}
flexGrow={0}
size="sm"
p={0}
colorScheme="error"
/>
</Tooltip>
<IconButton
tooltip={t('common.delete')}
variant="link"
aria-label={t('common.delete')}
icon={<PiXBold />}
onClick={onDelete}
flexGrow={0}
size="sm"
p={0}
colorScheme="error"
{...rest}
/>
);
});

View File

@@ -6,6 +6,7 @@ import { Weight } from 'features/controlLayers/components/common/Weight';
import { IPAdapterImagePreview } from 'features/controlLayers/components/IPAdapter/IPAdapterImagePreview';
import { IPAdapterMethod } from 'features/controlLayers/components/IPAdapter/IPAdapterMethod';
import { IPAdapterModel } from 'features/controlLayers/components/IPAdapter/IPAdapterModel';
import { RegionalGuidanceIPAdapterSettingsEmptyState } from 'features/controlLayers/components/RegionalGuidance/RegionalGuidanceIPAdapterSettingsEmptyState';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { usePullBboxIntoRegionalGuidanceReferenceImage } from 'features/controlLayers/hooks/saveCanvasHooks';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
@@ -19,12 +20,12 @@ import {
rgIPAdapterWeightChanged,
} from 'features/controlLayers/store/canvasSlice';
import { selectCanvasSlice, selectRegionalGuidanceReferenceImage } from 'features/controlLayers/store/selectors';
import type { CLIPVisionModelV2, IPMethodV2 } from 'features/controlLayers/store/types';
import type { CanvasEntityIdentifier, CLIPVisionModelV2, IPMethodV2 } from 'features/controlLayers/store/types';
import type { SetRegionalGuidanceReferenceImageDndTargetData } from 'features/dnd/dnd';
import { setRegionalGuidanceReferenceImageDndTarget } from 'features/dnd/dnd';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiBoundingBoxBold, PiTrashSimpleFill } from 'react-icons/pi';
import { PiBoundingBoxBold, PiXBold } from 'react-icons/pi';
import type { ImageDTO, IPAdapterModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
@@ -32,7 +33,7 @@ type Props = {
referenceImageId: string;
};
export const RegionalGuidanceIPAdapterSettings = memo(({ referenceImageId }: Props) => {
const RegionalGuidanceIPAdapterSettingsContent = memo(({ referenceImageId }: Props) => {
const entityIdentifier = useEntityIdentifierContext('regional_guidance');
const { t } = useTranslation();
const dispatch = useAppDispatch();
@@ -115,7 +116,7 @@ export const RegionalGuidanceIPAdapterSettings = memo(({ referenceImageId }: Pro
size="sm"
variant="link"
alignSelf="stretch"
icon={<PiTrashSimpleFill />}
icon={<PiXBold />}
tooltip={t('controlLayers.deleteReferenceImage')}
aria-label={t('controlLayers.deleteReferenceImage')}
onClick={onDeleteIPAdapter}
@@ -161,4 +162,31 @@ export const RegionalGuidanceIPAdapterSettings = memo(({ referenceImageId }: Pro
);
});
RegionalGuidanceIPAdapterSettingsContent.displayName = 'RegionalGuidanceIPAdapterSettingsContent';
const buildSelectIPAdapterHasImage = (
entityIdentifier: CanvasEntityIdentifier<'regional_guidance'>,
referenceImageId: string
) =>
createSelector(selectCanvasSlice, (canvas) => {
const referenceImage = selectRegionalGuidanceReferenceImage(canvas, entityIdentifier, referenceImageId);
return !!referenceImage && referenceImage.ipAdapter.image !== null;
});
export const RegionalGuidanceIPAdapterSettings = memo(({ referenceImageId }: Props) => {
const entityIdentifier = useEntityIdentifierContext('regional_guidance');
const selectIPAdapterHasImage = useMemo(
() => buildSelectIPAdapterHasImage(entityIdentifier, referenceImageId),
[entityIdentifier, referenceImageId]
);
const hasImage = useAppSelector(selectIPAdapterHasImage);
if (!hasImage) {
return <RegionalGuidanceIPAdapterSettingsEmptyState referenceImageId={referenceImageId} />;
}
return <RegionalGuidanceIPAdapterSettingsContent referenceImageId={referenceImageId} />;
});
RegionalGuidanceIPAdapterSettings.displayName = 'RegionalGuidanceIPAdapterSettings';

View File

@@ -0,0 +1,99 @@
import { Button, Flex, IconButton, Spacer, Text } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { rgIPAdapterDeleted } from 'features/controlLayers/store/canvasSlice';
import type { SetRegionalGuidanceReferenceImageDndTargetData } from 'features/dnd/dnd';
import { setRegionalGuidanceReferenceImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { setRegionalGuidanceReferenceImage } from 'features/imageActions/actions';
import { activeTabCanvasRightPanelChanged } from 'features/ui/store/uiSlice';
import { memo, useCallback, useMemo } from 'react';
import { Trans, useTranslation } from 'react-i18next';
import { PiXBold } from 'react-icons/pi';
import type { ImageDTO } from 'services/api/types';
type Props = {
referenceImageId: string;
};
export const RegionalGuidanceIPAdapterSettingsEmptyState = memo(({ referenceImageId }: Props) => {
const { t } = useTranslation();
const entityIdentifier = useEntityIdentifierContext('regional_guidance');
const dispatch = useAppDispatch();
const isBusy = useCanvasIsBusy();
const onUpload = useCallback(
(imageDTO: ImageDTO) => {
setRegionalGuidanceReferenceImage({ imageDTO, entityIdentifier, referenceImageId, dispatch });
},
[dispatch, entityIdentifier, referenceImageId]
);
const uploadApi = useImageUploadButton({ onUpload, allowMultiple: false });
const onClickGalleryButton = useCallback(() => {
dispatch(activeTabCanvasRightPanelChanged('gallery'));
}, [dispatch]);
const onDeleteIPAdapter = useCallback(() => {
dispatch(rgIPAdapterDeleted({ entityIdentifier, referenceImageId }));
}, [dispatch, entityIdentifier, referenceImageId]);
const dndTargetData = useMemo<SetRegionalGuidanceReferenceImageDndTargetData>(
() =>
setRegionalGuidanceReferenceImageDndTarget.getData({
entityIdentifier,
referenceImageId,
}),
[entityIdentifier, referenceImageId]
);
return (
<Flex flexDir="column" gap={2} position="relative" w="full">
<Flex alignItems="center" gap={2}>
<Text fontWeight="semibold" color="base.400">
{t('controlLayers.referenceImage')}
</Text>
<Spacer />
<IconButton
size="sm"
variant="link"
alignSelf="stretch"
icon={<PiXBold />}
tooltip={t('controlLayers.deleteReferenceImage')}
aria-label={t('controlLayers.deleteReferenceImage')}
onClick={onDeleteIPAdapter}
colorScheme="error"
/>
</Flex>
<Flex alignItems="center" gap={2} p={4}>
<Text textAlign="center" color="base.300">
<Trans
i18nKey="controlLayers.referenceImageEmptyState"
components={{
UploadButton: (
<Button
isDisabled={isBusy}
size="sm"
variant="link"
color="base.300"
{...uploadApi.getUploadButtonProps()}
/>
),
GalleryButton: (
<Button onClick={onClickGalleryButton} isDisabled={isBusy} size="sm" variant="link" color="base.300" />
),
}}
/>
</Text>
</Flex>
<input {...uploadApi.getUploadInputProps()} />
<DndDropTarget
dndTarget={setRegionalGuidanceReferenceImageDndTarget}
dndTargetData={dndTargetData}
label={t('controlLayers.useImage')}
isDisabled={isBusy}
/>
</Flex>
);
});
RegionalGuidanceIPAdapterSettingsEmptyState.displayName = 'RegionalGuidanceIPAdapterSettingsEmptyState';

View File

@@ -5,6 +5,7 @@ import { StagingAreaToolbarDiscardSelectedButton } from 'features/controlLayers/
import { StagingAreaToolbarImageCountButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarImageCountButton';
import { StagingAreaToolbarNextButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarNextButton';
import { StagingAreaToolbarPrevButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarPrevButton';
import { StagingAreaToolbarSaveAsMenu } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarSaveAsMenu';
import { StagingAreaToolbarSaveSelectedToGalleryButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarSaveSelectedToGalleryButton';
import { StagingAreaToolbarToggleShowResultsButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarToggleShowResultsButton';
import { memo } from 'react';
@@ -21,6 +22,7 @@ export const StagingAreaToolbar = memo(() => {
<StagingAreaToolbarAcceptButton />
<StagingAreaToolbarToggleShowResultsButton />
<StagingAreaToolbarSaveSelectedToGalleryButton />
<StagingAreaToolbarSaveAsMenu />
<StagingAreaToolbarDiscardSelectedButton />
<StagingAreaToolbarDiscardAllButton />
</ButtonGroup>

View File

@@ -0,0 +1,136 @@
import { IconButton, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
import { useAppStore } from 'app/store/nanostores/store';
import { useAppSelector } from 'app/store/storeHooks';
import { NewLayerIcon } from 'features/controlLayers/components/common/icons';
import { selectSelectedImage } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { createNewCanvasEntityFromImage } from 'features/imageActions/actions';
import { toast } from 'features/toast/toast';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiDotsThreeBold } from 'react-icons/pi';
import { imageDTOToFile, uploadImage } from 'services/api/endpoints/images';
const uploadImageArg = { image_category: 'general', is_intermediate: true, silent: true } as const;
export const StagingAreaToolbarSaveAsMenu = memo(() => {
const { t } = useTranslation();
const selectedImage = useAppSelector(selectSelectedImage);
const store = useAppStore();
const onClickNewRasterLayerFromImage = useCallback(async () => {
if (!selectedImage) {
return;
}
const { dispatch, getState } = store;
const file = await imageDTOToFile(selectedImage.imageDTO);
const imageDTO = await uploadImage({ file, ...uploadImageArg });
createNewCanvasEntityFromImage({
imageDTO,
type: 'raster_layer',
dispatch,
getState,
overrides: { isEnabled: false }, // We are adding the layer while staging, it should be disabled by default
});
toast({
id: 'SENT_TO_CANVAS',
title: t('toast.sentToCanvas'),
status: 'success',
});
}, [selectedImage, store, t]);
const onClickNewControlLayerFromImage = useCallback(async () => {
if (!selectedImage) {
return;
}
const { dispatch, getState } = store;
const file = await imageDTOToFile(selectedImage.imageDTO);
const imageDTO = await uploadImage({ file, ...uploadImageArg });
createNewCanvasEntityFromImage({
imageDTO,
type: 'control_layer',
dispatch,
getState,
overrides: { isEnabled: false }, // We are adding the layer while staging, it should be disabled by default
});
toast({
id: 'SENT_TO_CANVAS',
title: t('toast.sentToCanvas'),
status: 'success',
});
}, [selectedImage, store, t]);
const onClickNewInpaintMaskFromImage = useCallback(async () => {
if (!selectedImage) {
return;
}
const { dispatch, getState } = store;
const file = await imageDTOToFile(selectedImage.imageDTO);
const imageDTO = await uploadImage({ file, ...uploadImageArg });
createNewCanvasEntityFromImage({
imageDTO,
type: 'inpaint_mask',
dispatch,
getState,
overrides: { isEnabled: false }, // We are adding the layer while staging, it should be disabled by default
});
toast({
id: 'SENT_TO_CANVAS',
title: t('toast.sentToCanvas'),
status: 'success',
});
}, [selectedImage, store, t]);
const onClickNewRegionalGuidanceFromImage = useCallback(async () => {
if (!selectedImage) {
return;
}
const { dispatch, getState } = store;
const file = await imageDTOToFile(selectedImage.imageDTO);
const imageDTO = await uploadImage({ file, ...uploadImageArg });
createNewCanvasEntityFromImage({
imageDTO,
type: 'regional_guidance',
dispatch,
getState,
overrides: { isEnabled: false }, // We are adding the layer while staging, it should be disabled by default
});
toast({
id: 'SENT_TO_CANVAS',
title: t('toast.sentToCanvas'),
status: 'success',
});
}, [selectedImage, store, t]);
return (
<Menu>
<MenuButton
as={IconButton}
aria-label={t('controlLayers.newLayerFromImage')}
tooltip={t('controlLayers.newLayerFromImage')}
icon={<PiDotsThreeBold />}
colorScheme="invokeBlue"
isDisabled={!selectedImage}
/>
<MenuList>
<MenuItem icon={<NewLayerIcon />} onClickCapture={onClickNewInpaintMaskFromImage} isDisabled={!selectedImage}>
{t('controlLayers.inpaintMask')}
</MenuItem>
<MenuItem
icon={<NewLayerIcon />}
onClickCapture={onClickNewRegionalGuidanceFromImage}
isDisabled={!selectedImage}
>
{t('controlLayers.regionalGuidance')}
</MenuItem>
<MenuItem icon={<NewLayerIcon />} onClickCapture={onClickNewControlLayerFromImage} isDisabled={!selectedImage}>
{t('controlLayers.controlLayer')}
</MenuItem>
<MenuItem icon={<NewLayerIcon />} onClickCapture={onClickNewRasterLayerFromImage} isDisabled={!selectedImage}>
{t('controlLayers.rasterLayer')}
</MenuItem>
</MenuList>
</Menu>
);
});
StagingAreaToolbarSaveAsMenu.displayName = 'StagingAreaToolbarSaveAsMenu';

View File

@@ -7,7 +7,7 @@ import { toast } from 'features/toast/toast';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiFloppyDiskBold } from 'react-icons/pi';
import { uploadImage } from 'services/api/endpoints/images';
import { imageDTOToFile, uploadImage } from 'services/api/endpoints/images';
const TOAST_ID = 'SAVE_STAGING_AREA_IMAGE_TO_GALLERY';
@@ -25,11 +25,8 @@ export const StagingAreaToolbarSaveSelectedToGalleryButton = memo(() => {
// To save the image to gallery, we will download it and re-upload it. This allows the user to delete the image
// the gallery without borking the canvas, which may need this image to exist.
const result = await withResultAsync(async () => {
// Download the image
const res = await fetch(selectedImage.imageDTO.image_url);
const blob = await res.blob();
// Create a new file with the same name, which we will upload
const file = new File([blob], `copy_of_${selectedImage.imageDTO.image_name}`, { type: 'image/png' });
const file = await imageDTOToFile(selectedImage.imageDTO);
await uploadImage({
file,

View File

@@ -1,6 +1,7 @@
import { Flex } from '@invoke-ai/ui-library';
import { CanvasEntityDeleteButton } from 'features/controlLayers/components/common/CanvasEntityDeleteButton';
import { CanvasEntityEnabledToggle } from 'features/controlLayers/components/common/CanvasEntityEnabledToggle';
import { CanvasEntityHeaderWarnings } from 'features/controlLayers/components/common/CanvasEntityHeaderWarnings';
import { CanvasEntityIsBookmarkedForQuickSwitchToggle } from 'features/controlLayers/components/common/CanvasEntityIsBookmarkedForQuickSwitchToggle';
import { CanvasEntityIsLockedToggle } from 'features/controlLayers/components/common/CanvasEntityIsLockedToggle';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
@@ -11,6 +12,7 @@ export const CanvasEntityHeaderCommonActions = memo(() => {
return (
<Flex alignSelf="stretch">
<CanvasEntityHeaderWarnings />
<CanvasEntityIsBookmarkedForQuickSwitchToggle />
{entityIdentifier.type !== 'reference_image' && <CanvasEntityIsLockedToggle />}
<CanvasEntityEnabledToggle />

View File

@@ -0,0 +1,101 @@
import { Flex, IconButton, ListItem, Text, UnorderedList } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { EMPTY_ARRAY } from 'app/store/constants';
import { useAppSelector } from 'app/store/storeHooks';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useEntityIsEnabled } from 'features/controlLayers/hooks/useEntityIsEnabled';
import { selectModel } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import {
getControlLayerWarnings,
getGlobalReferenceImageWarnings,
getInpaintMaskWarnings,
getRasterLayerWarnings,
getRegionalGuidanceWarnings,
} from 'features/controlLayers/store/validators';
import type { TFunction } from 'i18next';
import { upperFirst } from 'lodash-es';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiWarningBold } from 'react-icons/pi';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
const buildSelectWarnings = (entityIdentifier: CanvasEntityIdentifier, t: TFunction) => {
return createSelector(selectCanvasSlice, selectModel, (canvas, model) => {
// This component is used within a <CanvasEntityStateGate /> so we can safely assume that the entity exists.
// Should never throw.
const entity = selectEntityOrThrow(canvas, entityIdentifier, 'CanvasEntityHeaderWarnings');
let warnings: string[] = [];
const entityType = entity.type;
if (entityType === 'control_layer') {
warnings = getControlLayerWarnings(entity, model);
} else if (entityType === 'regional_guidance') {
warnings = getRegionalGuidanceWarnings(entity, model);
} else if (entityType === 'inpaint_mask') {
warnings = getInpaintMaskWarnings(entity, model);
} else if (entityType === 'raster_layer') {
warnings = getRasterLayerWarnings(entity, model);
} else if (entityType === 'reference_image') {
warnings = getGlobalReferenceImageWarnings(entity, model);
} else {
assert<Equals<typeof entityType, never>>(false, 'Unexpected entity type');
}
// Return a stable reference if there are no warnings
if (warnings.length === 0) {
return EMPTY_ARRAY;
}
return warnings.map((w) => t(w)).map(upperFirst);
});
};
export const CanvasEntityHeaderWarnings = memo(() => {
const entityIdentifier = useEntityIdentifierContext();
const { t } = useTranslation();
const isEnabled = useEntityIsEnabled(entityIdentifier);
const selectWarnings = useMemo(() => buildSelectWarnings(entityIdentifier, t), [entityIdentifier, t]);
const warnings = useAppSelector(selectWarnings);
if (warnings.length === 0) {
return null;
}
return (
// Using IconButton here bc it matches the styling of the actual buttons in the header without any fanagling, but
// it's not a button
<IconButton
as="span"
size="sm"
variant="link"
alignSelf="stretch"
aria-label="warnings"
tooltip={<TooltipContent warnings={warnings} />}
icon={<PiWarningBold />}
colorScheme="warning"
isDisabled={!isEnabled}
/>
);
});
CanvasEntityHeaderWarnings.displayName = 'CanvasEntityHeaderWarnings';
const TooltipContent = memo((props: { warnings: string[] }) => {
const { t } = useTranslation();
return (
<Flex flexDir="column">
<Text>{t('controlLayers.warnings.problemsFound')}:</Text>
<UnorderedList>
{props.warnings.map((warning, index) => (
<ListItem key={index}>{warning}</ListItem>
))}
</UnorderedList>
</Flex>
);
});
TooltipContent.displayName = 'TooltipContent';

View File

@@ -29,7 +29,13 @@ import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import { isControlNetOrT2IAdapterModelConfig, isIPAdapterModelConfig } from 'services/api/types';
/** @knipignore */
/**
* Selects the default control adapter configuration based on the model configurations and the base.
*
* Be sure to clone the output of this selector before modifying it!
*
* @knipignore
*/
export const selectDefaultControlAdapter = createSelector(
selectModelConfigsQuery,
selectBase,
@@ -52,6 +58,11 @@ export const selectDefaultControlAdapter = createSelector(
}
);
/**
* Selects the default IP adapter configuration based on the model configurations and the base.
*
* Be sure to clone the output of this selector before modifying it!
*/
export const selectDefaultIPAdapter = createSelector(
selectModelConfigsQuery,
selectBase,
@@ -117,7 +128,9 @@ export const useAddRegionalReferenceImage = () => {
const func = useCallback(() => {
const overrides: Partial<CanvasRegionalGuidanceState> = {
referenceImages: [{ id: getPrefixedId('regional_guidance_reference_image'), ipAdapter: defaultIPAdapter }],
referenceImages: [
{ id: getPrefixedId('regional_guidance_reference_image'), ipAdapter: deepClone(defaultIPAdapter) },
],
};
dispatch(rgAdded({ isSelected: true, overrides }));
}, [defaultIPAdapter, dispatch]);
@@ -129,7 +142,7 @@ export const useAddGlobalReferenceImage = () => {
const dispatch = useAppDispatch();
const defaultIPAdapter = useAppSelector(selectDefaultIPAdapter);
const func = useCallback(() => {
const overrides = { ipAdapter: defaultIPAdapter };
const overrides = { ipAdapter: deepClone(defaultIPAdapter) };
dispatch(referenceImageAdded({ isSelected: true, overrides }));
}, [defaultIPAdapter, dispatch]);
@@ -140,7 +153,7 @@ export const useAddRegionalGuidanceIPAdapter = (entityIdentifier: CanvasEntityId
const dispatch = useAppDispatch();
const defaultIPAdapter = useAppSelector(selectDefaultIPAdapter);
const func = useCallback(() => {
dispatch(rgIPAdapterAdded({ entityIdentifier, overrides: { ipAdapter: defaultIPAdapter } }));
dispatch(rgIPAdapterAdded({ entityIdentifier, overrides: { ipAdapter: deepClone(defaultIPAdapter) } }));
}, [defaultIPAdapter, dispatch, entityIdentifier]);
return func;

View File

@@ -9,6 +9,7 @@ import {
getEmptyRect,
getKonvaNodeDebugAttrs,
getPrefixedId,
offsetCoord,
} from 'features/controlLayers/konva/util';
import { selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
import type { Coordinate, Rect, RectWithRotation } from 'features/controlLayers/store/types';
@@ -558,6 +559,25 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
this.manager.stateApi.setEntityPosition({ entityIdentifier: this.parent.entityIdentifier, position });
};
nudgeBy = (offset: Coordinate) => {
// We can immediately move both the proxy rect and layer objects so we don't have to wait for a redux round-trip,
// which can take up to 2ms in my testing. This is optional, but can make the interaction feel more responsive,
// especially on lower-end devices.
// Get the relative position of the layer's objects, according to konva
const position = this.konva.proxyRect.position();
// Offset the position by the nudge amount
const newPosition = offsetCoord(position, offset);
// Set the new position of the proxy rect - this doesn't move the layer objects - only the outline rect
this.konva.proxyRect.setAttrs(newPosition);
// Sync the layer objects with the proxy rect - moves them to the new position
this.syncObjectGroupWithProxyRect();
// Push to redux. The state change will do a round-trip, and eventually make it back to the canvas classes, at
// which point the layer will be moved to the new position.
this.manager.stateApi.moveEntityBy({ entityIdentifier: this.parent.entityIdentifier, offset });
this.log.trace({ offset }, 'Nudged');
};
syncObjectGroupWithProxyRect = () => {
this.parent.renderer.konva.objectGroup.setAttrs({
x: this.konva.proxyRect.x(),

View File

@@ -20,7 +20,8 @@ import {
controlLayerAdded,
entityBrushLineAdded,
entityEraserLineAdded,
entityMoved,
entityMovedBy,
entityMovedTo,
entityRasterized,
entityRectAdded,
entityReset,
@@ -40,7 +41,8 @@ import type {
EntityBrushLineAddedPayload,
EntityEraserLineAddedPayload,
EntityIdentifierPayload,
EntityMovedPayload,
EntityMovedByPayload,
EntityMovedToPayload,
EntityRasterizedPayload,
EntityRectAddedPayload,
Rect,
@@ -139,8 +141,15 @@ export class CanvasStateApiModule extends CanvasModuleBase {
/**
* Updates an entity's position, pushing state to redux.
*/
setEntityPosition = (arg: EntityMovedPayload) => {
this.store.dispatch(entityMoved(arg));
setEntityPosition = (arg: EntityMovedToPayload) => {
this.store.dispatch(entityMovedTo(arg));
};
/**
* Moves an entity by the give offset, pushing state to redux.
*/
moveEntityBy = (arg: EntityMovedByPayload) => {
this.store.dispatch(entityMovedBy(arg));
};
/**

View File

@@ -1,9 +1,24 @@
import { $focusedRegion } from 'common/hooks/focus';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
import type { CanvasToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasToolModule';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import type { Coordinate } from 'features/controlLayers/store/types';
import type { Logger } from 'roarr';
type CanvasMoveToolModuleConfig = {
/**
* The number of pixels to nudge the entity by when moving with the arrow keys.
*/
NUDGE_PX: number;
};
const DEFAULT_CONFIG: CanvasMoveToolModuleConfig = {
NUDGE_PX: 1,
};
type NudgeKey = 'ArrowLeft' | 'ArrowRight' | 'ArrowUp' | 'ArrowDown';
export class CanvasMoveToolModule extends CanvasModuleBase {
readonly type = 'move_tool';
readonly id: string;
@@ -12,6 +27,9 @@ export class CanvasMoveToolModule extends CanvasModuleBase {
readonly manager: CanvasManager;
readonly log: Logger;
config: CanvasMoveToolModuleConfig = DEFAULT_CONFIG;
nudgeOffsets: Record<NudgeKey, Coordinate>;
constructor(parent: CanvasToolModule) {
super();
this.id = getPrefixedId(this.type);
@@ -19,8 +37,18 @@ export class CanvasMoveToolModule extends CanvasModuleBase {
this.manager = this.parent.manager;
this.path = this.manager.buildPath(this);
this.log = this.manager.buildLogger(this);
this.log.debug('Creating module');
this.nudgeOffsets = {
ArrowLeft: { x: -this.config.NUDGE_PX, y: 0 },
ArrowRight: { x: this.config.NUDGE_PX, y: 0 },
ArrowUp: { x: 0, y: -this.config.NUDGE_PX },
ArrowDown: { x: 0, y: this.config.NUDGE_PX },
};
}
isNudgeKey(key: string): key is NudgeKey {
return this.nudgeOffsets[key as NudgeKey] !== undefined;
}
syncCursorStyle = () => {
@@ -32,4 +60,45 @@ export class CanvasMoveToolModule extends CanvasModuleBase {
selectedEntity.transformer.syncCursorStyle();
}
};
nudge = (nudgeKey: NudgeKey) => {
if ($focusedRegion.get() !== 'canvas') {
return;
}
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
if (!selectedEntity) {
return;
}
if (
selectedEntity.$isDisabled.get() ||
selectedEntity.$isEmpty.get() ||
selectedEntity.$isLocked.get() ||
selectedEntity.$isEntityTypeHidden.get()
) {
return;
}
const isBusy = this.manager.$isBusy.get();
const isMoveToolSelected = this.parent.$tool.get() === 'move';
const isThisEntityTransforming = this.manager.stateApi.$transformingAdapter.get() === selectedEntity;
if (isBusy) {
// When the canvas is busy, we shouldn't allow nudging - except when the canvas is busy transforming the selected
// entity. Nudging is allowed during transformation, regardless of the selected tool.
if (!isThisEntityTransforming) {
return;
}
} else {
// Otherwise, the canvas is not busy, and we should only allow nudging when the move tool is selected.
if (!isMoveToolSelected) {
return;
}
}
const offset = this.nudgeOffsets[nudgeKey];
selectedEntity.transformer.nudgeBy(offset);
};
}

View File

@@ -528,11 +528,16 @@ export class CanvasToolModule extends CanvasModuleBase {
};
onKeyDown = (e: KeyboardEvent) => {
if (e.repeat) {
if (e.target instanceof HTMLInputElement || e.target instanceof HTMLTextAreaElement) {
return;
}
if (e.target instanceof HTMLInputElement || e.target instanceof HTMLTextAreaElement) {
// Handle nudging - must be before repeat, as we may want to catch repeating keys
if (this.tools.move.isNudgeKey(e.key)) {
this.tools.move.nudge(e.key);
}
if (e.repeat) {
return;
}

View File

@@ -18,6 +18,7 @@ import type {
CanvasEntityType,
CanvasInpaintMaskState,
CanvasMetadata,
EntityMovedByPayload,
FillStyle,
RegionalGuidanceReferenceImageState,
RgbColor,
@@ -51,7 +52,7 @@ import type {
EntityBrushLineAddedPayload,
EntityEraserLineAddedPayload,
EntityIdentifierPayload,
EntityMovedPayload,
EntityMovedToPayload,
EntityRasterizedPayload,
EntityRectAddedPayload,
IPMethodV2,
@@ -1201,7 +1202,7 @@ export const canvasSlice = createSlice({
}
entity.fill.style = style;
},
entityMoved: (state, action: PayloadAction<EntityMovedPayload>) => {
entityMovedTo: (state, action: PayloadAction<EntityMovedToPayload>) => {
const { entityIdentifier, position } = action.payload;
const entity = selectEntity(state, entityIdentifier);
if (!entity) {
@@ -1212,6 +1213,20 @@ export const canvasSlice = createSlice({
entity.position = position;
}
},
entityMovedBy: (state, action: PayloadAction<EntityMovedByPayload>) => {
const { entityIdentifier, offset } = action.payload;
const entity = selectEntity(state, entityIdentifier);
if (!entity) {
return;
}
if (!isRenderableEntity(entity)) {
return;
}
entity.position.x += offset.x;
entity.position.y += offset.y;
},
entityRasterized: (state, action: PayloadAction<EntityRasterizedPayload>) => {
const { entityIdentifier, imageObject, position, replaceObjects, isSelected } = action.payload;
const entity = selectEntity(state, entityIdentifier);
@@ -1505,7 +1520,8 @@ export const {
entityIsLockedToggled,
entityFillColorChanged,
entityFillStyleChanged,
entityMoved,
entityMovedTo,
entityMovedBy,
entityDuplicated,
entityRasterized,
entityBrushLineAdded,

View File

@@ -83,7 +83,7 @@ const initialState: ParamsState = {
canvasCoherenceMode: 'Gaussian Blur',
canvasCoherenceMinDenoise: 0,
canvasCoherenceEdgeSize: 16,
infillMethod: 'patchmatch',
infillMethod: 'lama',
infillTileSize: 32,
infillPatchmatchDownscaleSize: 1,
infillColorValue: { r: 0, g: 0, b: 0, a: 1 },

View File

@@ -439,7 +439,8 @@ export type EntityIdentifierPayload<
entityIdentifier: CanvasEntityIdentifier<U>;
} & T;
export type EntityMovedPayload = EntityIdentifierPayload<{ position: Coordinate }>;
export type EntityMovedToPayload = EntityIdentifierPayload<{ position: Coordinate }>;
export type EntityMovedByPayload = EntityIdentifierPayload<{ offset: Coordinate }>;
export type EntityBrushLineAddedPayload = EntityIdentifierPayload<{
brushLine: CanvasBrushLineState | CanvasBrushLineWithPressureState;
}>;

View File

@@ -0,0 +1,153 @@
import type {
CanvasControlLayerState,
CanvasInpaintMaskState,
CanvasRasterLayerState,
CanvasReferenceImageState,
CanvasRegionalGuidanceState,
} from 'features/controlLayers/store/types';
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
const WARNINGS = {
UNSUPPORTED_MODEL: 'controlLayers.warnings.unsupportedModel',
RG_NO_PROMPTS_OR_IP_ADAPTERS: 'controlLayers.warnings.rgNoPromptsOrIPAdapters',
RG_NEGATIVE_PROMPT_NOT_SUPPORTED: 'controlLayers.warnings.rgNegativePromptNotSupported',
RG_REFERENCE_IMAGES_NOT_SUPPORTED: 'controlLayers.warnings.rgReferenceImagesNotSupported',
RG_AUTO_NEGATIVE_NOT_SUPPORTED: 'controlLayers.warnings.rgAutoNegativeNotSupported',
RG_NO_REGION: 'controlLayers.warnings.rgNoRegion',
IP_ADAPTER_NO_MODEL_SELECTED: 'controlLayers.warnings.ipAdapterNoModelSelected',
IP_ADAPTER_INCOMPATIBLE_BASE_MODEL: 'controlLayers.warnings.ipAdapterIncompatibleBaseModel',
IP_ADAPTER_NO_IMAGE_SELECTED: 'controlLayers.warnings.ipAdapterNoImageSelected',
CONTROL_ADAPTER_NO_MODEL_SELECTED: 'controlLayers.warnings.controlAdapterNoModelSelected',
CONTROL_ADAPTER_INCOMPATIBLE_BASE_MODEL: 'controlLayers.warnings.controlAdapterIncompatibleBaseModel',
CONTROL_ADAPTER_NO_CONTROL: 'controlLayers.warnings.controlAdapterNoControl',
} as const;
type WarningTKey = (typeof WARNINGS)[keyof typeof WARNINGS];
export const getRegionalGuidanceWarnings = (
entity: CanvasRegionalGuidanceState,
model: ParameterModel | null
): WarningTKey[] => {
const warnings: WarningTKey[] = [];
if (entity.objects.length === 0) {
// Layer is in empty state
warnings.push(WARNINGS.RG_NO_REGION);
}
if (entity.positivePrompt === null && entity.negativePrompt === null && entity.referenceImages.length === 0) {
// Must have at least 1 prompt or IP Adapter
warnings.push(WARNINGS.RG_NO_PROMPTS_OR_IP_ADAPTERS);
}
if (model) {
if (model.base === 'sd-3' || model.base === 'sd-2') {
// Unsupported model architecture
warnings.push(WARNINGS.UNSUPPORTED_MODEL);
} else if (model.base === 'flux') {
// Some features are not supported for flux models
if (entity.negativePrompt !== null) {
warnings.push(WARNINGS.RG_NEGATIVE_PROMPT_NOT_SUPPORTED);
}
if (entity.referenceImages.length > 0) {
warnings.push(WARNINGS.RG_REFERENCE_IMAGES_NOT_SUPPORTED);
}
if (entity.autoNegative) {
warnings.push(WARNINGS.RG_AUTO_NEGATIVE_NOT_SUPPORTED);
}
} else {
entity.referenceImages.forEach(({ ipAdapter }) => {
if (!ipAdapter.model) {
// No model selected
warnings.push(WARNINGS.IP_ADAPTER_NO_MODEL_SELECTED);
} else if (ipAdapter.model.base !== model.base) {
// Supported model architecture but doesn't match
warnings.push(WARNINGS.IP_ADAPTER_INCOMPATIBLE_BASE_MODEL);
}
if (!ipAdapter.image) {
// No image selected
warnings.push(WARNINGS.IP_ADAPTER_NO_IMAGE_SELECTED);
}
});
}
}
return warnings;
};
export const getGlobalReferenceImageWarnings = (
entity: CanvasReferenceImageState,
model: ParameterModel | null
): WarningTKey[] => {
const warnings: WarningTKey[] = [];
if (!entity.ipAdapter.model) {
// No model selected
warnings.push(WARNINGS.IP_ADAPTER_NO_MODEL_SELECTED);
} else if (model) {
if (model.base === 'sd-3' || model.base === 'sd-2') {
// Unsupported model architecture
warnings.push(WARNINGS.UNSUPPORTED_MODEL);
} else if (entity.ipAdapter.model.base !== model.base) {
// Supported model architecture but doesn't match
warnings.push(WARNINGS.IP_ADAPTER_INCOMPATIBLE_BASE_MODEL);
}
}
if (!entity.ipAdapter.image) {
// No image selected
warnings.push(WARNINGS.IP_ADAPTER_NO_IMAGE_SELECTED);
}
return warnings;
};
export const getControlLayerWarnings = (
entity: CanvasControlLayerState,
model: ParameterModel | null
): WarningTKey[] => {
const warnings: WarningTKey[] = [];
if (entity.objects.length === 0) {
// Layer is in empty state
warnings.push(WARNINGS.CONTROL_ADAPTER_NO_CONTROL);
}
if (!entity.controlAdapter.model) {
// No model selected
warnings.push(WARNINGS.CONTROL_ADAPTER_NO_MODEL_SELECTED);
} else if (model) {
if (model.base === 'sd-3' || model.base === 'sd-2') {
// Unsupported model architecture
warnings.push(WARNINGS.UNSUPPORTED_MODEL);
} else if (entity.controlAdapter.model.base !== model.base) {
// Supported model architecture but doesn't match
warnings.push(WARNINGS.CONTROL_ADAPTER_INCOMPATIBLE_BASE_MODEL);
}
}
return warnings;
};
export const getRasterLayerWarnings = (
_entity: CanvasRasterLayerState,
_model: ParameterModel | null
): WarningTKey[] => {
const warnings: WarningTKey[] = [];
// There are no warnings at the moment for raster layers.
return warnings;
};
export const getInpaintMaskWarnings = (
_entity: CanvasInpaintMaskState,
_model: ParameterModel | null
): WarningTKey[] => {
const warnings: WarningTKey[] = [];
// There are no warnings at the moment for inpaint masks.
return warnings;
};

View File

@@ -77,6 +77,32 @@ export const ImageMenuItemNewLayerFromImageSubMenu = memo(() => {
});
}, [imageDTO, imageViewer, store, t]);
const onClickNewRegionalReferenceImageFromImage = useCallback(() => {
const { dispatch, getState } = store;
createNewCanvasEntityFromImage({ imageDTO, type: 'reference_image', dispatch, getState });
dispatch(sentImageToCanvas());
dispatch(setActiveTab('canvas'));
imageViewer.close();
toast({
id: 'SENT_TO_CANVAS',
title: t('toast.sentToCanvas'),
status: 'success',
});
}, [imageDTO, imageViewer, store, t]);
const onClickNewGlobalReferenceImageFromImage = useCallback(() => {
const { dispatch, getState } = store;
createNewCanvasEntityFromImage({ imageDTO, type: 'regional_guidance_with_reference_image', dispatch, getState });
dispatch(sentImageToCanvas());
dispatch(setActiveTab('canvas'));
imageViewer.close();
toast({
id: 'SENT_TO_CANVAS',
title: t('toast.sentToCanvas'),
status: 'success',
});
}, [imageDTO, imageViewer, store, t]);
return (
<MenuItem {...subMenu.parentMenuItemProps} icon={<PiPlusBold />}>
<Menu {...subMenu.menuProps}>
@@ -104,6 +130,20 @@ export const ImageMenuItemNewLayerFromImageSubMenu = memo(() => {
<MenuItem icon={<NewLayerIcon />} onClickCapture={onClickNewRasterLayerFromImage} isDisabled={isBusy}>
{t('controlLayers.rasterLayer')}
</MenuItem>
<MenuItem
icon={<NewLayerIcon />}
onClickCapture={onClickNewRegionalReferenceImageFromImage}
isDisabled={isBusy}
>
{t('controlLayers.referenceImageRegional')}
</MenuItem>
<MenuItem
icon={<NewLayerIcon />}
onClickCapture={onClickNewGlobalReferenceImageFromImage}
isDisabled={isBusy}
>
{t('controlLayers.referenceImageGlobal')}
</MenuItem>
</MenuList>
</Menu>
</MenuItem>

View File

@@ -20,6 +20,7 @@ import { selectBboxModelBase, selectBboxRect } from 'features/controlLayers/stor
import type {
CanvasControlLayerState,
CanvasEntityIdentifier,
CanvasEntityState,
CanvasEntityType,
CanvasInpaintMaskState,
CanvasRasterLayerState,
@@ -134,14 +135,16 @@ export const createNewCanvasEntityFromImage = (arg: {
type: CanvasEntityType | 'regional_guidance_with_reference_image';
dispatch: AppDispatch;
getState: () => RootState;
overrides?: Partial<Pick<CanvasEntityState, 'isEnabled' | 'isLocked' | 'name'>>;
}) => {
const { type, imageDTO, dispatch, getState } = arg;
const { type, imageDTO, dispatch, getState, overrides: _overrides } = arg;
const state = getState();
const imageObject = imageDTOToImageObject(imageDTO);
const { x, y } = selectBboxRect(state);
const overrides = {
objects: [imageObject],
position: { x, y },
..._overrides,
};
switch (type) {
case 'raster_layer': {
@@ -166,13 +169,13 @@ export const createNewCanvasEntityFromImage = (arg: {
break;
}
case 'reference_image': {
const ipAdapter = selectDefaultIPAdapter(getState());
const ipAdapter = deepClone(selectDefaultIPAdapter(getState()));
ipAdapter.image = imageDTOToImageWithDims(imageDTO);
dispatch(referenceImageAdded({ overrides: { ipAdapter }, isSelected: true }));
break;
}
case 'regional_guidance_with_reference_image': {
const ipAdapter = selectDefaultIPAdapter(getState());
const ipAdapter = deepClone(selectDefaultIPAdapter(getState()));
ipAdapter.image = imageDTOToImageWithDims(imageDTO);
const referenceImages = [{ id: getPrefixedId('regional_guidance_reference_image'), ipAdapter }];
dispatch(rgAdded({ overrides: { referenceImages }, isSelected: true }));
@@ -288,14 +291,14 @@ export const newCanvasFromImage = (arg: {
break;
}
case 'reference_image': {
const ipAdapter = selectDefaultIPAdapter(getState());
const ipAdapter = deepClone(selectDefaultIPAdapter(getState()));
ipAdapter.image = imageDTOToImageWithDims(imageDTO);
dispatch(canvasReset());
dispatch(referenceImageAdded({ overrides: { ipAdapter }, isSelected: true }));
break;
}
case 'regional_guidance_with_reference_image': {
const ipAdapter = selectDefaultIPAdapter(getState());
const ipAdapter = deepClone(selectDefaultIPAdapter(getState()));
ipAdapter.image = imageDTOToImageWithDims(imageDTO);
const referenceImages = [{ id: getPrefixedId('regional_guidance_reference_image'), ipAdapter }];
dispatch(canvasReset());

View File

@@ -1,35 +1,41 @@
import { logger } from 'app/logging/logger';
import { withResultAsync } from 'common/util/result';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import type {
CanvasControlLayerState,
ControlNetConfig,
Rect,
T2IAdapterConfig,
} from 'features/controlLayers/store/types';
import type { CanvasControlLayerState, Rect } from 'features/controlLayers/store/types';
import { getControlLayerWarnings } from 'features/controlLayers/store/validators';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
import { serializeError } from 'serialize-error';
import type { BaseModelType, ImageDTO, Invocation } from 'services/api/types';
import type { ImageDTO, Invocation } from 'services/api/types';
import { assert } from 'tsafe';
const log = logger('system');
type AddControlNetsArg = {
manager: CanvasManager;
entities: CanvasControlLayerState[];
g: Graph;
rect: Rect;
collector: Invocation<'collect'>;
model: ParameterModel;
};
type AddControlNetsResult = {
addedControlNets: number;
};
export const addControlNets = async (
manager: CanvasManager,
layers: CanvasControlLayerState[],
g: Graph,
rect: Rect,
collector: Invocation<'collect'>,
base: BaseModelType
): Promise<AddControlNetsResult> => {
const validControlLayers = layers
.filter((layer) => layer.isEnabled)
.filter((layer) => isValidControlAdapter(layer.controlAdapter, base))
.filter((layer) => layer.controlAdapter.type === 'controlnet');
export const addControlNets = async ({
manager,
entities,
g,
rect,
collector,
model,
}: AddControlNetsArg): Promise<AddControlNetsResult> => {
const validControlLayers = entities
.filter((entity) => entity.isEnabled)
.filter((entity) => entity.controlAdapter.type === 'controlnet')
.filter((entity) => getControlLayerWarnings(entity, model).length === 0);
const result: AddControlNetsResult = {
addedControlNets: 0,
@@ -54,22 +60,31 @@ export const addControlNets = async (
return result;
};
type AddT2IAdaptersArg = {
manager: CanvasManager;
entities: CanvasControlLayerState[];
g: Graph;
rect: Rect;
collector: Invocation<'collect'>;
model: ParameterModel;
};
type AddT2IAdaptersResult = {
addedT2IAdapters: number;
};
export const addT2IAdapters = async (
manager: CanvasManager,
layers: CanvasControlLayerState[],
g: Graph,
rect: Rect,
collector: Invocation<'collect'>,
base: BaseModelType
): Promise<AddT2IAdaptersResult> => {
const validControlLayers = layers
.filter((layer) => layer.isEnabled)
.filter((layer) => isValidControlAdapter(layer.controlAdapter, base))
.filter((layer) => layer.controlAdapter.type === 't2i_adapter');
export const addT2IAdapters = async ({
manager,
entities,
g,
rect,
collector,
model,
}: AddT2IAdaptersArg): Promise<AddT2IAdaptersResult> => {
const validControlLayers = entities
.filter((entity) => entity.isEnabled)
.filter((entity) => entity.controlAdapter.type === 't2i_adapter')
.filter((entity) => getControlLayerWarnings(entity, model).length === 0);
const result: AddT2IAdaptersResult = {
addedT2IAdapters: 0,
@@ -145,11 +160,3 @@ const addT2IAdapterToGraph = (
g.addEdge(t2iAdapter, 't2i_adapter', collector, 'item');
};
const isValidControlAdapter = (controlAdapter: ControlNetConfig | T2IAdapterConfig, base: BaseModelType): boolean => {
// Must be have a model
const hasModel = Boolean(controlAdapter.model);
// Model must match the current base model
const modelMatchesBase = controlAdapter.model?.base === base;
return hasModel && modelMatchesBase;
};

View File

@@ -1,19 +1,23 @@
import type { CanvasReferenceImageState } from 'features/controlLayers/store/types';
import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { BaseModelType, Invocation } from 'services/api/types';
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
import type { Invocation } from 'services/api/types';
import { assert } from 'tsafe';
type AddIPAdaptersResult = {
addedIPAdapters: number;
};
export const addIPAdapters = (
ipAdapters: CanvasReferenceImageState[],
g: Graph,
collector: Invocation<'collect'>,
base: BaseModelType
): AddIPAdaptersResult => {
const validIPAdapters = ipAdapters.filter((entity) => isValidIPAdapter(entity, base));
type AddIPAdaptersArg = {
entities: CanvasReferenceImageState[];
g: Graph;
collector: Invocation<'collect'>;
model: ParameterModel;
};
export const addIPAdapters = ({ entities, g, collector, model }: AddIPAdaptersArg): AddIPAdaptersResult => {
const validIPAdapters = entities.filter((entity) => getGlobalReferenceImageWarnings(entity, model).length === 0);
const result: AddIPAdaptersResult = {
addedIPAdapters: 0,
@@ -76,11 +80,3 @@ const addIPAdapter = (entity: CanvasReferenceImageState, g: Graph, collector: In
g.addEdge(ipAdapterNode, 'ip_adapter', collector, 'item');
};
const isValidIPAdapter = ({ isEnabled, ipAdapter }: CanvasReferenceImageState, base: BaseModelType): boolean => {
// Must be have a model that matches the current base and must have a control image
const hasModel = Boolean(ipAdapter.model);
const modelMatchesBase = ipAdapter.model?.base === base;
const hasImage = Boolean(ipAdapter.image);
return isEnabled && hasModel && modelMatchesBase && hasImage;
};

View File

@@ -3,15 +3,12 @@ import { deepClone } from 'common/util/deepClone';
import { withResultAsync } from 'common/util/result';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import type {
CanvasRegionalGuidanceState,
IPAdapterConfig,
Rect,
RegionalGuidanceReferenceImageState,
} from 'features/controlLayers/store/types';
import type { CanvasRegionalGuidanceState, Rect } from 'features/controlLayers/store/types';
import { getRegionalGuidanceWarnings } from 'features/controlLayers/store/validators';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
import { serializeError } from 'serialize-error';
import type { BaseModelType, Invocation } from 'services/api/types';
import type { Invocation } from 'services/api/types';
import { assert } from 'tsafe';
const log = logger('system');
@@ -23,19 +20,26 @@ type AddedRegionResult = {
addedIPAdapters: number;
};
const isValidRegion = (rg: CanvasRegionalGuidanceState, base: BaseModelType) => {
const isEnabled = rg.isEnabled;
const hasTextPrompt = Boolean(rg.positivePrompt || rg.negativePrompt);
const hasIPAdapter = rg.referenceImages.filter(({ ipAdapter }) => isValidIPAdapter(ipAdapter, base)).length > 0;
return isEnabled && (hasTextPrompt || hasIPAdapter);
type AddRegionsArg = {
manager: CanvasManager;
regions: CanvasRegionalGuidanceState[];
g: Graph;
bbox: Rect;
model: ParameterModel;
posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>;
negCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'> | null;
posCondCollect: Invocation<'collect'>;
negCondCollect: Invocation<'collect'> | null;
ipAdapterCollect: Invocation<'collect'>;
};
/**
* Adds regional guidance to the graph
* @param manager The canvas manager
* @param regions Array of regions to add
* @param g The graph to add the layers to
* @param base The base model type
* @param denoise The main denoise node
* @param bbox The bounding box
* @param model The main model
* @param posCond The positive conditioning node
* @param negCond The negative conditioning node
* @param posCondCollect The positive conditioning collector
@@ -44,22 +48,28 @@ const isValidRegion = (rg: CanvasRegionalGuidanceState, base: BaseModelType) =>
* @returns A promise that resolves to the regions that were successfully added to the graph
*/
export const addRegions = async (
manager: CanvasManager,
regions: CanvasRegionalGuidanceState[],
g: Graph,
bbox: Rect,
base: BaseModelType,
denoise: Invocation<'denoise_latents'>,
posCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>,
negCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>,
posCondCollect: Invocation<'collect'>,
negCondCollect: Invocation<'collect'>,
ipAdapterCollect: Invocation<'collect'>
): Promise<AddedRegionResult[]> => {
const isSDXL = base === 'sdxl';
export const addRegions = async ({
manager,
regions,
g,
bbox,
model,
posCond,
negCond,
posCondCollect,
negCondCollect,
ipAdapterCollect,
}: AddRegionsArg): Promise<AddedRegionResult[]> => {
const isSDXL = model.base === 'sdxl';
const isFLUX = model.base === 'flux';
const validRegions = regions.filter((rg) => {
if (!rg.isEnabled) {
return false;
}
return getRegionalGuidanceWarnings(rg, model).length === 0;
});
const validRegions = regions.filter((rg) => isValidRegion(rg, base));
const results: AddedRegionResult[] = [];
for (const region of validRegions) {
@@ -94,20 +104,27 @@ export const addRegions = async (
if (region.positivePrompt) {
// The main positive conditioning node
result.addedPositivePrompt = true;
const regionalPosCond = g.addNode(
isSDXL
? {
type: 'sdxl_compel_prompt',
id: getPrefixedId('prompt_region_positive_cond'),
prompt: region.positivePrompt,
style: region.positivePrompt, // TODO: Should we put the positive prompt in both fields?
}
: {
type: 'compel',
id: getPrefixedId('prompt_region_positive_cond'),
prompt: region.positivePrompt,
}
);
let regionalPosCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>;
if (isSDXL) {
regionalPosCond = g.addNode({
type: 'sdxl_compel_prompt',
id: getPrefixedId('prompt_region_positive_cond'),
prompt: region.positivePrompt,
style: region.positivePrompt, // TODO: Should we put the positive prompt in both fields?
});
} else if (isFLUX) {
regionalPosCond = g.addNode({
type: 'flux_text_encoder',
id: getPrefixedId('prompt_region_positive_cond'),
prompt: region.positivePrompt,
});
} else {
regionalPosCond = g.addNode({
type: 'compel',
id: getPrefixedId('prompt_region_positive_cond'),
prompt: region.positivePrompt,
});
}
// Connect the mask to the conditioning
g.addEdge(maskToTensor, 'mask', regionalPosCond, 'mask');
// Connect the conditioning to the collector
@@ -115,38 +132,55 @@ export const addRegions = async (
// Copy the connections to the "global" positive conditioning node to the regional cond
if (posCond.type === 'compel') {
for (const edge of g.getEdgesTo(posCond, ['clip', 'mask'])) {
// Clone the edge, but change the destination node to the regional conditioning node
const clone = deepClone(edge);
clone.destination.node_id = regionalPosCond.id;
g.addEdgeFromObj(clone);
}
} else if (posCond.type === 'sdxl_compel_prompt') {
for (const edge of g.getEdgesTo(posCond, ['clip', 'clip2', 'mask'])) {
const clone = deepClone(edge);
clone.destination.node_id = regionalPosCond.id;
g.addEdgeFromObj(clone);
}
} else if (posCond.type === 'flux_text_encoder') {
for (const edge of g.getEdgesTo(posCond, ['clip', 't5_encoder', 't5_max_seq_len', 'mask'])) {
const clone = deepClone(edge);
clone.destination.node_id = regionalPosCond.id;
g.addEdgeFromObj(clone);
}
} else {
for (const edge of g.getEdgesTo(posCond, ['clip', 'clip2', 'mask'])) {
// Clone the edge, but change the destination node to the regional conditioning node
const clone = deepClone(edge);
clone.destination.node_id = regionalPosCond.id;
g.addEdgeFromObj(clone);
}
assert(false, 'Unsupported positive conditioning node type.');
}
}
if (region.negativePrompt) {
result.addedNegativePrompt = true;
assert(negCond, 'Negative conditioning node is required if there is a negative prompt');
assert(negCondCollect, 'Negative conditioning collector is required if there is a negative prompt');
// The main negative conditioning node
const regionalNegCond = g.addNode(
isSDXL
? {
type: 'sdxl_compel_prompt',
id: getPrefixedId('prompt_region_negative_cond'),
prompt: region.negativePrompt,
style: region.negativePrompt,
}
: {
type: 'compel',
id: getPrefixedId('prompt_region_negative_cond'),
prompt: region.negativePrompt,
}
);
result.addedNegativePrompt = true;
let regionalNegCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>;
if (isSDXL) {
regionalNegCond = g.addNode({
type: 'sdxl_compel_prompt',
id: getPrefixedId('prompt_region_negative_cond'),
prompt: region.negativePrompt,
style: region.negativePrompt,
});
} else if (isFLUX) {
regionalNegCond = g.addNode({
type: 'flux_text_encoder',
id: getPrefixedId('prompt_region_negative_cond'),
prompt: region.negativePrompt,
});
} else {
regionalNegCond = g.addNode({
type: 'compel',
id: getPrefixedId('prompt_region_negative_cond'),
prompt: region.negativePrompt,
});
}
// Connect the mask to the conditioning
g.addEdge(maskToTensor, 'mask', regionalNegCond, 'mask');
// Connect the conditioning to the collector
@@ -158,17 +192,27 @@ export const addRegions = async (
clone.destination.node_id = regionalNegCond.id;
g.addEdgeFromObj(clone);
}
} else {
} else if (negCond.type === 'sdxl_compel_prompt') {
for (const edge of g.getEdgesTo(negCond, ['clip', 'clip2', 'mask'])) {
const clone = deepClone(edge);
clone.destination.node_id = regionalNegCond.id;
g.addEdgeFromObj(clone);
}
} else if (negCond.type === 'flux_text_encoder') {
for (const edge of g.getEdgesTo(negCond, ['clip', 't5_encoder', 't5_max_seq_len', 'mask'])) {
const clone = deepClone(edge);
clone.destination.node_id = regionalNegCond.id;
g.addEdgeFromObj(clone);
}
} else {
assert(false, 'Unsupported negative conditioning node type.');
}
}
// If we are using the "invert" auto-negative setting, we need to add an additional negative conditioning node
if (region.autoNegative && region.positivePrompt) {
assert(negCondCollect, 'Negative conditioning collector is required if there is an auto-negative setting');
result.addedAutoNegativePositivePrompt = true;
// We re-use the mask image, but invert it when converting to tensor
const invertTensorMask = g.addNode({
@@ -178,20 +222,27 @@ export const addRegions = async (
// Connect the OG mask image to the inverted mask-to-tensor node
g.addEdge(maskToTensor, 'mask', invertTensorMask, 'mask');
// Create the conditioning node. It's going to be connected to the negative cond collector, but it uses the positive prompt
const regionalPosCondInverted = g.addNode(
isSDXL
? {
type: 'sdxl_compel_prompt',
id: getPrefixedId('prompt_region_positive_cond_inverted'),
prompt: region.positivePrompt,
style: region.positivePrompt,
}
: {
type: 'compel',
id: getPrefixedId('prompt_region_positive_cond_inverted'),
prompt: region.positivePrompt,
}
);
let regionalPosCondInverted: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>;
if (isSDXL) {
regionalPosCondInverted = g.addNode({
type: 'sdxl_compel_prompt',
id: getPrefixedId('prompt_region_positive_cond_inverted'),
prompt: region.positivePrompt,
style: region.positivePrompt,
});
} else if (isFLUX) {
regionalPosCondInverted = g.addNode({
type: 'flux_text_encoder',
id: getPrefixedId('prompt_region_positive_cond_inverted'),
prompt: region.positivePrompt,
});
} else {
regionalPosCondInverted = g.addNode({
type: 'compel',
id: getPrefixedId('prompt_region_positive_cond_inverted'),
prompt: region.positivePrompt,
});
}
// Connect the inverted mask to the conditioning
g.addEdge(invertTensorMask, 'mask', regionalPosCondInverted, 'mask');
// Connect the conditioning to the negative collector
@@ -203,20 +254,26 @@ export const addRegions = async (
clone.destination.node_id = regionalPosCondInverted.id;
g.addEdgeFromObj(clone);
}
} else {
} else if (posCond.type === 'sdxl_compel_prompt') {
for (const edge of g.getEdgesTo(posCond, ['clip', 'clip2', 'mask'])) {
const clone = deepClone(edge);
clone.destination.node_id = regionalPosCondInverted.id;
g.addEdgeFromObj(clone);
}
} else if (posCond.type === 'flux_text_encoder') {
for (const edge of g.getEdgesTo(posCond, ['clip', 't5_encoder', 't5_max_seq_len', 'mask'])) {
const clone = deepClone(edge);
clone.destination.node_id = regionalPosCondInverted.id;
g.addEdgeFromObj(clone);
}
} else {
assert(false, 'Unsupported positive conditioning node type.');
}
}
const validRGIPAdapters: RegionalGuidanceReferenceImageState[] = region.referenceImages.filter(({ ipAdapter }) =>
isValidIPAdapter(ipAdapter, base)
);
for (const { id, ipAdapter } of region.referenceImages) {
assert(!isFLUX, 'Regional IP adapters are not supported for FLUX.');
for (const { id, ipAdapter } of validRGIPAdapters) {
result.addedIPAdapters++;
const { weight, model, clipVisionModel, method, beginEndStepPct, image } = ipAdapter;
assert(model, 'IP Adapter model is required');
@@ -248,11 +305,3 @@ export const addRegions = async (
return results;
};
const isValidIPAdapter = (ipAdapter: IPAdapterConfig, base: BaseModelType): boolean => {
// Must be have a model that matches the current base and must have a control image
const hasModel = Boolean(ipAdapter.model);
const modelMatchesBase = ipAdapter.model?.base === base;
const hasImage = Boolean(ipAdapter.image);
return hasModel && modelMatchesBase && hasImage;
};

View File

@@ -11,6 +11,7 @@ import { addImageToImage } from 'features/nodes/util/graph/generation/addImageTo
import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint';
import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker';
import { addOutpaint } from 'features/nodes/util/graph/generation/addOutpaint';
import { addRegions } from 'features/nodes/util/graph/generation/addRegions';
import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToImage';
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
@@ -79,7 +80,10 @@ export const buildFLUXGraph = async (
id: getPrefixedId('flux_text_encoder'),
prompt: positivePrompt,
});
const posCondCollect = g.addNode({
type: 'collect',
id: getPrefixedId('pos_cond_collect'),
});
const denoise = g.addNode({
type: 'flux_denoise',
id: getPrefixedId('flux_denoise'),
@@ -104,13 +108,12 @@ export const buildFLUXGraph = async (
g.addEdge(modelLoader, 'clip', posCond, 'clip');
g.addEdge(modelLoader, 't5_encoder', posCond, 't5_encoder');
g.addEdge(modelLoader, 'max_seq_len', posCond, 't5_max_seq_len');
g.addEdge(posCond, 'conditioning', posCondCollect, 'item');
g.addEdge(posCondCollect, 'collection', denoise, 'positive_text_conditioning');
g.addEdge(denoise, 'latents', l2i, 'latents');
addFLUXLoRAs(state, g, denoise, modelLoader, posCond);
g.addEdge(posCond, 'conditioning', denoise, 'positive_text_conditioning');
g.addEdge(denoise, 'latents', l2i, 'latents');
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
assert(modelConfig.base === 'flux');
@@ -196,31 +199,50 @@ export const buildFLUXGraph = async (
type: 'collect',
id: getPrefixedId('control_net_collector'),
});
const controlNetResult = await addControlNets(
const controlNetResult = await addControlNets({
manager,
canvas.controlLayers.entities,
entities: canvas.controlLayers.entities,
g,
canvas.bbox.rect,
controlNetCollector,
modelConfig.base
);
rect: canvas.bbox.rect,
collector: controlNetCollector,
model: modelConfig,
});
if (controlNetResult.addedControlNets > 0) {
g.addEdge(controlNetCollector, 'collection', denoise, 'control');
} else {
g.deleteNode(controlNetCollector.id);
}
const ipAdapterCollector = g.addNode({
const ipAdapterCollect = g.addNode({
type: 'collect',
id: getPrefixedId('ip_adapter_collector'),
});
const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollector, modelConfig.base);
const ipAdapterResult = addIPAdapters({
entities: canvas.referenceImages.entities,
g,
collector: ipAdapterCollect,
model: modelConfig,
});
const totalIPAdaptersAdded = ipAdapterResult.addedIPAdapters;
const regionsResult = await addRegions({
manager,
regions: canvas.regionalGuidance.entities,
g,
bbox: canvas.bbox.rect,
model: modelConfig,
posCond,
negCond: null,
posCondCollect,
negCondCollect: null,
ipAdapterCollect,
});
const totalIPAdaptersAdded =
ipAdapterResult.addedIPAdapters + regionsResult.reduce((acc, r) => acc + r.addedIPAdapters, 0);
if (totalIPAdaptersAdded > 0) {
g.addEdge(ipAdapterCollector, 'collection', denoise, 'ip_adapter');
g.addEdge(ipAdapterCollect, 'collection', denoise, 'ip_adapter');
} else {
g.deleteNode(ipAdapterCollector.id);
g.deleteNode(ipAdapterCollect.id);
}
if (state.system.shouldUseNSFWChecker) {

View File

@@ -227,14 +227,14 @@ export const buildSD1Graph = async (
type: 'collect',
id: getPrefixedId('control_net_collector'),
});
const controlNetResult = await addControlNets(
const controlNetResult = await addControlNets({
manager,
canvas.controlLayers.entities,
entities: canvas.controlLayers.entities,
g,
canvas.bbox.rect,
controlNetCollector,
modelConfig.base
);
rect: canvas.bbox.rect,
collector: controlNetCollector,
model: modelConfig,
});
if (controlNetResult.addedControlNets > 0) {
g.addEdge(controlNetCollector, 'collection', denoise, 'control');
} else {
@@ -245,46 +245,50 @@ export const buildSD1Graph = async (
type: 'collect',
id: getPrefixedId('t2i_adapter_collector'),
});
const t2iAdapterResult = await addT2IAdapters(
const t2iAdapterResult = await addT2IAdapters({
manager,
canvas.controlLayers.entities,
entities: canvas.controlLayers.entities,
g,
canvas.bbox.rect,
t2iAdapterCollector,
modelConfig.base
);
rect: canvas.bbox.rect,
collector: t2iAdapterCollector,
model: modelConfig,
});
if (t2iAdapterResult.addedT2IAdapters > 0) {
g.addEdge(t2iAdapterCollector, 'collection', denoise, 't2i_adapter');
} else {
g.deleteNode(t2iAdapterCollector.id);
}
const ipAdapterCollector = g.addNode({
const ipAdapterCollect = g.addNode({
type: 'collect',
id: getPrefixedId('ip_adapter_collector'),
});
const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollector, modelConfig.base);
const regionsResult = await addRegions(
manager,
canvas.regionalGuidance.entities,
const ipAdapterResult = addIPAdapters({
entities: canvas.referenceImages.entities,
g,
canvas.bbox.rect,
modelConfig.base,
denoise,
collector: ipAdapterCollect,
model: modelConfig,
});
const regionsResult = await addRegions({
manager,
regions: canvas.regionalGuidance.entities,
g,
bbox: canvas.bbox.rect,
model: modelConfig,
posCond,
negCond,
posCondCollect,
negCondCollect,
ipAdapterCollector
);
ipAdapterCollect,
});
const totalIPAdaptersAdded =
ipAdapterResult.addedIPAdapters + regionsResult.reduce((acc, r) => acc + r.addedIPAdapters, 0);
if (totalIPAdaptersAdded > 0) {
g.addEdge(ipAdapterCollector, 'collection', denoise, 'ip_adapter');
g.addEdge(ipAdapterCollect, 'collection', denoise, 'ip_adapter');
} else {
g.deleteNode(ipAdapterCollector.id);
g.deleteNode(ipAdapterCollect.id);
}
if (state.system.shouldUseNSFWChecker) {

View File

@@ -232,14 +232,14 @@ export const buildSDXLGraph = async (
type: 'collect',
id: getPrefixedId('control_net_collector'),
});
const controlNetResult = await addControlNets(
const controlNetResult = await addControlNets({
manager,
canvas.controlLayers.entities,
entities: canvas.controlLayers.entities,
g,
canvas.bbox.rect,
controlNetCollector,
modelConfig.base
);
rect: canvas.bbox.rect,
collector: controlNetCollector,
model: modelConfig,
});
if (controlNetResult.addedControlNets > 0) {
g.addEdge(controlNetCollector, 'collection', denoise, 'control');
} else {
@@ -250,46 +250,50 @@ export const buildSDXLGraph = async (
type: 'collect',
id: getPrefixedId('t2i_adapter_collector'),
});
const t2iAdapterResult = await addT2IAdapters(
const t2iAdapterResult = await addT2IAdapters({
manager,
canvas.controlLayers.entities,
entities: canvas.controlLayers.entities,
g,
canvas.bbox.rect,
t2iAdapterCollector,
modelConfig.base
);
rect: canvas.bbox.rect,
collector: t2iAdapterCollector,
model: modelConfig,
});
if (t2iAdapterResult.addedT2IAdapters > 0) {
g.addEdge(t2iAdapterCollector, 'collection', denoise, 't2i_adapter');
} else {
g.deleteNode(t2iAdapterCollector.id);
}
const ipAdapterCollector = g.addNode({
const ipAdapterCollect = g.addNode({
type: 'collect',
id: getPrefixedId('ip_adapter_collector'),
});
const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollector, modelConfig.base);
const regionsResult = await addRegions(
manager,
canvas.regionalGuidance.entities,
const ipAdapterResult = addIPAdapters({
entities: canvas.referenceImages.entities,
g,
canvas.bbox.rect,
modelConfig.base,
denoise,
collector: ipAdapterCollect,
model: modelConfig,
});
const regionsResult = await addRegions({
manager,
regions: canvas.regionalGuidance.entities,
g,
bbox: canvas.bbox.rect,
model: modelConfig,
posCond,
negCond,
posCondCollect,
negCondCollect,
ipAdapterCollector
);
ipAdapterCollect,
});
const totalIPAdaptersAdded =
ipAdapterResult.addedIPAdapters + regionsResult.reduce((acc, r) => acc + r.addedIPAdapters, 0);
if (totalIPAdaptersAdded > 0) {
g.addEdge(ipAdapterCollector, 'collection', denoise, 'ip_adapter');
g.addEdge(ipAdapterCollect, 'collection', denoise, 'ip_adapter');
} else {
g.deleteNode(ipAdapterCollector.id);
g.deleteNode(ipAdapterCollect.id);
}
if (state.system.shouldUseNSFWChecker) {

View File

@@ -58,6 +58,9 @@ const isAllowedOutputField = (nodeType: string, fieldName: string) => {
if (RESERVED_OUTPUT_FIELD_NAMES.includes(fieldName)) {
return false;
}
if (nodeType === 'image_batch' && fieldName !== 'image') {
return false;
}
return true;
};

View File

@@ -4,13 +4,13 @@ import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiTrashSimpleFill } from 'react-icons/pi';
import { useClearQueue } from './ClearQueueConfirmationAlertDialog';
import { useClearQueueDialog } from './ClearQueueConfirmationAlertDialog';
type Props = ButtonProps;
const ClearQueueButton = (props: Props) => {
const { t } = useTranslation();
const clearQueue = useClearQueue();
const clearQueue = useClearQueueDialog();
return (
<>

View File

@@ -1,51 +1,15 @@
import { ConfirmationAlertDialog, Text } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
import { buildUseBoolean } from 'common/hooks/useBoolean';
import { listCursorChanged, listPriorityChanged } from 'features/queue/store/queueSlice';
import { toast } from 'features/toast/toast';
import { memo, useCallback, useMemo } from 'react';
import { useClearQueue } from 'features/queue/hooks/useClearQueue';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { useClearQueueMutation, useGetQueueStatusQuery } from 'services/api/endpoints/queue';
import { $isConnected } from 'services/events/stores';
const [useClearQueueConfirmationAlertDialog] = buildUseBoolean(false);
export const useClearQueue = () => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
export const useClearQueueDialog = () => {
const dialog = useClearQueueConfirmationAlertDialog();
const { data: queueStatus } = useGetQueueStatusQuery();
const isConnected = useStore($isConnected);
const [trigger, { isLoading }] = useClearQueueMutation({
fixedCacheKey: 'clearQueue',
});
const clearQueue = useCallback(async () => {
if (!queueStatus?.queue.total) {
return;
}
try {
await trigger().unwrap();
toast({
id: 'QUEUE_CLEAR_SUCCEEDED',
title: t('queue.clearSucceeded'),
status: 'success',
});
dispatch(listCursorChanged(undefined));
dispatch(listPriorityChanged(undefined));
} catch {
toast({
id: 'QUEUE_CLEAR_FAILED',
title: t('queue.clearFailed'),
status: 'error',
});
}
}, [queueStatus?.queue.total, trigger, dispatch, t]);
const isDisabled = useMemo(() => !isConnected || !queueStatus?.queue.total, [isConnected, queueStatus?.queue.total]);
const { clearQueue, isLoading, isDisabled, queueStatus } = useClearQueue();
return {
clearQueue,
@@ -61,7 +25,7 @@ export const useClearQueue = () => {
export const ClearQueueConfirmationsAlertDialog = memo(() => {
useAssertSingleton('ClearQueueConfirmationsAlertDialog');
const { t } = useTranslation();
const clearQueue = useClearQueue();
const clearQueue = useClearQueueDialog();
return (
<ConfirmationAlertDialog

View File

@@ -4,11 +4,11 @@ import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiTrashSimpleBold, PiXBold } from 'react-icons/pi';
import { useClearQueue } from './ClearQueueConfirmationAlertDialog';
import { useClearQueueDialog } from './ClearQueueConfirmationAlertDialog';
export const ClearQueueIconButton = memo((_) => {
const { t } = useTranslation();
const clearQueue = useClearQueue();
const clearQueue = useClearQueueDialog();
const cancelCurrentQueueItem = useCancelCurrentQueueItem();
// Show the single item clear button when shift is pressed

View File

@@ -147,8 +147,6 @@ const UpscaleTabTooltipContent = memo(({ prepend = false }: { prepend?: boolean
<ReasonsList reasons={reasons} />
</>
)}
<StyledDivider />
<AddingToText />
</Flex>
);
});
@@ -180,8 +178,6 @@ const WorkflowsTabTooltipContent = memo(({ prepend = false }: { prepend?: boolea
<ReasonsList reasons={reasons} />
</>
)}
<StyledDivider />
<AddingToText />
</Flex>
);
});

View File

@@ -1,8 +1,9 @@
import { IconButton, Menu, MenuButton, MenuGroup, MenuItem, MenuList } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { SessionMenuItems } from 'common/components/SessionMenuItems';
import { useClearQueue } from 'features/queue/components/ClearQueueConfirmationAlertDialog';
import { useClearQueueDialog } from 'features/queue/components/ClearQueueConfirmationAlertDialog';
import { QueueCountBadge } from 'features/queue/components/QueueCountBadge';
import { useCancelCurrentQueueItem } from 'features/queue/hooks/useCancelCurrentQueueItem';
import { usePauseProcessor } from 'features/queue/hooks/usePauseProcessor';
import { useResumeProcessor } from 'features/queue/hooks/useResumeProcessor';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
@@ -17,7 +18,8 @@ export const QueueActionsMenuButton = memo(() => {
const { t } = useTranslation();
const isPauseEnabled = useFeatureStatus('pauseQueue');
const isResumeEnabled = useFeatureStatus('resumeQueue');
const clearQueue = useClearQueue();
const cancelCurrent = useCancelCurrentQueueItem();
const clearQueue = useClearQueueDialog();
const {
resumeProcessor,
isLoading: isLoadingResumeProcessor,
@@ -44,9 +46,9 @@ export const QueueActionsMenuButton = memo(() => {
<MenuItem
isDestructive
icon={<PiXBold />}
onClick={clearQueue.openDialog}
isLoading={clearQueue.isLoading}
isDisabled={clearQueue.isDisabled}
onClick={cancelCurrent.cancelQueueItem}
isLoading={cancelCurrent.isLoading}
isDisabled={cancelCurrent.isDisabled}
>
{t('queue.cancelTooltip')}
</MenuItem>

View File

@@ -0,0 +1,50 @@
import { useStore } from '@nanostores/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { listCursorChanged, listPriorityChanged } from 'features/queue/store/queueSlice';
import { toast } from 'features/toast/toast';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useClearQueueMutation, useGetQueueStatusQuery } from 'services/api/endpoints/queue';
import { $isConnected } from 'services/events/stores';
export const useClearQueue = () => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const { data: queueStatus } = useGetQueueStatusQuery();
const isConnected = useStore($isConnected);
const [trigger, { isLoading }] = useClearQueueMutation({
fixedCacheKey: 'clearQueue',
});
const clearQueue = useCallback(async () => {
if (!queueStatus?.queue.total) {
return;
}
try {
await trigger().unwrap();
toast({
id: 'QUEUE_CLEAR_SUCCEEDED',
title: t('queue.clearSucceeded'),
status: 'success',
});
dispatch(listCursorChanged(undefined));
dispatch(listPriorityChanged(undefined));
} catch {
toast({
id: 'QUEUE_CLEAR_FAILED',
title: t('queue.clearFailed'),
status: 'error',
});
}
}, [queueStatus?.queue.total, trigger, dispatch, t]);
const isDisabled = useMemo(() => !isConnected || !queueStatus?.queue.total, [isConnected, queueStatus?.queue.total]);
return {
clearQueue,
isLoading,
queueStatus,
isDisabled,
};
};

Some files were not shown because too many files have changed in this diff Show More