Compare commits

..

160 Commits

Author SHA1 Message Date
psychedelicious
83e33a4810 chore: bump version to v5.6.0 2025-01-21 17:58:47 +11:00
psychedelicious
e635028477 chore(ui): update whats new copy 2025-01-21 17:58:47 +11:00
psychedelicious
b7b8f8a9e5 fix(nodes): remove WithMetadata from non-image-outputting node 2025-01-21 17:58:47 +11:00
psychedelicious
e926d2f24b fix(nodes): add beta classification to new inpainting support nodes 2025-01-21 17:58:47 +11:00
psychedelicious
ad8885c456 chore(ui): typegen 2025-01-21 17:45:32 +11:00
psychedelicious
cf4c79fe2e feat(nodes): add PasteImageIntoBoundingBoxInvocation 2025-01-21 17:45:32 +11:00
psychedelicious
e0edfe6c40 feat(nodes): add CropImageToBoundingBoxInvocation 2025-01-21 17:45:32 +11:00
psychedelicious
8a0a37191a feat(nodes): add GetMaskBoundingBoxInvocation 2025-01-21 17:45:32 +11:00
psychedelicious
7dbd5f150a feat(nodes): add BoundingBoxField.tuple() to get bbox as PIL tuple 2025-01-21 17:45:32 +11:00
psychedelicious
1ad65ffd53 feat(nodes): re-title "Mask from ID" -> "Mask from Segmented Image" 2025-01-21 17:45:32 +11:00
psychedelicious
14b5c871dc feat(nodes): simplify MaskFromIDInvocation 2025-01-21 17:45:32 +11:00
psychedelicious
8d2b4e2bf5 feat(nodes): support FLUX, SD3 in ideal_size 2025-01-21 17:45:32 +11:00
psychedelicious
aba70eacab fix(ui): field handle positioning for non-batch fields
Accidentally overwrote some reactflow styles which caused field handles to be positioned differently for non-batch fields. Just a minor visual issue.
2025-01-21 11:49:49 +11:00
Riccardo Giovanetti
4b67175b1b translationBot(ui): update translation (Italian)
Currently translated at 99.1% (1690 of 1704 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2025-01-21 09:12:45 +11:00
Hosted Weblate
e3423d1ba8 translationBot(ui): update translation files
Updated by "Cleanup translation files" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI
2025-01-21 09:12:45 +11:00
Linos
87fb00ff5d translationBot(ui): update translation (Vietnamese)
Currently translated at 100.0% (1697 of 1697 strings)

translationBot(ui): update translation (Vietnamese)

Currently translated at 99.2% (1684 of 1697 strings)

translationBot(ui): update translation (Vietnamese)

Currently translated at 99.7% (1676 of 1681 strings)

translationBot(ui): update translation (Vietnamese)

Currently translated at 99.3% (1670 of 1681 strings)

translationBot(ui): update translation (Vietnamese)

Currently translated at 99.5% (1658 of 1666 strings)

translationBot(ui): update translation (Vietnamese)

Currently translated at 100.0% (1652 of 1652 strings)

Co-authored-by: Linos <linos.coding@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/vi/
Translation: InvokeAI/Web UI
2025-01-21 09:12:45 +11:00
Riccardo Giovanetti
d99a9ffb72 translationBot(ui): update translation (Italian)
Currently translated at 99.3% (1642 of 1652 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2025-01-21 09:12:45 +11:00
Hosted Weblate
7964f438dc translationBot(ui): update translation files
Updated by "Cleanup translation files" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI
2025-01-21 09:12:45 +11:00
Linos
b130a3a9ee translationBot(ui): update translation (Vietnamese)
Currently translated at 100.0% (1652 of 1652 strings)

Co-authored-by: Linos <linos.coding@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/vi/
Translation: InvokeAI/Web UI
2025-01-21 09:12:45 +11:00
Riccardo Giovanetti
a6b32160b2 translationBot(ui): update translation (Italian)
Currently translated at 99.3% (1642 of 1652 strings)

translationBot(ui): update translation (Italian)

Currently translated at 99.3% (1641 of 1652 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2025-01-21 09:12:45 +11:00
psychedelicious
7d110cc9d3 fix(ui): disable dynamic prompts generators pending resolution of infinite recursion issue
Dynamic prompts string generators can cause an infinite feedback loop when added to the linear view.

The root cause is how these generators handle "resolving" their collections. They hit the dynamic prompts HTTP API within the view component to get the prompts, then set the batch node's internal state with those values.

When the same generator is rendered in both the node editor view and linear view and the timing is just right, that state update causes an infinite feedback loop between the two components as they respond to the state updates from the other component.

The other generators never store the generated values in the batch node's internal state. The values are "resolved" just-in-time as they are needed.

To fix this, the batch value "resolver" utilities could be made async and hit the API. But there's a problem - the resolver utilities are used within the "are we ready to invoke? are there any problems with the current settings?" redux selectors, which are strictly synchronous. To fix that, we can refactor that "are we ready to invoke?" logic to not use redux selectors, so the whole thing could be async.

It's not a big change but I'm not going to spend time on it at the moment.

So, until I address this, the dynamic prompts generators are disabled.
2025-01-21 09:00:40 +11:00
psychedelicious
82122645e8 refactor(ui): organize special handling for batch field types 2025-01-21 07:17:29 +11:00
psychedelicious
f5c5b73383 fix(ui): string batch nodes' inputs get batch type 2025-01-21 07:17:29 +11:00
psychedelicious
2b2ec67cd6 fix(nodes): allow connection input on string batch nodes 2025-01-21 07:17:29 +11:00
Ryan Dick
66bc225bd3 Add a troubleshooting instructions for the Windows page file issue to the Low-VRAM docs. 2025-01-20 08:58:41 +11:00
psychedelicious
7535d2e188 feat(ui): use translation for load from file buttons 2025-01-20 08:57:42 +11:00
psychedelicious
3dff87aeee feat(ui): better layout for generator load from file buttons 2025-01-20 08:57:42 +11:00
psychedelicious
b14bf1e0f4 chore(ui): lint 2025-01-20 08:57:42 +11:00
psychedelicious
4fdc6eec9d feat(ui): support loading from file for string input generators 2025-01-20 08:57:42 +11:00
psychedelicious
180a67d11b feat(ui): small fontsize on generator textareas 2025-01-20 08:57:42 +11:00
psychedelicious
ec816d3c04 feat(ui): improved dynamicprompts generator
- Split into two (random and combinatorial) - lots of fiddly logic to do both in one generator.
- Update to support seeds for random.
2025-01-20 08:57:42 +11:00
psychedelicious
7dcc2dafbc chore(ui): typegen 2025-01-20 08:57:42 +11:00
psychedelicious
81da5210f0 feat(api): add seed field to dynamicprompts 2025-01-20 08:57:42 +11:00
psychedelicious
eb976a2ab0 feat(ui): add dynamic prompts string generator (WIP) 2025-01-20 08:57:42 +11:00
psychedelicious
724028d974 feat(ui): port improved string parsing logic from string generator to float & int 2025-01-20 08:57:42 +11:00
psychedelicious
43c98fd99e feat(ui): add string generator 2025-01-20 08:57:42 +11:00
psychedelicious
526d64a5e2 feat(nodes): add string generator 2025-01-20 08:57:42 +11:00
psychedelicious
58c6c6db53 feat(ui): make string collection component same as number collection
Same UI & better perf thanks to a different structure.
2025-01-20 08:57:42 +11:00
psychedelicious
8a41e09de3 feat(ui): seeded random generators
- Add JS Mersenne Twister implementation dependency to use as seeded PRNG. This is not a cryptographically secure algorithm.
- Add nullish seed field to float and integer random generators.
- Add UI to control the seed.
- When seed is not set, behaviour is unchanged - the values are randomized when you Invoke. When seed is set, the random distribution is deterministic depending on the seed. In this case, we can display the values to the user.
2025-01-18 08:45:56 +11:00
psychedelicious
c24eae1968 chore: bump version to v5.6.0rc4 2025-01-17 16:29:20 +11:00
psychedelicious
a6b207a0d9 fix(ui): string field textarea accidentally readonly 2025-01-17 16:17:13 +11:00
psychedelicious
eea5ecdd69 Update invokeai_version.py 2025-01-17 13:15:20 +11:00
psychedelicious
50de54dcfd chore(ui): lint 2025-01-17 12:48:58 +11:00
psychedelicious
04b893f982 chore(ui): typegen 2025-01-17 12:48:58 +11:00
psychedelicious
4c655eeb48 chore(ui): lint 2025-01-17 12:48:58 +11:00
psychedelicious
298abab883 feat(ui): improved generator text area styling 2025-01-17 12:48:58 +11:00
psychedelicious
bd477ded2e feat(ui): better preview for generators 2025-01-17 12:48:58 +11:00
psychedelicious
0b64d21980 tidy(ui): remove extraneous reset button on generators 2025-01-17 12:48:58 +11:00
psychedelicious
91d5f8537d feat(ui): add integer & float parse string generators 2025-01-17 12:48:58 +11:00
psychedelicious
e498e1f07c feat(ui): reworked float/int generators (arithmetic sequence, linear dist, uniform rand dist) 2025-01-17 12:48:58 +11:00
psychedelicious
73a3f195dc fix(ui): remove nonfunctional button 2025-01-17 12:48:58 +11:00
psychedelicious
8cc790a030 fix(ui): batch size calculations 2025-01-17 12:48:58 +11:00
psychedelicious
57265c8869 feat(ui): rip out generator modal functionality 2025-01-17 12:48:58 +11:00
psychedelicious
66d08eaa1c fix(ui): translation for generators 2025-01-17 12:48:58 +11:00
psychedelicious
d69e90ca5e feat(ui): support integer generators 2025-01-17 12:48:58 +11:00
psychedelicious
f345fde512 fix(ui): use utils to get default float generator values 2025-01-17 12:48:58 +11:00
psychedelicious
508c702289 feat(nodes): remove default values for generator; let UI handle it 2025-01-17 12:48:58 +11:00
psychedelicious
8fbd2f9a97 feat(nodes): add integer generator nodes 2025-01-17 12:48:58 +11:00
psychedelicious
bfb26af36a chore(ui): lint 2025-01-17 12:48:58 +11:00
psychedelicious
4400bc69f2 feat(ui): don't show generator preview for random generators 2025-01-17 12:48:58 +11:00
psychedelicious
10f2c0dc9a feat(ui): support generator nodes (wip)
- Add `batch` property to field type object to differentiate between executable nodes and batch/generator nodes.
- Support for float generators
2025-01-17 12:48:58 +11:00
psychedelicious
5b0326fc49 chore(ui): typegen 2025-01-17 12:48:58 +11:00
psychedelicious
2f9a0a250d feat(nodes): generators as nodes 2025-01-17 12:48:58 +11:00
psychedelicious
5d03328dc6 tidy(nodes): code dedupe for batch node init errors 2025-01-17 12:48:58 +11:00
psychedelicious
1fb32aec28 tidy(nodes): move batch nodes to own file 2025-01-17 12:48:58 +11:00
psychedelicious
2bbcd42036 chore(ui): knip 2025-01-17 12:34:54 +11:00
psychedelicious
2f40f7bafd tweak(ui): error verbiage for collection size mismatch 2025-01-17 12:34:54 +11:00
psychedelicious
65dd01bf3a fix(ui): invoke tooltip for invalid/empty batches 2025-01-17 12:34:54 +11:00
psychedelicious
81fc525f8a chore(ui): lint 2025-01-17 12:34:54 +11:00
psychedelicious
d2dd5ee408 fix(ui): unclosed JSX tag 2025-01-17 12:34:54 +11:00
psychedelicious
b4b1daeb26 feat(ui): validate all batch nodes have connection 2025-01-17 12:34:54 +11:00
psychedelicious
90c4c10e14 feat(ui): show batch group in node title 2025-01-17 12:34:54 +11:00
psychedelicious
30e33d30d5 fix(ui): handle batch group ids of "None" correctly 2025-01-17 12:34:54 +11:00
psychedelicious
3df3be6c34 tweak(ui): enum field selects have size="sm" 2025-01-17 12:34:54 +11:00
psychedelicious
4e917bf2b2 chore(ui): typegen 2025-01-17 12:34:54 +11:00
psychedelicious
26e6e28a13 feat(nodes): add title for batch_group_id field 2025-01-17 12:34:54 +11:00
psychedelicious
f9cee42a06 tweak(ui): node editor layout padding 2025-01-17 12:34:54 +11:00
psychedelicious
1b8da023b8 chore(ui): typegen 2025-01-17 12:34:54 +11:00
psychedelicious
05f1026812 feat(nodes): batch_group_id is a literal of options 2025-01-17 12:34:54 +11:00
psychedelicious
ca1bd254ea feat(ui): rename "link_id" -> "batch_group_id" 2025-01-17 12:34:54 +11:00
psychedelicious
29645326b9 chore(ui): typegen 2025-01-17 12:34:54 +11:00
psychedelicious
c23a2abc82 feat(nodes): rename "link_id" -> "batch_group_id" 2025-01-17 12:34:54 +11:00
psychedelicious
803ec8e904 feat(ui): add zipped batch collection size validation 2025-01-17 12:34:54 +11:00
psychedelicious
0abc0be931 fix(ui): allow batch nodes without link id (i.e. product batch nodes) to have mismatched collection sizes 2025-01-17 12:34:54 +11:00
psychedelicious
edff16124f feat(ui): support zipped batch nodes 2025-01-17 12:34:54 +11:00
psychedelicious
2e4110a29a chore(ui): typegen 2025-01-17 12:34:54 +11:00
psychedelicious
7ee51f3e14 feat(nodes): add link_id field to batch nodes
This is used to link batch nodes into zipped batch data collections.
2025-01-17 12:34:54 +11:00
psychedelicious
8ae75dbc35 chore(ui): typegen 2025-01-17 12:34:54 +11:00
psychedelicious
9265716b07 chore(ui): lint 2025-01-17 12:19:04 +11:00
psychedelicious
27b9c07711 chore(ui): typegen 2025-01-17 12:19:04 +11:00
psychedelicious
9dcbe3cc8f tweak(ui): number collection styling 2025-01-17 12:19:04 +11:00
psychedelicious
30165f66c3 feat(ui): string collection batch items are input not textarea 2025-01-17 12:19:04 +11:00
psychedelicious
deb70edc75 fix(ui): translation key 2025-01-17 12:19:04 +11:00
psychedelicious
d82d990b23 feat(ui): add number range generators 2025-01-17 12:19:04 +11:00
psychedelicious
2c64b60d32 Revert "feat(ui): rough out number generators for number collection fields"
This reverts commit 41cc6f1f96bca2a51727f21bd727ca48eab669bc.
2025-01-17 12:19:04 +11:00
psychedelicious
4e8c6d931d Revert "feat(ui): number collection generator supports floats"
This reverts commit 9da3339b513de9575ffbf6ce880b3097217b199d.
2025-01-17 12:19:04 +11:00
psychedelicious
9049e6e0f3 Revert "feat(ui): more batch generator stuff"
This reverts commit 111a29c7b4fc6b5062a0a37ce704a6508ff58dd8.
2025-01-17 12:19:04 +11:00
psychedelicious
3cb5f8536b feat(ui): more batch generator stuff 2025-01-17 12:19:04 +11:00
psychedelicious
38e50cc7aa tidy(ui): abstract out batch detection logic 2025-01-17 12:19:04 +11:00
psychedelicious
5bff6123b9 feat(nodes): add default value for batch nodes 2025-01-17 12:19:04 +11:00
psychedelicious
d63ff560d6 feat(ui): number collection generator supports floats 2025-01-17 12:19:04 +11:00
psychedelicious
acceac8304 fix(ui): do not set number collection field to undefined when removing last item 2025-01-17 12:19:04 +11:00
psychedelicious
96671d12bd fix(ui): filter out batch nodes when checking readiness on workflows tab 2025-01-17 12:19:04 +11:00
psychedelicious
584601d03f perf(ui): memoize selector in workflows 2025-01-17 12:19:04 +11:00
psychedelicious
b1c4ec0888 feat(ui): rough out number generators for number collection fields 2025-01-17 12:19:04 +11:00
psychedelicious
db5f016826 fix(nodes): allow batch datum items to mix ints and floats
Unfortunately we cannot do strict floats or ints.

The batch data models don't specify the value types, it instead relies on pydantic parsing. JSON doesn't differentiate between float and int, so a float `1.0` gets parsed as `1` in python.

As a result, we _must_ accept mixed floats and ints for BatchDatum.items.

Tests and validation updated to handle this.

Maybe we should update the BatchDatum model to have a `type` field? Then we could parse as float or int, depending on the inputs...
2025-01-17 12:19:04 +11:00
psychedelicious
c1fd28472d fix(ui): float batch data creation 2025-01-17 12:19:04 +11:00
psychedelicious
0c5958675a chore(ui): lint 2025-01-17 12:19:04 +11:00
psychedelicious
912e07f2c8 tidy(ui): use zod typeguard builder util for fields 2025-01-17 12:19:04 +11:00
psychedelicious
f853b24868 chore(ui): typegen 2025-01-17 12:19:04 +11:00
psychedelicious
4f900b22dc feat(ui): validate number item multipleOf 2025-01-17 12:19:04 +11:00
psychedelicious
5823532941 feat(ui): validate string item lengths 2025-01-17 12:19:04 +11:00
psychedelicious
bfe6d98cba feat(ui): support float batches 2025-01-17 12:19:04 +11:00
psychedelicious
c26b3cd54f refactor(ui): abstract out helper to add batch data 2025-01-17 12:19:04 +11:00
psychedelicious
c012d832d2 fix(ui): typo 2025-01-17 12:19:04 +11:00
psychedelicious
9d11d2aabd refactor(ui): abstract out field validators 2025-01-17 12:19:04 +11:00
psychedelicious
a5f1587ce7 feat(ui): add template validation for integer collection items 2025-01-17 12:19:04 +11:00
psychedelicious
0b26bb1ca3 feat(ui): add template validation for string collection items 2025-01-17 12:19:04 +11:00
psychedelicious
0f1e632117 feat(nodes): add float batch node 2025-01-17 12:19:04 +11:00
psychedelicious
b212332b3e feat(ui): support integer batches 2025-01-17 12:19:04 +11:00
psychedelicious
90a91ff438 feat(nodes): add integer batch node 2025-01-17 12:19:04 +11:00
psychedelicious
b52b271dc4 feat(ui): support string batches 2025-01-17 12:19:04 +11:00
psychedelicious
e077fe8046 refactor(ui): streamline image field collection input logic, support multiple images w/ same name in collection 2025-01-17 12:19:04 +11:00
psychedelicious
368957b208 tweak(ui): image field collection input component styling 2025-01-17 12:19:04 +11:00
psychedelicious
27277e1fd6 docs(ui): improved comments for image batch node special handling 2025-01-17 12:19:04 +11:00
psychedelicious
236c0d89e7 feat(nodes): add string batch node 2025-01-17 12:19:04 +11:00
psychedelicious
b807170701 fix(ui): typo in error message for image collection fields 2025-01-17 12:19:04 +11:00
Ryan Dick
c5d2de3169 Revise the default logic for the model cache RAM limit (#7566)
## Summary

This PR revises the logic for calculating the model cache RAM limit. See
the code for thorough documentation of the change.

The updated logic is more conservative in the amount of RAM that it will
use. This will likely be a better default for more users. Of course,
users can still choose to set a more aggressive limit by overriding the
logic with `max_cache_ram_gb`.

## Related Issues / Discussions

- Should help with https://github.com/invoke-ai/InvokeAI/issues/7563

## QA Instructions

Exercise all heuristics:
- [x] Heuristic 1
- [x] Heuristic 2
- [x] Heuristic 3
- [x] Heuristic 4

## Merge Plan

- [x] Merge https://github.com/invoke-ai/InvokeAI/pull/7565 first and
update the target branch

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
2025-01-16 19:59:14 -05:00
Ryan Dick
f7511bfd94 Add keep_ram_copy_of_weights config option (#7565)
## Summary

This PR adds a `keep_ram_copy_of_weights` config option the default (and
legacy) behavior is `true`. The tradeoffs for this setting are as
follows:
- `keep_ram_copy_of_weights: true`: Faster model switching and LoRA
patching.
- `keep_ram_copy_of_weights: false`: Lower average RAM load (may not
help significantly with peak RAM).

## Related Issues / Discussions

- Helps with https://github.com/invoke-ai/InvokeAI/issues/7563
- The Low-VRAM docs are updated to include this feature in
https://github.com/invoke-ai/InvokeAI/pull/7566

## QA Instructions

- Test with `enable_partial_load: false` and `keep_ram_copy_of_weights:
false`.
  - [x] RAM usage when model is loaded is reduced.
  - [x] Model loading / unloading works as expected.
  - [x] LoRA patching still works.
- Test with `enable_partial_load: false` and `keep_ram_copy_of_weights:
true`.
  - [x] Behavior should be unchanged.
- Test with `enable_partial_load: true` and `keep_ram_copy_of_weights:
false`.
  - [x] RAM usage when model is loaded is reduced.
  - [x] Model loading / unloading works as expected.
  - [x] LoRA patching still works.
- Test with `enable_partial_load: true` and `keep_ram_copy_of_weights:
true`.
  - [x] Behavior should be unchanged.

- [x] Smoke test CPU-only and MPS with default configs.

## Merge Plan

- [x] Merge https://github.com/invoke-ai/InvokeAI/pull/7564 first and
change target branch.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
2025-01-16 19:57:02 -05:00
Ryan Dick
0abb5ea114 Reduce peak memory during FLUX model load (#7564)
## Summary

Prior to this change, there were several cases where we initialized the
weights of a FLUX model before loading its state dict (and, to make
things worse, in some cases the weights were in float32). This PR fixes
a handful of these cases. (I think I found all instances for the FLUX
family of models.)

## Related Issues / Discussions

- Helps with https://github.com/invoke-ai/InvokeAI/issues/7563

## QA Instructions

I tested that that model loading still works and that there is no
virtual memory reservation on model initialization for the following
models:
- [x] FLUX VAE
- [x] Full T5 Encoder
- [x] Full FLUX checkpoint
- [x] GGUF FLUX checkpoint

## Merge Plan

No special instructions.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
2025-01-16 18:47:17 -05:00
Ryan Dick
ce57c4ed2e Update the Low-VRAM docs. 2025-01-16 23:46:07 +00:00
Ryan Dick
0cf51cefe8 Revise the logic for calculating the RAM model cache limit. 2025-01-16 23:46:07 +00:00
Ryan Dick
e5e848d239 Update config docstring. 2025-01-16 22:34:23 +00:00
Ryan Dick
da589b3f1f Memory optimization to load state dicts one module at a time in CachedModelWithPartialLoad when we are not storing a CPU copy of the state dict (i.e. when keep_ram_copy_of_weights=False). 2025-01-16 17:00:33 +00:00
Ryan Dick
36a3869af0 Add keep_ram_copy_of_weights config option. 2025-01-16 15:35:25 +00:00
Ryan Dick
c76d08d1fd Add keep_ram_copy option to CachedModelOnlyFullLoad. 2025-01-16 15:08:23 +00:00
Ryan Dick
04087c38ce Add keep_ram_copy option to CachedModelWithPartialLoad. 2025-01-16 14:51:44 +00:00
Ryan Dick
b2bb359d47 Update the model loading logic for several of the large FLUX-related models to ensure that the model is initialized on the meta device prior to loading the state dict into it. This helps to keep peak memory down. 2025-01-16 02:30:28 +00:00
Mary Hipp
b57aa06d9e take out AbortController logic and simplify dependencies 2025-01-16 09:39:32 +11:00
Mary Hipp
f856246c36 try removing abortcontroller 2025-01-16 09:39:32 +11:00
Mary Hipp
195df2ebe6 remove logic changes, keep logging 2025-01-16 09:39:32 +11:00
Mary Hipp
7b5cef6bd7 lint fix 2025-01-16 09:39:32 +11:00
Mary Hipp
69e7ffaaf5 add logging, remove deps 2025-01-16 09:39:32 +11:00
psychedelicious
993401ad6c fix(ui): hide layer when previewing filter
Previously, when previewing a filter on a layer with some transparency or a filter that changes the alpha, the preview was rendered on top of the layer. The preview blended with the layer, which isn't right.

In this change, the layer is hidden during the preview, and when the filter finishes (having been applied or canceled - the two possible paths), the layer is shown.

Technically, we are hiding and showing the layer's object renderer's konva group, which contains the layer's "real" data.

Another small change was made to prevent a flash of empty layer, by waiting to destroy a previous filter preview image until the new preview image is ready to display.
2025-01-16 09:27:36 +11:00
psychedelicious
8d570dcffc chore(ui): typegen 2025-01-16 09:27:36 +11:00
psychedelicious
3f70e947fd chore: ruff 2025-01-16 09:27:36 +11:00
dunkeroni
157290bef4 add: size option for image noise node and filter 2025-01-16 09:27:36 +11:00
dunkeroni
b7389da89b add: Noise filter on Canvas 2025-01-16 09:27:36 +11:00
dunkeroni
254b89b1f5 add: Blur filter option on canvas 2025-01-16 09:27:36 +11:00
dunkeroni
2b122d7882 add: image noise invocation 2025-01-16 09:27:36 +11:00
dunkeroni
ded9213eb4 trim blur splitting logic 2025-01-16 09:27:36 +11:00
dunkeroni
9d51eb49cd fix: ImageBlurInvocation handles transparency now 2025-01-16 09:27:36 +11:00
dunkeroni
0a6e22bc9e fix: ImagePasteInvocation respects transparency 2025-01-16 09:27:36 +11:00
Ryan Dick
b301785dc8 Normalize the T5 model identifiers so that a FLUX T5 or an SD3 T5 model can be used interchangeably. 2025-01-16 08:33:58 +11:00
psychedelicious
edcdff4f78 fix(ui): round rects when applying transform
Due to the limited floating point precision, and konva's `scale` properties, it is possible for the relative rect of an object to have non-integer coordinates and dimensions.

When we go to rasterize and otherwise export images, the HTML canvas API truncates these numbers.

So, we can end up with situations where the relative width and height of a layer are very close to the "real" value, but slightly off.

For example, width and height might be 512px, but the relative rect is calculated to be something like 512.000000003 or 511.9999999997.

In the first case, the truncation results in 512x512 for the dimensions - which is correct. But in the second case, it results in 511x511!

One place where this causes issues is the image action `New Canvas from image -> As Raster Layer (resize)`. For certain input image sizes, this results in an incorrectly resized image. For example, a 1496x1946 input image is resized to 511x511 pixels when the bbox is 512x512.

To fix this, we can round both coords and dimensions of rects when rasterizing.

I've thought through the implications and done some testing. I believe this change will not cause any regressions and only fix edge cases. But, it's possible that something was inadvertently relying on the old behavior.
2025-01-16 01:17:30 +11:00
psychedelicious
66e04ea7ab fix(ui): sticky preset image tooltip
There's a bug where preset image tooltips get stuck open in the list.

After much fiddling, debugging, and review of upstream dependencies, I have determined that this is bug in Chakra-UI v2.

Specifically, it appears to be a race condition related to the Tooltip component's internal use of the `useDisclosure` hook to manage tooltip open state, and the react render cycle.

Unfortunately, Chakra v2 is no longer being updated, and it's a pain in the butt to vendor and fix that component given its dependencies. Not 100% sure I could easily fix it, anyways.

Fortunately, there is a workaround - reduce the tooltip openDelay to 0ms. I prefer the current 500ms delay but I think it's preferable to have too-quick tooltips than too-sticky tooltips...
2025-01-15 09:12:46 -05:00
Ryan Dick
497bc916cc Add unet_config to get_scheduler(...) call in TiledMultiDiffusionDenoiseLatents. 2025-01-15 08:44:08 -05:00
dunkeroni
ebe1873712 fix: only add prediction type if it exists 2025-01-15 08:44:08 -05:00
dunkeroni
59926c320c support v-prediction in denoise_latents.py 2025-01-15 08:44:08 -05:00
Mary Hipp
2d3e2f1907 use window instead of document 2025-01-14 20:01:08 -05:00
99 changed files with 5336 additions and 716 deletions

1
.nvmrc Normal file
View File

@@ -0,0 +1 @@
v22.12.0

View File

@@ -28,11 +28,12 @@ It is possible to fine-tune the settings for best performance or if you still ge
## Details and fine-tuning
Low-VRAM mode involves 3 features, each of which can be configured or fine-tuned:
Low-VRAM mode involves 4 features, each of which can be configured or fine-tuned:
- Partial model loading
- Dynamic RAM and VRAM cache sizes
- Working memory
- Partial model loading (`enable_partial_loading`)
- Dynamic RAM and VRAM cache sizes (`max_cache_ram_gb`, `max_cache_vram_gb`)
- Working memory (`device_working_mem_gb`)
- Keeping a RAM weight copy (`keep_ram_copy_of_weights`)
Read on to learn about these features and understand how to fine-tune them for your system and use-cases.
@@ -67,12 +68,20 @@ As of v5.6.0, the caches are dynamically sized. The `ram` and `vram` settings ar
But, if your GPU has enough VRAM to hold models fully, you might get a perf boost by manually setting the cache sizes in `invokeai.yaml`:
```yaml
# Set the RAM cache size to as large as possible, leaving a few GB free for the rest of your system and Invoke.
# For example, if your system has 32GB RAM, 28GB is a good value.
# The default max cache RAM size is logged on InvokeAI startup. It is determined based on your system RAM / VRAM.
# You can override the default value by setting `max_cache_ram_gb`.
# Increasing `max_cache_ram_gb` will increase the amount of RAM used to cache inactive models, resulting in faster model
# reloads for the cached models.
# As an example, if your system has 32GB of RAM and no other heavy processes, setting the `max_cache_ram_gb` to 28GB
# might be a good value to achieve aggressive model caching.
max_cache_ram_gb: 28
# Set the VRAM cache size to be as large as possible while leaving enough room for the working memory of the tasks you will be doing.
# For example, on a 24GB GPU that will be running unquantized FLUX without any auxiliary models,
# 18GB is a good value.
# The default max cache VRAM size is adjusted dynamically based on the amount of available VRAM (taking into
# consideration the VRAM used by other processes).
# You can override the default value by setting `max_cache_vram_gb`. Note that this value takes precedence over the
# `device_working_mem_gb`.
# It is recommended to set the VRAM cache size to be as large as possible while leaving enough room for the working
# memory of the tasks you will be doing. For example, on a 24GB GPU that will be running unquantized FLUX without any
# auxiliary models, 18GB might be a good value.
max_cache_vram_gb: 18
```
@@ -109,6 +118,15 @@ device_working_mem_gb: 4
Once decoding completes, the model manager "reclaims" the extra VRAM allocated as working memory for future model loading operations.
### Keeping a RAM weight copy
Invoke has the option of keeping a RAM copy of all model weights, even when they are loaded onto the GPU. This optimization is _on_ by default, and enables faster model switching and LoRA patching. Disabling this feature will reduce the average RAM load while running Invoke (peak RAM likely won't change), at the cost of slower model switching and LoRA patching. If you have limited RAM, you can disable this optimization:
```yaml
# Set to false to reduce the average RAM usage at the cost of slower model switching and LoRA patching.
keep_ram_copy_of_weights: false
```
### Disabling Nvidia sysmem fallback (Windows only)
On Windows, Nvidia GPUs are able to use system RAM when their VRAM fills up via **sysmem fallback**. While it sounds like a good idea on the surface, in practice it causes massive slowdowns during generation.
@@ -127,3 +145,19 @@ It is strongly suggested to disable this feature:
If the sysmem fallback feature sounds familiar, that's because Invoke's partial model loading strategy is conceptually very similar - use VRAM when there's room, else fall back to RAM.
Unfortunately, the Nvidia implementation is not optimized for applications like Invoke and does more harm than good.
## Troubleshooting
### Windows page file
Invoke has high virtual memory (a.k.a. 'committed memory') requirements. This can cause issues on Windows if the page file size limits are hit. (See this issue for the technical details on why this happens: https://github.com/invoke-ai/InvokeAI/issues/7563).
If you run out of page file space, InvokeAI may crash. Often, these crashes will happen with one of the following errors:
- InvokeAI exits with Windows error code `3221225477`
- InvokeAI crashes without an error, but `eventvwr.msc` reveals an error with code `0xc0000005` (the hex equivalent of `3221225477`)
If you are running out of page file space, try the following solutions:
- Make sure that you have sufficient disk space for the page file to grow. Watch your disk usage as Invoke runs. If it climbs near 100% leading up to the crash, then this is very likely the source of the issue. Clear out some disk space to resolve the issue.
- Make sure that your page file is set to "System managed size" (this is the default) rather than a custom size. Under the "System managed size" policy, the page file will grow dynamically as needed.

View File

@@ -25,6 +25,7 @@ async def parse_dynamicprompts(
prompt: str = Body(description="The prompt to parse with dynamicprompts"),
max_prompts: int = Body(ge=1, le=10000, default=1000, description="The max number of prompts to generate"),
combinatorial: bool = Body(default=True, description="Whether to use the combinatorial generator"),
seed: int | None = Body(None, description="The seed to use for random generation. Only used if not combinatorial"),
) -> DynamicPromptsResponse:
"""Creates a batch process"""
max_prompts = min(max_prompts, 10000)
@@ -35,7 +36,7 @@ async def parse_dynamicprompts(
generator = CombinatorialPromptGenerator()
prompts = generator.generate(prompt, max_prompts=max_prompts)
else:
generator = RandomPromptGenerator()
generator = RandomPromptGenerator(seed=seed)
prompts = generator.generate(prompt, num_images=max_prompts)
except ParseException as e:
prompts = [prompt]

View File

@@ -0,0 +1,237 @@
from typing import Literal
from pydantic import BaseModel
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import (
ImageField,
Input,
InputField,
OutputField,
)
from invokeai.app.invocations.primitives import (
FloatOutput,
ImageOutput,
IntegerOutput,
StringOutput,
)
from invokeai.app.services.shared.invocation_context import InvocationContext
BATCH_GROUP_IDS = Literal[
"None",
"Group 1",
"Group 2",
"Group 3",
"Group 4",
"Group 5",
]
class NotExecutableNodeError(Exception):
def __init__(self, message: str = "This class should never be executed or instantiated directly."):
super().__init__(message)
pass
class BaseBatchInvocation(BaseInvocation):
batch_group_id: BATCH_GROUP_IDS = InputField(
default="None",
description="The ID of this batch node's group. If provided, all batch nodes in with the same ID will be 'zipped' before execution, and all nodes' collections must be of the same size.",
input=Input.Direct,
title="Batch Group",
)
def __init__(self):
raise NotExecutableNodeError()
@invocation(
"image_batch",
title="Image Batch",
tags=["primitives", "image", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class ImageBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each image in the batch."""
images: list[ImageField] = InputField(
default=[], min_length=1, description="The images to batch over", input=Input.Direct
)
def invoke(self, context: InvocationContext) -> ImageOutput:
raise NotExecutableNodeError()
@invocation(
"string_batch",
title="String Batch",
tags=["primitives", "string", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class StringBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each string in the batch."""
strings: list[str] = InputField(
default=[],
min_length=1,
description="The strings to batch over",
)
def invoke(self, context: InvocationContext) -> StringOutput:
raise NotExecutableNodeError()
@invocation_output("string_generator_output")
class StringGeneratorOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of strings"""
strings: list[str] = OutputField(description="The generated strings")
class StringGeneratorField(BaseModel):
pass
@invocation(
"string_generator",
title="String Generator",
tags=["primitives", "string", "number", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class StringGenerator(BaseInvocation):
"""Generated a range of strings for use in a batched generation"""
generator: StringGeneratorField = InputField(
description="The string generator.",
input=Input.Direct,
title="Generator Type",
)
def __init__(self):
raise NotExecutableNodeError()
def invoke(self, context: InvocationContext) -> StringGeneratorOutput:
raise NotExecutableNodeError()
@invocation(
"integer_batch",
title="Integer Batch",
tags=["primitives", "integer", "number", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class IntegerBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each integer in the batch."""
integers: list[int] = InputField(
default=[],
min_length=1,
description="The integers to batch over",
)
def invoke(self, context: InvocationContext) -> IntegerOutput:
raise NotExecutableNodeError()
@invocation_output("integer_generator_output")
class IntegerGeneratorOutput(BaseInvocationOutput):
integers: list[int] = OutputField(description="The generated integers")
class IntegerGeneratorField(BaseModel):
pass
@invocation(
"integer_generator",
title="Integer Generator",
tags=["primitives", "int", "number", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class IntegerGenerator(BaseInvocation):
"""Generated a range of integers for use in a batched generation"""
generator: IntegerGeneratorField = InputField(
description="The integer generator.",
input=Input.Direct,
title="Generator Type",
)
def __init__(self):
raise NotExecutableNodeError()
def invoke(self, context: InvocationContext) -> IntegerGeneratorOutput:
raise NotExecutableNodeError()
@invocation(
"float_batch",
title="Float Batch",
tags=["primitives", "float", "number", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class FloatBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each float in the batch."""
floats: list[float] = InputField(
default=[],
min_length=1,
description="The floats to batch over",
)
def invoke(self, context: InvocationContext) -> FloatOutput:
raise NotExecutableNodeError()
@invocation_output("float_generator_output")
class FloatGeneratorOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of floats"""
floats: list[float] = OutputField(description="The generated floats")
class FloatGeneratorField(BaseModel):
pass
@invocation(
"float_generator",
title="Float Generator",
tags=["primitives", "float", "number", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class FloatGenerator(BaseInvocation):
"""Generated a range of floats for use in a batched generation"""
generator: FloatGeneratorField = InputField(
description="The float generator.",
input=Input.Direct,
title="Generator Type",
)
def __init__(self):
raise NotExecutableNodeError()
def invoke(self, context: InvocationContext) -> FloatGeneratorOutput:
raise NotExecutableNodeError()

View File

@@ -40,6 +40,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
from invokeai.backend.model_manager.config import AnyModelConfig
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
@@ -85,6 +86,7 @@ def get_scheduler(
scheduler_info: ModelIdentifierField,
scheduler_name: str,
seed: int,
unet_config: AnyModelConfig,
) -> Scheduler:
"""Load a scheduler and apply some scheduler-specific overrides."""
# TODO(ryand): Silently falling back to ddim seems like a bad idea. Look into why this was added and remove if
@@ -103,6 +105,9 @@ def get_scheduler(
"_backup": scheduler_config,
}
if hasattr(unet_config, "prediction_type"):
scheduler_config["prediction_type"] = unet_config.prediction_type
# make dpmpp_sde reproducable(seed can be passed only in initializer)
if scheduler_class is DPMSolverSDEScheduler:
scheduler_config["noise_sampler_seed"] = seed
@@ -829,6 +834,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
_, _, latent_height, latent_width = latents.shape
# get the unet's config so that we can pass the base to sd_step_callback()
unet_config = context.models.get_config(self.unet.unet.key)
conditioning_data = self.get_conditioning_data(
context=context,
positive_conditioning_field=self.positive_conditioning,
@@ -848,6 +856,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
seed=seed,
unet_config=unet_config,
)
timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
@@ -859,9 +868,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
denoising_end=self.denoising_end,
)
# get the unet's config so that we can pass the base to sd_step_callback()
unet_config = context.models.get_config(self.unet.unet.key)
### preview
def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, unet_config.base)
@@ -1030,6 +1036,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
seed=seed,
unet_config=unet_config,
)
pipeline = self.create_pipeline(unet, scheduler)

View File

@@ -300,6 +300,13 @@ class BoundingBoxField(BaseModel):
raise ValueError(f"y_min ({self.y_min}) is greater than y_max ({self.y_max}).")
return self
def tuple(self) -> Tuple[int, int, int, int]:
"""
Returns the bounding box as a tuple suitable for use with PIL's `Image.crop()` method.
This method returns a tuple of the form (left, upper, right, lower) == (x_min, y_min, x_max, y_max).
"""
return (self.x_min, self.y_min, self.x_max, self.y_max)
class MetadataField(RootModel[dict[str, Any]]):
"""

View File

@@ -10,6 +10,10 @@ from invokeai.app.invocations.baseinvocation import (
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.t5_model_identifier import (
preprocess_t5_encoder_model_identifier,
preprocess_t5_tokenizer_model_identifier,
)
from invokeai.backend.flux.util import max_seq_lengths
from invokeai.backend.model_manager.config import (
CheckpointConfigBase,
@@ -74,8 +78,8 @@ class FluxModelLoaderInvocation(BaseInvocation):
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
tokenizer2 = preprocess_t5_tokenizer_model_identifier(self.t5_encoder_model)
t5_encoder = preprocess_t5_encoder_model_identifier(self.t5_encoder_model)
transformer_config = context.models.get_config(transformer)
assert isinstance(transformer_config, CheckpointConfigBase)

View File

@@ -2,7 +2,7 @@ from contextlib import ExitStack
from typing import Iterator, Literal, Optional, Tuple
import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer, T5TokenizerFast
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
@@ -76,7 +76,7 @@ class FluxTextEncoderInvocation(BaseInvocation):
context.models.load(self.t5_encoder.tokenizer) as t5_tokenizer,
):
assert isinstance(t5_text_encoder, T5EncoderModel)
assert isinstance(t5_tokenizer, T5Tokenizer)
assert isinstance(t5_tokenizer, (T5Tokenizer, T5TokenizerFast))
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)

View File

@@ -21,7 +21,7 @@ class IdealSizeOutput(BaseInvocationOutput):
"ideal_size",
title="Ideal Size",
tags=["latents", "math", "ideal_size"],
version="1.0.3",
version="1.0.4",
)
class IdealSizeInvocation(BaseInvocation):
"""Calculates the ideal size for generation to avoid duplication"""
@@ -41,11 +41,16 @@ class IdealSizeInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> IdealSizeOutput:
unet_config = context.models.get_config(self.unet.unet.key)
aspect = self.width / self.height
dimension: float = 512
if unet_config.base == BaseModelType.StableDiffusion2:
if unet_config.base == BaseModelType.StableDiffusion1:
dimension = 512
elif unet_config.base == BaseModelType.StableDiffusion2:
dimension = 768
elif unet_config.base == BaseModelType.StableDiffusionXL:
elif unet_config.base in (BaseModelType.StableDiffusionXL, BaseModelType.Flux, BaseModelType.StableDiffusion3):
dimension = 1024
else:
raise ValueError(f"Unsupported model type: {unet_config.base}")
dimension = dimension * self.multiplier
min_dimension = math.floor(dimension * 0.5)
model_area = dimension * dimension # hardcoded for now since all models are trained on square images

View File

@@ -13,6 +13,7 @@ from invokeai.app.invocations.baseinvocation import (
)
from invokeai.app.invocations.constants import IMAGE_MODES
from invokeai.app.invocations.fields import (
BoundingBoxField,
ColorField,
FieldDescriptions,
ImageField,
@@ -23,6 +24,7 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.image_records.image_records_common import ImageCategory
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.misc import SEED_MAX
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
from invokeai.backend.image_util.safety_checker import SafetyChecker
@@ -161,12 +163,12 @@ class ImagePasteInvocation(BaseInvocation, WithMetadata, WithBoard):
crop: bool = InputField(default=False, description="Crop to base image dimensions")
def invoke(self, context: InvocationContext) -> ImageOutput:
base_image = context.images.get_pil(self.base_image.image_name)
image = context.images.get_pil(self.image.image_name)
base_image = context.images.get_pil(self.base_image.image_name, mode="RGBA")
image = context.images.get_pil(self.image.image_name, mode="RGBA")
mask = None
if self.mask is not None:
mask = context.images.get_pil(self.mask.image_name)
mask = ImageOps.invert(mask.convert("L"))
mask = context.images.get_pil(self.mask.image_name, mode="L")
mask = ImageOps.invert(mask)
# TODO: probably shouldn't invert mask here... should user be required to do it?
min_x = min(0, self.x)
@@ -176,7 +178,11 @@ class ImagePasteInvocation(BaseInvocation, WithMetadata, WithBoard):
new_image = Image.new(mode="RGBA", size=(max_x - min_x, max_y - min_y), color=(0, 0, 0, 0))
new_image.paste(base_image, (abs(min_x), abs(min_y)))
new_image.paste(image, (max(0, self.x), max(0, self.y)), mask=mask)
# Create a temporary image to paste the image with transparency
temp_image = Image.new("RGBA", new_image.size)
temp_image.paste(image, (max(0, self.x), max(0, self.y)), mask=mask)
new_image = Image.alpha_composite(new_image, temp_image)
if self.crop:
base_w, base_h = base_image.size
@@ -301,14 +307,44 @@ class ImageBlurInvocation(BaseInvocation, WithMetadata, WithBoard):
blur_type: Literal["gaussian", "box"] = InputField(default="gaussian", description="The type of blur")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
image = context.images.get_pil(self.image.image_name, mode="RGBA")
# Split the image into RGBA channels
r, g, b, a = image.split()
# Premultiply RGB channels by alpha
premultiplied_image = ImageChops.multiply(image, a.convert("RGBA"))
premultiplied_image.putalpha(a)
# Apply the blur
blur = (
ImageFilter.GaussianBlur(self.radius) if self.blur_type == "gaussian" else ImageFilter.BoxBlur(self.radius)
)
blur_image = image.filter(blur)
blurred_image = premultiplied_image.filter(blur)
image_dto = context.images.save(image=blur_image)
# Split the blurred image into RGBA channels
r, g, b, a_orig = blurred_image.split()
# Convert to float using NumPy. float 32/64 division are much faster than float 16
r = numpy.array(r, dtype=numpy.float32)
g = numpy.array(g, dtype=numpy.float32)
b = numpy.array(b, dtype=numpy.float32)
a = numpy.array(a_orig, dtype=numpy.float32) / 255.0 # Normalize alpha to [0, 1]
# Unpremultiply RGB channels by alpha
r /= a + 1e-6 # Add a small epsilon to avoid division by zero
g /= a + 1e-6
b /= a + 1e-6
# Convert back to PIL images
r = Image.fromarray(numpy.uint8(numpy.clip(r, 0, 255)))
g = Image.fromarray(numpy.uint8(numpy.clip(g, 0, 255)))
b = Image.fromarray(numpy.uint8(numpy.clip(b, 0, 255)))
# Merge back into a single image
result_image = Image.merge("RGBA", (r, g, b, a_orig))
image_dto = context.images.save(image=result_image)
return ImageOutput.build(image_dto)
@@ -962,10 +998,10 @@ class CanvasPasteBackInvocation(BaseInvocation, WithMetadata, WithBoard):
@invocation(
"mask_from_id",
title="Mask from ID",
title="Mask from Segmented Image",
tags=["image", "mask", "id"],
category="image",
version="1.0.0",
version="1.0.1",
)
class MaskFromIDInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generate a mask for a particular color in an ID Map"""
@@ -975,40 +1011,24 @@ class MaskFromIDInvocation(BaseInvocation, WithMetadata, WithBoard):
threshold: int = InputField(default=100, description="Threshold for color detection")
invert: bool = InputField(default=False, description="Whether or not to invert the mask")
def rgba_to_hex(self, rgba_color: tuple[int, int, int, int]):
r, g, b, a = rgba_color
hex_code = "#{:02X}{:02X}{:02X}{:02X}".format(r, g, b, int(a * 255))
return hex_code
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, mode="RGBA")
def id_to_mask(self, id_mask: Image.Image, color: tuple[int, int, int, int], threshold: int = 100):
if id_mask.mode != "RGB":
id_mask = id_mask.convert("RGB")
# Can directly just use the tuple but I'll leave this rgba_to_hex here
# incase anyone prefers using hex codes directly instead of the color picker
hex_color_str = self.rgba_to_hex(color)
rgb_color = numpy.array([int(hex_color_str[i : i + 2], 16) for i in (1, 3, 5)])
np_color = numpy.array(self.color.tuple())
# Maybe there's a faster way to calculate this distance but I can't think of any right now.
color_distance = numpy.linalg.norm(id_mask - rgb_color, axis=-1)
color_distance = numpy.linalg.norm(image - np_color, axis=-1)
# Create a mask based on the threshold and the distance calculated above
binary_mask = (color_distance < threshold).astype(numpy.uint8) * 255
binary_mask = (color_distance < self.threshold).astype(numpy.uint8) * 255
# Convert the mask back to PIL
binary_mask_pil = Image.fromarray(binary_mask)
return binary_mask_pil
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
mask = self.id_to_mask(image, self.color.tuple(), self.threshold)
if self.invert:
mask = ImageOps.invert(mask)
binary_mask_pil = ImageOps.invert(binary_mask_pil)
image_dto = context.images.save(image=mask, image_category=ImageCategory.MASK)
image_dto = context.images.save(image=binary_mask_pil, image_category=ImageCategory.MASK)
return ImageOutput.build(image_dto)
@@ -1055,3 +1075,123 @@ class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard):
image_dto = context.images.save(image=generated_image)
return ImageOutput.build(image_dto)
@invocation(
"img_noise",
title="Add Image Noise",
tags=["image", "noise"],
category="image",
version="1.0.1",
)
class ImageNoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Add noise to an image"""
image: ImageField = InputField(description="The image to add noise to")
seed: int = InputField(
default=0,
ge=0,
le=SEED_MAX,
description=FieldDescriptions.seed,
)
noise_type: Literal["gaussian", "salt_and_pepper"] = InputField(
default="gaussian",
description="The type of noise to add",
)
amount: float = InputField(default=0.1, ge=0, le=1, description="The amount of noise to add")
noise_color: bool = InputField(default=True, description="Whether to add colored noise")
size: int = InputField(default=1, ge=1, description="The size of the noise points")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, mode="RGBA")
# Save out the alpha channel
alpha = image.getchannel("A")
# Set the seed for numpy random
rs = numpy.random.RandomState(numpy.random.MT19937(numpy.random.SeedSequence(self.seed)))
if self.noise_type == "gaussian":
if self.noise_color:
noise = rs.normal(0, 1, (image.height // self.size, image.width // self.size, 3)) * 255
else:
noise = rs.normal(0, 1, (image.height // self.size, image.width // self.size)) * 255
noise = numpy.stack([noise] * 3, axis=-1)
elif self.noise_type == "salt_and_pepper":
if self.noise_color:
noise = rs.choice(
[0, 255], (image.height // self.size, image.width // self.size, 3), p=[1 - self.amount, self.amount]
)
else:
noise = rs.choice(
[0, 255], (image.height // self.size, image.width // self.size), p=[1 - self.amount, self.amount]
)
noise = numpy.stack([noise] * 3, axis=-1)
noise = Image.fromarray(noise.astype(numpy.uint8), mode="RGB").resize(
(image.width, image.height), Image.Resampling.NEAREST
)
noisy_image = Image.blend(image.convert("RGB"), noise, self.amount).convert("RGBA")
# Paste back the alpha channel
noisy_image.putalpha(alpha)
image_dto = context.images.save(image=noisy_image)
return ImageOutput.build(image_dto)
@invocation(
"crop_image_to_bounding_box",
title="Crop Image to Bounding Box",
category="image",
version="1.0.0",
tags=["image", "crop"],
classification=Classification.Beta,
)
class CropImageToBoundingBoxInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Crop an image to the given bounding box. If the bounding box is omitted, the image is cropped to the non-transparent pixels."""
image: ImageField = InputField(description="The image to crop")
bounding_box: BoundingBoxField | None = InputField(
default=None, description="The bounding box to crop the image to"
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
bounding_box = self.bounding_box.tuple() if self.bounding_box is not None else image.getbbox()
cropped_image = image.crop(bounding_box)
image_dto = context.images.save(image=cropped_image)
return ImageOutput.build(image_dto)
@invocation(
"paste_image_into_bounding_box",
title="Paste Image into Bounding Box",
category="image",
version="1.0.0",
tags=["image", "crop"],
classification=Classification.Beta,
)
class PasteImageIntoBoundingBoxInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Paste the source image into the target image at the given bounding box.
The source image must be the same size as the bounding box, and the bounding box must fit within the target image."""
source_image: ImageField = InputField(description="The image to paste")
target_image: ImageField = InputField(description="The image to paste into")
bounding_box: BoundingBoxField = InputField(description="The bounding box to paste the image into")
def invoke(self, context: InvocationContext) -> ImageOutput:
source_image = context.images.get_pil(self.source_image.image_name, mode="RGBA")
target_image = context.images.get_pil(self.target_image.image_name, mode="RGBA")
bounding_box = self.bounding_box.tuple()
target_image.paste(source_image, bounding_box, source_image)
image_dto = context.images.save(image=target_image)
return ImageOutput.build(image_dto)

View File

@@ -2,9 +2,22 @@ import numpy as np
import torch
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, InvocationContext, invocation
from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithBoard, WithMetadata
from invokeai.app.invocations.primitives import ImageOutput, MaskOutput
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
Classification,
InvocationContext,
invocation,
)
from invokeai.app.invocations.fields import (
BoundingBoxField,
ColorField,
ImageField,
InputField,
TensorField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.primitives import BoundingBoxOutput, ImageOutput, MaskOutput
from invokeai.backend.image_util.util import pil_to_np
@@ -201,3 +214,48 @@ class ApplyMaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
image_dto = context.images.save(image=masked_image)
return ImageOutput.build(image_dto)
WHITE = ColorField(r=255, g=255, b=255, a=255)
@invocation(
"get_image_mask_bounding_box",
title="Get Image Mask Bounding Box",
tags=["mask"],
category="mask",
version="1.0.0",
classification=Classification.Beta,
)
class GetMaskBoundingBoxInvocation(BaseInvocation):
"""Gets the bounding box of the given mask image."""
mask: ImageField = InputField(description="The mask to crop.")
margin: int = InputField(default=0, description="Margin to add to the bounding box.")
mask_color: ColorField = InputField(default=WHITE, description="Color of the mask in the image.")
def invoke(self, context: InvocationContext) -> BoundingBoxOutput:
mask = context.images.get_pil(self.mask.image_name, mode="RGBA")
mask_np = np.array(mask)
# Convert mask_color to RGBA tuple
mask_color_rgb = self.mask_color.tuple()
# Find the bounding box of the mask color
y, x = np.where(np.all(mask_np == mask_color_rgb, axis=-1))
if len(x) == 0 or len(y) == 0:
# No pixels found with the given color
return BoundingBoxOutput(bounding_box=BoundingBoxField(x_min=0, y_min=0, x_max=0, y_max=0))
left, upper, right, lower = x.min(), y.min(), x.max(), y.max()
# Add the margin
left = max(0, left - self.margin)
upper = max(0, upper - self.margin)
right = min(mask_np.shape[1], right + self.margin)
lower = min(mask_np.shape[0], lower + self.margin)
bounding_box = BoundingBoxField(x_min=left, y_min=upper, x_max=right, y_max=lower)
return BoundingBoxOutput(bounding_box=bounding_box)

View File

@@ -7,7 +7,6 @@ import torch
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
@@ -539,23 +538,3 @@ class BoundingBoxInvocation(BaseInvocation):
# endregion
@invocation(
"image_batch",
title="Image Batch",
tags=["primitives", "image", "batch", "internal"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class ImageBatchInvocation(BaseInvocation):
"""Create a batched generation, where the workflow is executed once for each image in the batch."""
images: list[ImageField] = InputField(min_length=1, description="The images to batch over", input=Input.Direct)
def __init__(self):
raise NotImplementedError("This class should never be executed or instantiated directly.")
def invoke(self, context: InvocationContext) -> ImageOutput:
raise NotImplementedError("This class should never be executed or instantiated directly.")

View File

@@ -10,6 +10,10 @@ from invokeai.app.invocations.baseinvocation import (
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.t5_model_identifier import (
preprocess_t5_encoder_model_identifier,
preprocess_t5_tokenizer_model_identifier,
)
from invokeai.backend.model_manager.config import SubModelType
@@ -88,16 +92,8 @@ class Sd3ModelLoaderInvocation(BaseInvocation):
if self.clip_g_model
else self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
)
tokenizer_t5 = (
self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
if self.t5_encoder_model
else self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
)
t5_encoder = (
self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
if self.t5_encoder_model
else self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
)
tokenizer_t5 = preprocess_t5_tokenizer_model_identifier(self.t5_encoder_model or self.model)
t5_encoder = preprocess_t5_encoder_model_identifier(self.t5_encoder_model or self.model)
return Sd3ModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[]),

View File

@@ -218,6 +218,7 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
seed=seed,
unet_config=unet_config,
)
pipeline = self.create_pipeline(unet=unet, scheduler=scheduler)

View File

@@ -87,6 +87,7 @@ class InvokeAIAppConfig(BaseSettings):
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
device_working_mem_gb: The amount of working memory to keep available on the compute device (in GB). Has no effect if running on CPU. If you are experiencing OOM errors, try increasing this value.
enable_partial_loading: Enable partial loading of models. This enables models to run with reduced VRAM requirements (at the cost of slower speed) by streaming the model from RAM to VRAM as its used. In some edge cases, partial loading can cause models to run more slowly if they were previously being fully loaded into VRAM.
keep_ram_copy_of_weights: Whether to keep a full RAM copy of a model's weights when the model is loaded in VRAM. Keeping a RAM copy increases average RAM usage, but speeds up model switching and LoRA patching (assuming there is sufficient RAM). Set this to False if RAM pressure is consistently high.
ram: DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_ram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.
vram: DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_vram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.
lazy_offload: DEPRECATED: This setting is no longer used. Lazy-offloading is enabled by default. This config setting will be removed once the new model cache behavior is stable.
@@ -162,6 +163,7 @@ class InvokeAIAppConfig(BaseSettings):
log_memory_usage: bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.")
device_working_mem_gb: float = Field(default=3, description="The amount of working memory to keep available on the compute device (in GB). Has no effect if running on CPU. If you are experiencing OOM errors, try increasing this value.")
enable_partial_loading: bool = Field(default=False, description="Enable partial loading of models. This enables models to run with reduced VRAM requirements (at the cost of slower speed) by streaming the model from RAM to VRAM as its used. In some edge cases, partial loading can cause models to run more slowly if they were previously being fully loaded into VRAM.")
keep_ram_copy_of_weights: bool = Field(default=True, description="Whether to keep a full RAM copy of a model's weights when the model is loaded in VRAM. Keeping a RAM copy increases average RAM usage, but speeds up model switching and LoRA patching (assuming there is sufficient RAM). Set this to False if RAM pressure is consistently high.")
# Deprecated CACHE configs
ram: Optional[float] = Field(default=None, gt=0, description="DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_ram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.")
vram: Optional[float] = Field(default=None, ge=0, description="DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_vram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.")

View File

@@ -84,6 +84,7 @@ class ModelManagerService(ModelManagerServiceBase):
ram_cache = ModelCache(
execution_device_working_mem_gb=app_config.device_working_mem_gb,
enable_partial_loading=app_config.enable_partial_loading,
keep_ram_copy_of_weights=app_config.keep_ram_copy_of_weights,
max_ram_cache_size_gb=app_config.max_cache_ram_gb,
max_vram_cache_size_gb=app_config.max_cache_vram_gb,
execution_device=execution_device or TorchDevice.choose_torch_device(),

View File

@@ -108,8 +108,16 @@ class Batch(BaseModel):
return v
for batch_data_list in v:
for datum in batch_data_list:
if not datum.items:
continue
# Special handling for numbers - they can be mixed
# TODO(psyche): Update BatchDatum to have a `type` field to specify the type of the items, then we can have strict float and int fields
if all(isinstance(item, (int, float)) for item in datum.items):
continue
# Get the type of the first item in the list
first_item_type = type(datum.items[0]) if datum.items else None
first_item_type = type(datum.items[0])
for item in datum.items:
if type(item) is not first_item_type:
raise BatchItemsTypeError("All items in a batch must have the same type")

View File

@@ -0,0 +1,26 @@
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.backend.model_manager.config import BaseModelType, SubModelType
def preprocess_t5_encoder_model_identifier(model_identifier: ModelIdentifierField) -> ModelIdentifierField:
"""A helper function to normalize a T5 encoder model identifier so that T5 models associated with FLUX
or SD3 models can be used interchangeably.
"""
if model_identifier.base == BaseModelType.Any:
return model_identifier.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
elif model_identifier.base == BaseModelType.StableDiffusion3:
return model_identifier.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
else:
raise ValueError(f"Unsupported model base: {model_identifier.base}")
def preprocess_t5_tokenizer_model_identifier(model_identifier: ModelIdentifierField) -> ModelIdentifierField:
"""A helper function to normalize a T5 tokenizer model identifier so that T5 models associated with FLUX
or SD3 models can be used interchangeably.
"""
if model_identifier.base == BaseModelType.Any:
return model_identifier.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
elif model_identifier.base == BaseModelType.StableDiffusion3:
return model_identifier.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
else:
raise ValueError(f"Unsupported model base: {model_identifier.base}")

View File

@@ -1,13 +1,19 @@
# Initially pulled from https://github.com/black-forest-labs/flux
from torch import Tensor, nn
from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
from invokeai.backend.util.devices import TorchDevice
class HFEncoder(nn.Module):
def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool, max_length: int):
def __init__(
self,
encoder: PreTrainedModel,
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
is_clip: bool,
max_length: int,
):
super().__init__()
self.max_length = max_length
self.is_clip = is_clip

View File

@@ -9,12 +9,17 @@ class CachedModelOnlyFullLoad:
MPS memory, etc.
"""
def __init__(self, model: torch.nn.Module | Any, compute_device: torch.device, total_bytes: int):
def __init__(
self, model: torch.nn.Module | Any, compute_device: torch.device, total_bytes: int, keep_ram_copy: bool = False
):
"""Initialize a CachedModelOnlyFullLoad.
Args:
model (torch.nn.Module | Any): The model to wrap. Should be on the CPU.
compute_device (torch.device): The compute device to move the model to.
total_bytes (int): The total size (in bytes) of all the weights in the model.
keep_ram_copy (bool): Whether to keep a read-only copy of the model's state dict in RAM. Keeping a RAM copy
increases RAM usage, but speeds up model offload from VRAM and LoRA patching (assuming there is
sufficient RAM).
"""
# model is often a torch.nn.Module, but could be any model type. Throughout this class, we handle both cases.
self._model = model
@@ -23,7 +28,7 @@ class CachedModelOnlyFullLoad:
# A CPU read-only copy of the model's state dict.
self._cpu_state_dict: dict[str, torch.Tensor] | None = None
if isinstance(model, torch.nn.Module):
if isinstance(model, torch.nn.Module) and keep_ram_copy:
self._cpu_state_dict = model.state_dict()
self._total_bytes = total_bytes

View File

@@ -14,33 +14,38 @@ class CachedModelWithPartialLoad:
MPS memory, etc.
"""
def __init__(self, model: torch.nn.Module, compute_device: torch.device):
def __init__(self, model: torch.nn.Module, compute_device: torch.device, keep_ram_copy: bool = False):
self._model = model
self._compute_device = compute_device
# A CPU read-only copy of the model's state dict.
self._cpu_state_dict: dict[str, torch.Tensor] = model.state_dict()
model_state_dict = model.state_dict()
# A CPU read-only copy of the model's state dict. Used for faster model unloads from VRAM, and to speed up LoRA
# patching. Set to `None` if keep_ram_copy is False.
self._cpu_state_dict: dict[str, torch.Tensor] | None = model_state_dict if keep_ram_copy else None
# A dictionary of the size of each tensor in the state dict.
# HACK(ryand): We use this dictionary any time we are doing byte tracking calculations. We do this for
# consistency in case the application code has modified the model's size (e.g. by casting to a different
# precision). Of course, this means that we are making model cache load/unload decisions based on model size
# data that may not be fully accurate.
self._state_dict_bytes = {k: calc_tensor_size(v) for k, v in self._cpu_state_dict.items()}
self._state_dict_bytes = {k: calc_tensor_size(v) for k, v in model_state_dict.items()}
self._total_bytes = sum(self._state_dict_bytes.values())
self._cur_vram_bytes: int | None = None
self._modules_that_support_autocast = self._find_modules_that_support_autocast()
self._keys_in_modules_that_do_not_support_autocast = self._find_keys_in_modules_that_do_not_support_autocast()
self._keys_in_modules_that_do_not_support_autocast = self._find_keys_in_modules_that_do_not_support_autocast(
model_state_dict
)
self._state_dict_keys_by_module_prefix = self._group_state_dict_keys_by_module_prefix(model_state_dict)
def _find_modules_that_support_autocast(self) -> dict[str, torch.nn.Module]:
"""Find all modules that support autocasting."""
return {n: m for n, m in self._model.named_modules() if isinstance(m, CustomModuleMixin)} # type: ignore
def _find_keys_in_modules_that_do_not_support_autocast(self) -> set[str]:
def _find_keys_in_modules_that_do_not_support_autocast(self, state_dict: dict[str, torch.Tensor]) -> set[str]:
keys_in_modules_that_do_not_support_autocast: set[str] = set()
for key in self._cpu_state_dict.keys():
for key in state_dict.keys():
for module_name in self._modules_that_support_autocast.keys():
if key.startswith(module_name):
break
@@ -48,6 +53,47 @@ class CachedModelWithPartialLoad:
keys_in_modules_that_do_not_support_autocast.add(key)
return keys_in_modules_that_do_not_support_autocast
def _group_state_dict_keys_by_module_prefix(self, state_dict: dict[str, torch.Tensor]) -> dict[str, list[str]]:
"""A helper function that groups state dict keys by module prefix.
Example:
```
state_dict = {
"weight": ...,
"module.submodule.weight": ...,
"module.submodule.bias": ...,
"module.other_submodule.weight": ...,
"module.other_submodule.bias": ...,
}
output = group_state_dict_keys_by_module_prefix(state_dict)
# The output will be:
output = {
"": [
"weight",
],
"module.submodule": [
"module.submodule.weight",
"module.submodule.bias",
],
"module.other_submodule": [
"module.other_submodule.weight",
"module.other_submodule.bias",
],
}
```
"""
state_dict_keys_by_module_prefix: dict[str, list[str]] = {}
for key in state_dict.keys():
split = key.rsplit(".", 1)
# `split` will have length 1 if the root module has parameters.
module_name = split[0] if len(split) > 1 else ""
if module_name not in state_dict_keys_by_module_prefix:
state_dict_keys_by_module_prefix[module_name] = []
state_dict_keys_by_module_prefix[module_name].append(key)
return state_dict_keys_by_module_prefix
def _move_non_persistent_buffers_to_device(self, device: torch.device):
"""Move the non-persistent buffers to the target device. These buffers are not included in the state dict,
so we need to move them manually.
@@ -98,6 +144,82 @@ class CachedModelWithPartialLoad:
"""Unload all weights from VRAM."""
return self.partial_unload_from_vram(self.total_bytes())
def _load_state_dict_with_device_conversion(
self, state_dict: dict[str, torch.Tensor], keys_to_convert: set[str], target_device: torch.device
):
if self._cpu_state_dict is not None:
# Run the fast version.
self._load_state_dict_with_fast_device_conversion(
state_dict=state_dict,
keys_to_convert=keys_to_convert,
target_device=target_device,
cpu_state_dict=self._cpu_state_dict,
)
else:
# Run the low-virtual-memory version.
self._load_state_dict_with_jit_device_conversion(
state_dict=state_dict,
keys_to_convert=keys_to_convert,
target_device=target_device,
)
def _load_state_dict_with_jit_device_conversion(
self,
state_dict: dict[str, torch.Tensor],
keys_to_convert: set[str],
target_device: torch.device,
):
"""A custom state dict loading implementation with good peak memory properties.
This implementation has the important property that it copies parameters to the target device one module at a time
rather than applying all of the device conversions and then calling load_state_dict(). This is done to minimize the
peak virtual memory usage. Specifically, we want to avoid a case where we hold references to all of the CPU weights
and CUDA weights simultaneously, because Windows will reserve virtual memory for both.
"""
for module_name, module in self._model.named_modules():
module_keys = self._state_dict_keys_by_module_prefix.get(module_name, [])
# Calculate the length of the module name prefix.
prefix_len = len(module_name)
if prefix_len > 0:
prefix_len += 1
module_state_dict = {}
for key in module_keys:
if key in keys_to_convert:
# It is important that we overwrite `state_dict[key]` to avoid keeping two copies of the same
# parameter.
state_dict[key] = state_dict[key].to(target_device)
# Note that we keep parameters that have not been moved to a new device in case the module implements
# weird custom state dict loading logic that requires all parameters to be present.
module_state_dict[key[prefix_len:]] = state_dict[key]
if len(module_state_dict) > 0:
# We set strict=False, because if `module` has both parameters and child modules, then we are loading a
# state dict that only contains the parameters of `module` (not its children).
# We assume that it is rare for non-leaf modules to have parameters. Calling load_state_dict() on non-leaf
# modules will recurse through all of the children, so is a bit wasteful.
incompatible_keys = module.load_state_dict(module_state_dict, strict=False, assign=True)
# Missing keys are ok, unexpected keys are not.
assert len(incompatible_keys.unexpected_keys) == 0
def _load_state_dict_with_fast_device_conversion(
self,
state_dict: dict[str, torch.Tensor],
keys_to_convert: set[str],
target_device: torch.device,
cpu_state_dict: dict[str, torch.Tensor],
):
"""Convert parameters to the target device and load them into the model. Leverages the `cpu_state_dict` to speed
up transfers of weights to the CPU.
"""
for key in keys_to_convert:
if target_device.type == "cpu":
state_dict[key] = cpu_state_dict[key]
else:
state_dict[key] = state_dict[key].to(target_device)
self._model.load_state_dict(state_dict, assign=True)
@torch.no_grad()
def partial_load_to_vram(self, vram_bytes_to_load: int) -> int:
"""Load more weights into VRAM without exceeding vram_bytes_to_load.
@@ -112,26 +234,33 @@ class CachedModelWithPartialLoad:
cur_state_dict = self._model.state_dict()
# Identify the keys that will be loaded into VRAM.
keys_to_load: set[str] = set()
# First, process the keys that *must* be loaded into VRAM.
for key in self._keys_in_modules_that_do_not_support_autocast:
param = cur_state_dict[key]
if param.device.type == self._compute_device.type:
continue
keys_to_load.add(key)
param_size = self._state_dict_bytes[key]
cur_state_dict[key] = param.to(self._compute_device, copy=True)
vram_bytes_loaded += param_size
if vram_bytes_loaded > vram_bytes_to_load:
logger = InvokeAILogger.get_logger()
logger.warning(
f"Loaded {vram_bytes_loaded / 2**20} MB into VRAM, but only {vram_bytes_to_load / 2**20} MB were "
f"Loading {vram_bytes_loaded / 2**20} MB into VRAM, but only {vram_bytes_to_load / 2**20} MB were "
"requested. This is the minimum set of weights in VRAM required to run the model."
)
# Next, process the keys that can optionally be loaded into VRAM.
fully_loaded = True
for key, param in cur_state_dict.items():
# Skip the keys that have already been processed above.
if key in keys_to_load:
continue
if param.device.type == self._compute_device.type:
continue
@@ -142,14 +271,14 @@ class CachedModelWithPartialLoad:
fully_loaded = False
continue
cur_state_dict[key] = param.to(self._compute_device, copy=True)
keys_to_load.add(key)
vram_bytes_loaded += param_size
if vram_bytes_loaded > 0:
if len(keys_to_load) > 0:
# We load the entire state dict, not just the parameters that changed, in case there are modules that
# override _load_from_state_dict() and do some funky stuff that requires the entire state dict.
# Alternatively, in the future, grouping parameters by module could probably solve this problem.
self._model.load_state_dict(cur_state_dict, assign=True)
self._load_state_dict_with_device_conversion(cur_state_dict, keys_to_load, self._compute_device)
if self._cur_vram_bytes is not None:
self._cur_vram_bytes += vram_bytes_loaded
@@ -180,6 +309,10 @@ class CachedModelWithPartialLoad:
offload_device = "cpu"
cur_state_dict = self._model.state_dict()
# Identify the keys that will be offloaded to CPU.
keys_to_offload: set[str] = set()
for key, param in cur_state_dict.items():
if vram_bytes_freed >= vram_bytes_to_free:
break
@@ -191,11 +324,11 @@ class CachedModelWithPartialLoad:
required_weights_in_vram += self._state_dict_bytes[key]
continue
cur_state_dict[key] = self._cpu_state_dict[key]
keys_to_offload.add(key)
vram_bytes_freed += self._state_dict_bytes[key]
if vram_bytes_freed > 0:
self._model.load_state_dict(cur_state_dict, assign=True)
if len(keys_to_offload) > 0:
self._load_state_dict_with_device_conversion(cur_state_dict, keys_to_offload, torch.device("cpu"))
if self._cur_vram_bytes is not None:
self._cur_vram_bytes -= vram_bytes_freed

View File

@@ -78,6 +78,7 @@ class ModelCache:
self,
execution_device_working_mem_gb: float,
enable_partial_loading: bool,
keep_ram_copy_of_weights: bool,
max_ram_cache_size_gb: float | None = None,
max_vram_cache_size_gb: float | None = None,
execution_device: torch.device | str = "cuda",
@@ -105,6 +106,7 @@ class ModelCache:
:param logger: InvokeAILogger to use (otherwise creates one)
"""
self._enable_partial_loading = enable_partial_loading
self._keep_ram_copy_of_weights = keep_ram_copy_of_weights
self._execution_device_working_mem_gb = execution_device_working_mem_gb
self._execution_device: torch.device = torch.device(execution_device)
self._storage_device: torch.device = torch.device(storage_device)
@@ -121,6 +123,8 @@ class ModelCache:
self._cached_models: Dict[str, CacheRecord] = {}
self._cache_stack: List[str] = []
self._ram_cache_size_bytes = self._calc_ram_available_to_model_cache()
@property
def stats(self) -> Optional[CacheStats]:
"""Return collected CacheStats object."""
@@ -154,9 +158,13 @@ class ModelCache:
# Wrap model.
if isinstance(model, torch.nn.Module) and running_with_cuda and self._enable_partial_loading:
wrapped_model = CachedModelWithPartialLoad(model, self._execution_device)
wrapped_model = CachedModelWithPartialLoad(
model, self._execution_device, keep_ram_copy=self._keep_ram_copy_of_weights
)
else:
wrapped_model = CachedModelOnlyFullLoad(model, self._execution_device, size)
wrapped_model = CachedModelOnlyFullLoad(
model, self._execution_device, size, keep_ram_copy=self._keep_ram_copy_of_weights
)
cache_record = CacheRecord(key=key, cached_model=wrapped_model)
self._cached_models[key] = cache_record
@@ -339,16 +347,17 @@ class ModelCache:
self._delete_cache_entry(cache_entry)
raise
def _get_total_vram_available_to_cache(self, working_mem_bytes: Optional[int]) -> int:
"""Calculate the total amount of VRAM available for storing models. I.e. the amount of VRAM available to the
process minus the amount of VRAM to keep for working memory.
def _get_vram_available(self, working_mem_bytes: Optional[int]) -> int:
"""Calculate the amount of additional VRAM available for the cache to use (takes into account the working
memory).
"""
# If self._max_vram_cache_size_gb is set, then it overrides the default logic.
if self._max_vram_cache_size_gb is not None:
return int(self._max_vram_cache_size_gb * GB)
vram_total_available_to_cache = int(self._max_vram_cache_size_gb * GB)
return vram_total_available_to_cache - self._get_vram_in_use()
working_mem_bytes_default = int(self._execution_device_working_mem_gb * GB)
working_mem_bytes = max(working_mem_bytes or 0, working_mem_bytes_default)
working_mem_bytes = max(working_mem_bytes or working_mem_bytes_default, working_mem_bytes_default)
if self._execution_device.type == "cuda":
# TODO(ryand): It is debatable whether we should use memory_reserved() or memory_allocated() here.
@@ -359,28 +368,19 @@ class ModelCache:
vram_free, _vram_total = torch.cuda.mem_get_info(self._execution_device)
vram_available_to_process = vram_free + vram_allocated
elif self._execution_device.type == "mps":
vram_allocated = torch.mps.driver_allocated_memory()
vram_reserved = torch.mps.driver_allocated_memory()
# TODO(ryand): Is it accurate that MPS shares memory with the CPU?
vram_free = psutil.virtual_memory().available
vram_available_to_process = vram_free + vram_allocated
vram_available_to_process = vram_free + vram_reserved
else:
raise ValueError(f"Unsupported execution device: {self._execution_device.type}")
return vram_available_to_process - working_mem_bytes
def _get_vram_available(self, working_mem_bytes: Optional[int]) -> int:
"""Calculate the amount of additional VRAM available for the model cache to use (takes into account the working
memory).
"""
return self._get_total_vram_available_to_cache(working_mem_bytes) - self._get_vram_in_use()
vram_total_available_to_cache = vram_available_to_process - working_mem_bytes
vram_cur_available_to_cache = vram_total_available_to_cache - self._get_vram_in_use()
return vram_cur_available_to_cache
def _get_vram_in_use(self) -> int:
"""Get the amount of VRAM currently in use by the cache."""
# NOTE(ryand): To be conservative, we are treating the amount of VRAM allocated by torch as entirely being used
# by the model cache. In reality, some of this allocated memory is being used as working memory. This is a
# reasonable conservative assumption, because this function is typically called before (not during)
# working-memory-intensive operations. This conservative definition also helps to handle models whose size
# increased after initial load (e.g. a model whose precision was upcast by application code).
if self._execution_device.type == "cuda":
return torch.cuda.memory_allocated()
elif self._execution_device.type == "mps":
@@ -390,83 +390,89 @@ class ModelCache:
# Alternative definition of VRAM in use:
# return sum(ce.cached_model.cur_vram_bytes() for ce in self._cached_models.values())
def _get_ram_available(self) -> int:
"""Get the amount of RAM available for the cache to use, while keeping memory pressure under control."""
def _calc_ram_available_to_model_cache(self) -> int:
"""Calculate the amount of RAM available for the cache to use."""
# If self._max_ram_cache_size_gb is set, then it overrides the default logic.
if self._max_ram_cache_size_gb is not None:
ram_total_available_to_cache = int(self._max_ram_cache_size_gb * GB)
return ram_total_available_to_cache - self._get_ram_in_use()
self._logger.info(f"Using user-defined RAM cache size: {self._max_ram_cache_size_gb} GB.")
return int(self._max_ram_cache_size_gb * GB)
# We have 3 strategies for calculating the amount of RAM available to the cache. We calculate all 3 options and
# then use a heuristic to decide which one to use.
# - Strategy 1: Match RAM cache size to VRAM cache size
# - Strategy 2: Aim to keep at least 10% of RAM free
# - Strategy 3: Use a minimum RAM cache size of 4GB
# Heuristics for dynamically calculating the RAM cache size, **in order of increasing priority**:
# 1. As an initial default, use 50% of the total RAM for InvokeAI.
# - Assume a 2GB baseline for InvokeAI's non-model RAM usage, and use the rest of the RAM for the model cache.
# 2. On a system with a lot of RAM (e.g. 64GB+), users probably don't want InvokeAI to eat up too much RAM.
# There are diminishing returns to storing more and more models. So, we apply an upper bound.
# - On systems without a CUDA device, the upper bound is 32GB.
# - On systems with a CUDA device, the upper bound is 2x the amount of VRAM.
# 3. On systems with a CUDA device, the minimum should be the VRAM size (less the working memory).
# - Setting lower than this would mean that we sometimes kick models out of the cache when there is room for
# all models in VRAM.
# - Consider an extreme case of a system with 8GB RAM / 24GB VRAM. I haven't tested this, but I think
# you'd still want the RAM cache size to be ~24GB (less the working memory). (Though you'd probably want to
# set `keep_ram_copy_of_weights: false` in this case.)
# 4. Absolute minimum of 4GB.
# ---------------------
# Calculate Strategy 1
# ---------------------
# Under Strategy 1, the RAM cache size is equal to the total VRAM available to the cache. The RAM cache size
# should **roughly** match the VRAM cache size for the following reasons:
# - Setting it much larger than the VRAM cache size means that we would accumulate mmap'ed model files for
# models that are 0% loaded onto the GPU. Accumulating a large amount of virtual memory causes issues -
# particularly on Windows. Instead, we should drop these extra models from the cache and rely on the OS's
# disk caching behavior to make reloading them fast (if there is enough RAM for disk caching to be possible).
# - Setting it much smaller than the VRAM cache size would increase the likelihood that we drop models from the
# cache even if they are partially loaded onto the GPU.
#
# TODO(ryand): In the future, we should re-think this strategy. Setting the RAM cache size like this doesn't
# really make sense, and is done primarily for consistency with legacy behavior. We should be relying on the
# OS's caching behavior more and make decisions about whether to drop models from the cache based primarily on
# how much of the model can be kept in VRAM.
cache_ram_used = self._get_ram_in_use()
if self._execution_device.type == "cpu":
# Strategy 1 is not applicable for CPU.
ram_available_based_on_default_ram_cache_size = 0
else:
default_ram_cache_size_bytes = self._get_total_vram_available_to_cache(None)
ram_available_based_on_default_ram_cache_size = default_ram_cache_size_bytes - cache_ram_used
# NOTE(ryand): We explored dynamically adjusting the RAM cache size based on memory pressure (using psutil), but
# decided against it for now, for the following reasons:
# - It was surprisingly difficult to get memory metrics with consistent definitions across OSes. (If you go
# down this path again, don't underestimate the amount of complexity here and be sure to test rigorously on all
# OSes.)
# - Making the RAM cache size dynamic opens the door for performance regressions that are hard to diagnose and
# hard for users to understand. It is better for users to see that their RAM is maxed out, and then override
# the default value if desired.
# ---------------------
# Calculate Strategy 2
# ---------------------
# If RAM memory pressure is high, then we want to be more conservative with the RAM cache size.
virtual_memory = psutil.virtual_memory()
ram_total = virtual_memory.total
ram_available = virtual_memory.available
ram_used = ram_total - ram_available
# We aim to keep at least 10% of RAM free.
ram_available_based_on_memory_usage = int(ram_total * 0.9) - ram_used
# Lookup the total VRAM size for the CUDA execution device.
total_cuda_vram_bytes: int | None = None
if self._execution_device.type == "cuda":
_, total_cuda_vram_bytes = torch.cuda.mem_get_info(self._execution_device)
# ---------------------
# Calculate Strategy 3
# ---------------------
# If the RAM cache is very small, then there's an increased likelihood that we will run into this issue:
# https://github.com/invoke-ai/InvokeAI/issues/7513
# To keep things running smoothly, there's a minimum RAM cache size that we always allow (even if this means
# using swap).
min_ram_cache_size_bytes = 4 * GB
ram_available_based_on_min_cache_size = min_ram_cache_size_bytes - cache_ram_used
# Apply heuristic 1.
# ------------------
heuristics_applied = [1]
total_system_ram_bytes = psutil.virtual_memory().total
# Assumed baseline RAM used by InvokeAI for non-model stuff.
baseline_ram_used_by_invokeai = 2 * GB
ram_available_to_model_cache = int(total_system_ram_bytes * 0.5 - baseline_ram_used_by_invokeai)
# ----------------------------
# Decide which strategy to use
# ----------------------------
# First, take the minimum of strategies 1 and 2.
ram_available = min(ram_available_based_on_default_ram_cache_size, ram_available_based_on_memory_usage)
# Then, apply strategy 3 as the lower bound.
ram_available = max(ram_available, ram_available_based_on_min_cache_size)
self._logger.debug(
f"Calculated RAM available: {ram_available/MB:.2f} MB. Strategies considered (1,2,3): "
f"{ram_available_based_on_default_ram_cache_size/MB:.2f}, "
f"{ram_available_based_on_memory_usage/MB:.2f}, "
f"{ram_available_based_on_min_cache_size/MB:.2f}"
# Apply heuristic 2.
# ------------------
max_ram_cache_size_bytes = 32 * GB
if total_cuda_vram_bytes is not None:
max_ram_cache_size_bytes = 2 * total_cuda_vram_bytes
if ram_available_to_model_cache > max_ram_cache_size_bytes:
heuristics_applied.append(2)
ram_available_to_model_cache = max_ram_cache_size_bytes
# Apply heuristic 3.
# ------------------
if total_cuda_vram_bytes is not None:
if self._max_vram_cache_size_gb is not None:
min_ram_cache_size_bytes = int(self._max_vram_cache_size_gb * GB)
else:
min_ram_cache_size_bytes = total_cuda_vram_bytes - int(self._execution_device_working_mem_gb * GB)
if ram_available_to_model_cache < min_ram_cache_size_bytes:
heuristics_applied.append(3)
ram_available_to_model_cache = min_ram_cache_size_bytes
# Apply heuristic 4.
# ------------------
if ram_available_to_model_cache < 4 * GB:
heuristics_applied.append(4)
ram_available_to_model_cache = 4 * GB
self._logger.info(
f"Calculated model RAM cache size: {ram_available_to_model_cache / MB:.2f} MB. Heuristics applied: {heuristics_applied}."
)
return ram_available
return ram_available_to_model_cache
def _get_ram_in_use(self) -> int:
"""Get the amount of RAM currently in use."""
return sum(ce.cached_model.total_bytes() for ce in self._cached_models.values())
def _get_ram_available(self) -> int:
"""Get the amount of RAM available for the cache to use."""
return self._ram_cache_size_bytes - self._get_ram_in_use()
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
if self._log_memory_usage:
return MemorySnapshot.capture()

View File

@@ -80,19 +80,19 @@ class FluxVAELoader(ModelLoader):
raise ValueError("Only VAECheckpointConfig models are currently supported here.")
model_path = Path(config.path)
with SilenceWarnings():
with accelerate.init_empty_weights():
model = AutoEncoder(ae_params[config.config_path])
sd = load_file(model_path)
model.load_state_dict(sd, assign=True)
# VAE is broken in float16, which mps defaults to
if self._torch_dtype == torch.float16:
try:
vae_dtype = torch.tensor([1.0], dtype=torch.bfloat16, device=self._torch_device).dtype
except TypeError:
vae_dtype = torch.float32
else:
vae_dtype = self._torch_dtype
model.to(vae_dtype)
sd = load_file(model_path)
model.load_state_dict(sd, assign=True)
# VAE is broken in float16, which mps defaults to
if self._torch_dtype == torch.float16:
try:
vae_dtype = torch.tensor([1.0], dtype=torch.bfloat16, device=self._torch_device).dtype
except TypeError:
vae_dtype = torch.float32
else:
vae_dtype = self._torch_dtype
model.to(vae_dtype)
return model
@@ -183,7 +183,9 @@ class T5EncoderCheckpointModel(ModelLoader):
case SubModelType.Tokenizer2 | SubModelType.Tokenizer3:
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
case SubModelType.TextEncoder2 | SubModelType.TextEncoder3:
return T5EncoderModel.from_pretrained(Path(config.path) / "text_encoder_2", torch_dtype="auto")
return T5EncoderModel.from_pretrained(
Path(config.path) / "text_encoder_2", torch_dtype="auto", low_cpu_mem_usage=True
)
raise ValueError(
f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
@@ -217,17 +219,18 @@ class FluxCheckpointModel(ModelLoader):
assert isinstance(config, MainCheckpointConfig)
model_path = Path(config.path)
with SilenceWarnings():
with accelerate.init_empty_weights():
model = Flux(params[config.config_path])
sd = load_file(model_path)
if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd:
sd = convert_bundle_to_flux_transformer_checkpoint(sd)
new_sd_size = sum([ten.nelement() * torch.bfloat16.itemsize for ten in sd.values()])
self._ram_cache.make_room(new_sd_size)
for k in sd.keys():
# We need to cast to bfloat16 due to it being the only currently supported dtype for inference
sd[k] = sd[k].to(torch.bfloat16)
model.load_state_dict(sd, assign=True)
sd = load_file(model_path)
if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd:
sd = convert_bundle_to_flux_transformer_checkpoint(sd)
new_sd_size = sum([ten.nelement() * torch.bfloat16.itemsize for ten in sd.values()])
self._ram_cache.make_room(new_sd_size)
for k in sd.keys():
# We need to cast to bfloat16 due to it being the only currently supported dtype for inference
sd[k] = sd[k].to(torch.bfloat16)
model.load_state_dict(sd, assign=True)
return model
@@ -258,11 +261,11 @@ class FluxGGUFCheckpointModel(ModelLoader):
assert isinstance(config, MainGGUFCheckpointConfig)
model_path = Path(config.path)
with SilenceWarnings():
with accelerate.init_empty_weights():
model = Flux(params[config.config_path])
# HACK(ryand): We shouldn't be hard-coding the compute_dtype here.
sd = gguf_sd_loader(model_path, compute_dtype=torch.bfloat16)
# HACK(ryand): We shouldn't be hard-coding the compute_dtype here.
sd = gguf_sd_loader(model_path, compute_dtype=torch.bfloat16)
# HACK(ryand): There are some broken GGUF models in circulation that have the wrong shape for img_in.weight.
# We override the shape here to fix the issue.

View File

@@ -76,6 +76,7 @@
"konva": "^9.3.15",
"lodash-es": "^4.17.21",
"lru-cache": "^11.0.1",
"mtwist": "^1.0.2",
"nanoid": "^5.0.7",
"nanostores": "^0.11.3",
"new-github-issue-url": "^1.0.0",

View File

@@ -77,6 +77,9 @@ dependencies:
lru-cache:
specifier: ^11.0.1
version: 11.0.1
mtwist:
specifier: ^1.0.2
version: 1.0.2
nanoid:
specifier: ^5.0.7
version: 5.0.7
@@ -7016,6 +7019,10 @@ packages:
/ms@2.1.3:
resolution: {integrity: sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==}
/mtwist@1.0.2:
resolution: {integrity: sha512-eRsSga5jkLg7nNERPOV8vDNxgSwuEcj5upQfJcT0gXfJwXo3pMc7xOga0fu8rXHyrxzl7GFVWWDuaPQgpKDvgw==}
dev: false
/muggle-string@0.3.1:
resolution: {integrity: sha512-ckmWDJjphvd/FvZawgygcUeQCxzvohjFO5RxTjj4eq8kw359gFF3E1brjfI+viLMxss5JrHTDRHZvu2/tuy0Qg==}
dev: true

View File

@@ -177,7 +177,17 @@
"none": "None",
"new": "New",
"generating": "Generating",
"warnings": "Warnings"
"warnings": "Warnings",
"start": "Start",
"count": "Count",
"step": "Step",
"end": "End",
"min": "Min",
"max": "Max",
"values": "Values",
"resetToDefaults": "Reset to Defaults",
"seed": "Seed",
"combinatorial": "Combinatorial"
},
"hrf": {
"hrf": "High Resolution Fix",
@@ -850,6 +860,19 @@
"defaultVAE": "Default VAE"
},
"nodes": {
"arithmeticSequence": "Arithmetic Sequence",
"linearDistribution": "Linear Distribution",
"uniformRandomDistribution": "Uniform Random Distribution",
"parseString": "Parse String",
"splitOn": "Split On",
"noBatchGroup": "no group",
"generatorNRandomValues_one": "{{count}} random value",
"generatorNRandomValues_other": "{{count}} random values",
"generatorNoValues": "empty",
"generatorLoading": "loading",
"generatorLoadFromFile": "Load from File",
"dynamicPromptsRandom": "Dynamic Prompts (Random)",
"dynamicPromptsCombinatorial": "Dynamic Prompts (Combinatorial)",
"addNode": "Add Node",
"addNodeToolTip": "Add Node (Shift+A, Space)",
"addLinearView": "Add to Linear View",
@@ -989,7 +1012,11 @@
"imageAccessError": "Unable to find image {{image_name}}, resetting to default",
"boardAccessError": "Unable to find board {{board_id}}, resetting to default",
"modelAccessError": "Unable to find model {{key}}, resetting to default",
"saveToGallery": "Save To Gallery"
"saveToGallery": "Save To Gallery",
"addItem": "Add Item",
"generateValues": "Generate Values",
"floatRangeGenerator": "Float Range Generator",
"integerRangeGenerator": "Integer Range Generator"
},
"parameters": {
"aspect": "Aspect",
@@ -1024,11 +1051,22 @@
"addingImagesTo": "Adding images to",
"invoke": "Invoke",
"missingFieldTemplate": "Missing field template",
"missingInputForField": "{{nodeLabel}} -> {{fieldLabel}}: missing input",
"missingInputForField": "missing input",
"missingNodeTemplate": "Missing node template",
"collectionEmpty": "{{nodeLabel}} -> {{fieldLabel}} empty collection",
"collectionTooFewItems": "{{nodeLabel}} -> {{fieldLabel}}: too few items, minimum {{minItems}}",
"collectionTooManyItems": "{{nodeLabel}} -> {{fieldLabel}}: too many items, maximum {{maxItems}}",
"emptyBatches": "empty batches",
"batchNodeNotConnected": "Batch node not connected: {{label}}",
"batchNodeEmptyCollection": "Some batch nodes have empty collections",
"invalidBatchConfigurationCannotCalculate": "Invalid batch configuration; cannot calculate",
"collectionTooFewItems": "too few items, minimum {{minItems}}",
"collectionTooManyItems": "too many items, maximum {{maxItems}}",
"collectionStringTooLong": "too long, max {{maxLength}}",
"collectionStringTooShort": "too short, min {{minLength}}",
"collectionNumberGTMax": "{{value}} > {{maximum}} (inc max)",
"collectionNumberLTMin": "{{value}} < {{minimum}} (inc min)",
"collectionNumberGTExclusiveMax": "{{value}} >= {{exclusiveMaximum}} (exc max)",
"collectionNumberLTExclusiveMin": "{{value}} <= {{exclusiveMinimum}} (exc min)",
"collectionNumberNotMultipleOf": "{{value}} not multiple of {{multipleOf}}",
"batchNodeCollectionSizeMismatch": "Collection size mismatch on Batch {{batchGroupId}}",
"noModelSelected": "No model selected",
"noT5EncoderModelSelected": "No T5 Encoder model selected for FLUX generation",
"noFLUXVAEModelSelected": "No VAE model selected for FLUX generation",
@@ -1100,7 +1138,8 @@
"perPromptLabel": "Seed per Image",
"perPromptDesc": "Use a different seed for each image"
},
"loading": "Generating Dynamic Prompts..."
"loading": "Generating Dynamic Prompts...",
"promptsToGenerate": "Prompts to Generate"
},
"sdxl": {
"cfgScale": "CFG Scale",
@@ -1932,6 +1971,24 @@
"description": "Generates an edge map from the selected layer using the PiDiNet edge detection model.",
"scribble": "Scribble",
"quantize_edges": "Quantize Edges"
},
"img_blur": {
"label": "Blur Image",
"description": "Blurs the selected layer.",
"blur_type": "Blur Type",
"blur_radius": "Radius",
"gaussian_type": "Gaussian",
"box_type": "Box"
},
"img_noise": {
"label": "Noise Image",
"description": "Adds noise to the selected layer.",
"noise_type": "Noise Type",
"noise_amount": "Amount",
"gaussian_type": "Gaussian",
"salt_and_pepper_type": "Salt and Pepper",
"noise_color": "Colored Noise",
"size": "Noise Size"
}
},
"transform": {
@@ -2139,7 +2196,13 @@
},
"whatsNew": {
"whatsNewInInvoke": "What's New in Invoke",
"items": ["Low-VRAM mode", "Dynamic memory management", "Faster model loading times", "Fewer memory errors"],
"items": [
"Low-VRAM mode",
"Dynamic memory management",
"Faster model loading times",
"Fewer memory errors",
"Expanded workflow batch capabilities"
],
"readReleaseNotes": "Read Release Notes",
"watchRecentReleaseVideos": "Watch Recent Release Videos",
"watchUiUpdatesOverview": "Watch UI Updates Overview"

View File

@@ -901,9 +901,7 @@
}
},
"newUserExperience": {
"downloadStarterModels": "Descargar modelos de inicio",
"toGetStarted": "Para empezar, introduzca un mensaje en el cuadro y haga clic en <StrongComponent>Invocar</StrongComponent> para generar su primera imagen. Seleccione una plantilla para mejorar los resultados. Puede elegir guardar sus imágenes directamente en <StrongComponent>Galería</StrongComponent> o editarlas en <StrongComponent>Lienzo</StrongComponent>.",
"importModels": "Importar modelos",
"noModelsInstalled": "Parece que no tienes ningún modelo instalado",
"gettingStartedSeries": "¿Desea más orientación? Consulte nuestra <LinkComponent>Serie de introducción</LinkComponent> para obtener consejos sobre cómo aprovechar todo el potencial de Invoke Studio.",
"toGetStartedLocal": "Para empezar, asegúrate de descargar o importar los modelos necesarios para ejecutar Invoke. A continuación, introduzca un mensaje en el cuadro y haga clic en <StrongComponent>Invocar</StrongComponent> para generar su primera imagen. Seleccione una plantilla para mejorar los resultados. Puede elegir guardar sus imágenes directamente en <StrongComponent>Galería</StrongComponent> o editarlas en el <StrongComponent>Lienzo</StrongComponent>."

View File

@@ -352,7 +352,6 @@
"noT5EncoderModelSelected": "Aucun modèle T5 Encoder sélectionné pour la génération FLUX",
"fluxModelIncompatibleScaledBboxWidth": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), la largeur de la bounding box mise à l'échelle est {{width}}",
"canvasIsCompositing": "La toile est en train de composer",
"collectionEmpty": "{{nodeLabel}} -> {{fieldLabel}} collection vide",
"collectionTooFewItems": "{{nodeLabel}} -> {{fieldLabel}} : trop peu d'éléments, minimum {{minItems}}",
"collectionTooManyItems": "{{nodeLabel}} -> {{fieldLabel}} : trop d'éléments, maximum {{maxItems}}",
"canvasIsSelectingObject": "La toile est occupée (sélection d'objet)"
@@ -2171,8 +2170,6 @@
"toGetStarted": "Pour commencer, saisissez un prompt dans la boîte et cliquez sur <StrongComponent>Invoke</StrongComponent> pour générer votre première image. Sélectionnez un template de prompt pour améliorer les résultats. Vous pouvez choisir de sauvegarder vos images directement dans la <StrongComponent>Galerie</StrongComponent> ou de les modifier sur la <StrongComponent>Toile</StrongComponent>.",
"gettingStartedSeries": "Vous souhaitez plus de conseils? Consultez notre <LinkComponent>Série de démarrage</LinkComponent> pour des astuces sur l'exploitation du plein potentiel de l'Invoke Studio.",
"noModelsInstalled": "Il semble qu'aucun modèle ne soit installé",
"downloadStarterModels": "Télécharger les modèles de démarrage",
"importModels": "Importer des Modèles",
"toGetStartedLocal": "Pour commencer, assurez-vous de télécharger ou d'importer des modèles nécessaires pour exécuter Invoke. Ensuite, saisissez le prompt dans la boîte et cliquez sur <StrongComponent>Invoke</StrongComponent> pour générer votre première image. Sélectionnez un template de prompt pour améliorer les résultats. Vous pouvez choisir de sauvegarder vos images directement sur <StrongComponent>Galerie</StrongComponent> ou les modifier sur la <StrongComponent>Toile</StrongComponent>."
},
"upsell": {

View File

@@ -97,7 +97,14 @@
"ok": "Ok",
"generating": "Generazione",
"loadingModel": "Caricamento del modello",
"warnings": "Avvisi"
"warnings": "Avvisi",
"step": "Passo",
"values": "Valori",
"start": "Inizio",
"end": "Fine",
"resetToDefaults": "Ripristina le impostazioni predefinite",
"seed": "Seme",
"combinatorial": "Combinatorio"
},
"gallery": {
"galleryImageSize": "Dimensione dell'immagine",
@@ -668,7 +675,7 @@
"addingImagesTo": "Aggiungi immagini a",
"systemDisconnected": "Sistema disconnesso",
"missingNodeTemplate": "Modello di nodo mancante",
"missingInputForField": "{{nodeLabel}} -> {{fieldLabel}}: ingresso mancante",
"missingInputForField": "ingresso mancante",
"missingFieldTemplate": "Modello di campo mancante",
"fluxModelIncompatibleBboxHeight": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), altezza riquadro è {{height}}",
"fluxModelIncompatibleBboxWidth": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), larghezza riquadro è {{width}}",
@@ -681,11 +688,22 @@
"canvasIsRasterizing": "La tela è occupata (sta rasterizzando)",
"canvasIsCompositing": "La tela è occupata (in composizione)",
"canvasIsFiltering": "La tela è occupata (sta filtrando)",
"collectionTooManyItems": "{{nodeLabel}} -> {{fieldLabel}}: troppi elementi, massimo {{maxItems}}",
"collectionTooManyItems": "troppi elementi, massimo {{maxItems}}",
"canvasIsSelectingObject": "La tela è occupata (selezione dell'oggetto)",
"collectionTooFewItems": "{{nodeLabel}} -> {{fieldLabel}}: troppi pochi elementi, minimo {{minItems}}",
"collectionEmpty": "{{nodeLabel}} -> {{fieldLabel}} raccolta vuota",
"fluxModelMultipleControlLoRAs": "È possibile utilizzare solo 1 Controllo LoRA alla volta"
"collectionTooFewItems": "troppi pochi elementi, minimo {{minItems}}",
"fluxModelMultipleControlLoRAs": "È possibile utilizzare solo 1 Controllo LoRA alla volta",
"collectionNumberGTMax": "{{value}} > {{maximum}} (incr max)",
"collectionStringTooLong": "troppo lungo, massimo {{maxLength}}",
"batchNodeNotConnected": "Nodo Lotto non connesso: {{label}}",
"batchNodeEmptyCollection": "Alcuni nodi lotto hanno raccolte vuote",
"emptyBatches": "lotti vuoti",
"batchNodeCollectionSizeMismatch": "Le dimensioni della raccolta nel Lotto {{batchGroupId}} non corrispondono",
"invalidBatchConfigurationCannotCalculate": "Configurazione lotto non valida; impossibile calcolare",
"collectionStringTooShort": "troppo corto, minimo {{minLength}}",
"collectionNumberNotMultipleOf": "{{value}} non è multiplo di {{multipleOf}}",
"collectionNumberLTMin": "{{value}} < {{minimum}} (incr min)",
"collectionNumberGTExclusiveMax": "{{value}} >= {{exclusiveMaximum}} (excl max)",
"collectionNumberLTExclusiveMin": "{{value}} <= {{exclusiveMinimum}} (excl min)"
},
"useCpuNoise": "Usa la CPU per generare rumore",
"iterations": "Iterazioni",
@@ -813,7 +831,8 @@
"imagesWillBeAddedTo": "Le immagini caricate verranno aggiunte alle risorse della bacheca {{boardName}}.",
"uploadFailedInvalidUploadDesc_withCount_one": "Devi caricare al massimo 1 immagine PNG o JPEG.",
"uploadFailedInvalidUploadDesc_withCount_many": "Devi caricare al massimo {{count}} immagini PNG o JPEG.",
"uploadFailedInvalidUploadDesc_withCount_other": "Devi caricare al massimo {{count}} immagini PNG o JPEG."
"uploadFailedInvalidUploadDesc_withCount_other": "Devi caricare al massimo {{count}} immagini PNG o JPEG.",
"outOfMemoryErrorDescLocal": "Segui la nostra <LinkComponent>guida per bassa VRAM</LinkComponent> per ridurre gli OOM."
},
"accessibility": {
"invokeProgressBar": "Barra di avanzamento generazione",
@@ -972,7 +991,25 @@
"noWorkflows": "Nessun flusso di lavoro",
"workflowHelpText": "Hai bisogno di aiuto? Consulta la nostra guida <LinkComponent>Introduzione ai flussi di lavoro</LinkComponent>.",
"specialDesc": "Questa invocazione comporta una gestione speciale nell'applicazione. Ad esempio, i nodi Lotto vengono utilizzati per mettere in coda più grafici da un singolo flusso di lavoro.",
"internalDesc": "Questa invocazione è utilizzata internamente da Invoke. Potrebbe subire modifiche significative durante gli aggiornamenti dell'app e potrebbe essere rimossa in qualsiasi momento."
"internalDesc": "Questa invocazione è utilizzata internamente da Invoke. Potrebbe subire modifiche significative durante gli aggiornamenti dell'app e potrebbe essere rimossa in qualsiasi momento.",
"addItem": "Aggiungi elemento",
"generateValues": "Genera valori",
"generatorNoValues": "vuoto",
"linearDistribution": "Distribuzione lineare",
"parseString": "Analizza stringa",
"splitOn": "Diviso su",
"noBatchGroup": "nessun gruppo",
"generatorLoading": "caricamento",
"generatorLoadFromFile": "Carica da file",
"dynamicPromptsRandom": "Prompt dinamici (casuali)",
"dynamicPromptsCombinatorial": "Prompt dinamici (combinatori)",
"floatRangeGenerator": "Generatore di intervalli di numeri in virgola mobile",
"integerRangeGenerator": "Generatore di intervalli di numeri interi",
"uniformRandomDistribution": "Distribuzione casuale uniforme",
"generatorNRandomValues_one": "{{count}} valore casuale",
"generatorNRandomValues_many": "{{count}} valori casuali",
"generatorNRandomValues_other": "{{count}} valori casuali",
"arithmeticSequence": "Sequenza aritmetica"
},
"boards": {
"autoAddBoard": "Aggiungi automaticamente bacheca",
@@ -1138,7 +1175,8 @@
"dynamicPrompts": "Prompt dinamici",
"promptsPreview": "Anteprima dei prompt",
"showDynamicPrompts": "Mostra prompt dinamici",
"loading": "Generazione prompt dinamici..."
"loading": "Generazione prompt dinamici...",
"promptsToGenerate": "Prompt da generare"
},
"popovers": {
"paramScheduler": {
@@ -1907,7 +1945,24 @@
},
"forMoreControl": "Per un maggiore controllo, fare clic su Avanzate qui sotto.",
"advanced": "Avanzate",
"processingLayerWith": "Elaborazione del livello con il filtro {{type}}."
"processingLayerWith": "Elaborazione del livello con il filtro {{type}}.",
"img_blur": {
"label": "Sfoca immagine",
"description": "Sfoca il livello selezionato.",
"blur_type": "Tipo di sfocatura",
"blur_radius": "Raggio",
"gaussian_type": "Gaussiana"
},
"img_noise": {
"size": "Dimensione del rumore",
"salt_and_pepper_type": "Sale e pepe",
"gaussian_type": "Gaussiano",
"noise_color": "Rumore colorato",
"description": "Aggiunge rumore al livello selezionato.",
"noise_type": "Tipo di rumore",
"label": "Aggiungi rumore",
"noise_amount": "Quantità"
}
},
"controlLayers_withCount_hidden": "Livelli di controllo ({{count}} nascosti)",
"regionalGuidance_withCount_hidden": "Guida regionale ({{count}} nascosti)",
@@ -2166,10 +2221,9 @@
"newUserExperience": {
"gettingStartedSeries": "Desideri maggiori informazioni? Consulta la nostra <LinkComponent>Getting Started Series</LinkComponent> per suggerimenti su come sfruttare appieno il potenziale di Invoke Studio.",
"toGetStarted": "Per iniziare, inserisci un prompt nella casella e fai clic su <StrongComponent>Invoke</StrongComponent> per generare la tua prima immagine. Seleziona un modello di prompt per migliorare i risultati. Puoi scegliere di salvare le tue immagini direttamente nella <StrongComponent>Galleria</StrongComponent> o modificarle nella <StrongComponent>Tela</StrongComponent>.",
"importModels": "Importa modelli",
"downloadStarterModels": "Scarica i modelli per iniziare",
"noModelsInstalled": "Sembra che tu non abbia installato alcun modello",
"toGetStartedLocal": "Per iniziare, assicurati di scaricare o importare i modelli necessari per eseguire Invoke. Quindi, inserisci un prompt nella casella e fai clic su <StrongComponent>Invoke</StrongComponent> per generare la tua prima immagine. Seleziona un modello di prompt per migliorare i risultati. Puoi scegliere di salvare le tue immagini direttamente nella <StrongComponent>Galleria</StrongComponent> o modificarle nella <StrongComponent>Tela</StrongComponent>."
"noModelsInstalled": "Sembra che non hai installato alcun modello! Puoi <DownloadStarterModelsButton>scaricare un pacchetto di modelli di avvio</DownloadStarterModelsButton> o <ImportModelsButton>importare modelli</ImportModelsButton>.",
"toGetStartedLocal": "Per iniziare, assicurati di scaricare o importare i modelli necessari per eseguire Invoke. Quindi, inserisci un prompt nella casella e fai clic su <StrongComponent>Invoke</StrongComponent> per generare la tua prima immagine. Seleziona un modello di prompt per migliorare i risultati. Puoi scegliere di salvare le tue immagini direttamente nella <StrongComponent>Galleria</StrongComponent> o modificarle nella <StrongComponent>Tela</StrongComponent>.",
"lowVRAMMode": "Per prestazioni ottimali, segui la nostra <LinkComponent>guida per bassa VRAM</LinkComponent>."
},
"whatsNew": {
"whatsNewInInvoke": "Novità in Invoke",
@@ -2177,7 +2231,10 @@
"watchRecentReleaseVideos": "Guarda i video su questa versione",
"watchUiUpdatesOverview": "Guarda le novità dell'interfaccia",
"items": [
"<StrongComponent>Livelli di controllo Flux</StrongComponent>: nuovi modelli di controllo per il rilevamento dei bordi e la mappatura della profondità sono ora supportati per i modelli di Flux dev."
"Modalità Bassa-VRAM",
"Gestione dinamica della memoria",
"Tempi di caricamento del modello più rapidi",
"Meno errori di memoria"
]
},
"system": {

View File

@@ -220,7 +220,15 @@
"tab": "Tab",
"loadingModel": "Đang Tải Model",
"generating": "Đang Tạo Sinh",
"warnings": "Cảnh Báo"
"warnings": "Cảnh Báo",
"count": "Đếm",
"step": "Bước",
"values": "Giá Trị",
"start": "Bắt Đầu",
"end": "Kết Thúc",
"min": "Tối Thiểu",
"max": "Tối Đa",
"resetToDefaults": "Đặt Lại Về Mặc Định"
},
"prompt": {
"addPromptTrigger": "Thêm Prompt Trigger",
@@ -965,7 +973,19 @@
"outputFieldTypeParseError": "Không thể phân tích loại dữ liệu đầu ra của {{node}}.{{field}} ({{message}})",
"modelAccessError": "Không thể tìm thấy model {{key}}, chuyển về mặc định",
"internalDesc": "Trình kích hoạt này được dùng bên trong bởi Invoke. Nó có thể phá hỏng thay đổi trong khi cập nhật ứng dụng và có thể bị xoá bất cứ lúc nào.",
"specialDesc": "Trình kích hoạt này có một số xử lý đặc biệt trong ứng dụng. Ví dụ, Node Hàng Loạt được dùng để xếp vào nhiều đồ thị từ một workflow."
"specialDesc": "Trình kích hoạt này có một số xử lý đặc biệt trong ứng dụng. Ví dụ, Node Hàng Loạt được dùng để xếp vào nhiều đồ thị từ một workflow.",
"addItem": "Thêm Mục",
"generateValues": "Cho Ra Giá Trị",
"floatRangeGenerator": "Phạm Vị Tạo Ra Số Thực",
"integerRangeGenerator": "Phạm Vị Tạo Ra Số Nguyên",
"linearDistribution": "Phân Bố Tuyến Tính",
"uniformRandomDistribution": "Phân Bố Ngẫu Nhiên Đồng Nhất",
"parseString": "Phân Tích Chuỗi",
"noBatchGroup": "không có nhóm",
"generatorNoValues": "trống",
"splitOn": "Tách Ở",
"arithmeticSequence": "Cấp Số Cộng",
"generatorNRandomValues_other": "{{count}} giá trị ngẫu nhiên"
},
"popovers": {
"paramCFGRescaleMultiplier": {
@@ -1433,13 +1453,24 @@
"missingNodeTemplate": "Thiếu mẫu trình bày node",
"fluxModelIncompatibleBboxHeight": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), chiều dài hộp giới hạn là {{height}}",
"fluxModelIncompatibleScaledBboxWidth": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), tỉ lệ chiều rộng hộp giới hạn là {{width}}",
"missingInputForField": "{{nodeLabel}} -> {{fieldLabel}}: thiếu đầu vào",
"missingInputForField": "thiếu đầu vào",
"missingFieldTemplate": "Thiếu vùng mẫu trình bày",
"collectionEmpty": "{{nodeLabel}} -> {{fieldLabel}} tài nguyên trống",
"collectionTooFewItems": "{{nodeLabel}} -> {{fieldLabel}}: quá ít mục, tối thiểu {{minItems}}",
"collectionTooManyItems": "{{nodeLabel}} -> {{fieldLabel}}: quá nhiều mục, tối đa {{maxItems}}",
"collectionTooFewItems": "quá ít mục, tối thiểu là {{minItems}}",
"collectionTooManyItems": "quá nhiều mục, tối đa là {{maxItems}}",
"canvasIsSelectingObject": "Canvas đang bận (đang chọn đồ vật)",
"fluxModelMultipleControlLoRAs": "Chỉ có thể dùng 1 LoRA Điều Khiển Được"
"fluxModelMultipleControlLoRAs": "Chỉ có thể dùng 1 LoRA Điều Khiển Được",
"collectionStringTooLong": "quá dài, tối đa là {{maxLength}}",
"collectionStringTooShort": "quá ngắn, tối thiểu là {{minLength}}",
"collectionNumberGTMax": "{{value}} > {{maximum}} (giá trị tối đa)",
"collectionNumberLTMin": "{{value}} < {{minimum}} (giá trị tối thiểu)",
"collectionNumberNotMultipleOf": "{{value}} không phải bội của {{multipleOf}}",
"collectionNumberLTExclusiveMin": "{{value}} <= {{exclusiveMinimum}} (giá trị chọn lọc tối thiểu)",
"collectionNumberGTExclusiveMax": "{{value}} >= {{exclusiveMaximum}} (giá trị chọn lọc tối đa)",
"batchNodeCollectionSizeMismatch": "Kích cỡ tài nguyên không phù hợp với Lô {{batchGroupId}}",
"emptyBatches": "lô trống",
"batchNodeNotConnected": "Node Hàng Loạt chưa được kết nối: {{label}}",
"batchNodeEmptyCollection": "Một vài node hàng loạt có tài nguyên rỗng",
"invalidBatchConfigurationCannotCalculate": "Thiết lập lô không hợp lệ; không thể tính toán"
},
"cfgScale": "Thang CFG",
"useSeed": "Dùng Hạt Giống",
@@ -1458,8 +1489,8 @@
"recallMetadata": "Gợi Lại Metadata",
"clipSkip": "CLIP Skip",
"general": "Cài Đặt Chung",
"boxBlur": "Box Blur",
"gaussianBlur": "Gaussian Blur",
"boxBlur": "Làm Mờ Dạng Box",
"gaussianBlur": "Làm Mờ Dạng Gaussian",
"staged": "Staged (Tăng khử nhiễu có hệ thống)",
"scaledHeight": "Tỉ Lệ Dài",
"cancel": {
@@ -1859,7 +1890,25 @@
},
"advanced": "Nâng Cao",
"processingLayerWith": "Đang xử lý layer với bộ lọc {{type}}.",
"forMoreControl": "Để kiểm soát tốt hơn, bấm vào mục Nâng Cao bên dưới."
"forMoreControl": "Để kiểm soát tốt hơn, bấm vào mục Nâng Cao bên dưới.",
"img_blur": {
"description": "Làm mờ layer được chọn.",
"blur_type": "Dạng Làm Mờ",
"blur_radius": "Radius",
"gaussian_type": "Gaussian",
"label": "Làm Mờ Ảnh",
"box_type": "Box"
},
"img_noise": {
"salt_and_pepper_type": "Salt and Pepper",
"noise_amount": "Lượng Nhiễu",
"label": "Độ Nhiễu Ảnh",
"description": "Tăng độ nhiễu vào layer được chọn.",
"noise_type": "Dạng Nhiễu",
"gaussian_type": "Gaussian",
"noise_color": "Màu Nhiễu",
"size": "Cỡ Nhiễu"
}
},
"transform": {
"fitModeCover": "Che Phủ",
@@ -2067,7 +2116,8 @@
"problemCopyingImage": "Không Thể Sao Chép Ảnh",
"problemDownloadingImage": "Không Thể Tải Xuống Ảnh",
"problemCopyingLayer": "Không Thể Sao Chép Layer",
"problemSavingLayer": "Không Thể Lưu Layer"
"problemSavingLayer": "Không Thể Lưu Layer",
"outOfMemoryErrorDescLocal": "Làm theo <LinkComponent>hướng dẫn VRAM Thấp</LinkComponent> của chúng tôi để hạn chế OOM (Tràn bộ nhớ)."
},
"ui": {
"tabs": {
@@ -2153,9 +2203,8 @@
"toGetStartedLocal": "Để bắt đầu, hãy chắc chắn đã tải xuống hoặc thêm vào model cần để chạy Invoke. Sau đó, nhập lệnh vào hộp và nhấp chuột vào <StrongComponent>Kích Hoạt</StrongComponent> để tạo ra bức ảnh đầu tiên. Chọn một mẫu trình bày cho lệnh để cải thiện kết quả. Bạn có thể chọn để lưu ảnh trực tiếp vào <StrongComponent>Thư Viện</StrongComponent> hoặc chỉnh sửa chúng ở <StrongComponent>Canvas</StrongComponent>.",
"gettingStartedSeries": "Cần thêm hướng dẫn? Xem thử <LinkComponent>Bắt Đầu Làm Quen</LinkComponent> để biết thêm mẹo khai thác toàn bộ tiềm năng của Invoke Studio.",
"toGetStarted": "Để bắt đầu, hãy nhập lệnh vào hộp và nhấp chuột vào <StrongComponent>Kích Hoạt</StrongComponent> để tạo ra bức ảnh đầu tiên. Chọn một mẫu trình bày cho lệnh để cải thiện kết quả. Bạn có thể chọn để lưu ảnh trực tiếp vào <StrongComponent>Thư Viện</StrongComponent> hoặc chỉnh sửa chúng ở <StrongComponent>Canvas</StrongComponent>.",
"downloadStarterModels": "Tải Xuống Model Khởi Đầu",
"importModels": "Nhập Vào Model",
"noModelsInstalled": "Hình như bạn không có model nào được tải cả"
"noModelsInstalled": "Dường như bạn chưa tải model nào cả! Bạn có thể <DownloadStarterModelsButton>tải xuống các model khởi đầu</DownloadStarterModelsButton> hoặc <ImportModelsButton>nhập vào thêm model</ImportModelsButton>.",
"lowVRAMMode": "Cho hiệu suất tốt nhất, hãy làm theo <LinkComponent>hướng dẫn VRAM Thấp</LinkComponent> của chúng tôi."
},
"whatsNew": {
"whatsNewInInvoke": "Có Gì Mới Ở Invoke",
@@ -2163,7 +2212,10 @@
"watchRecentReleaseVideos": "Xem Video Phát Hành Mới Nhất",
"watchUiUpdatesOverview": "Xem Tổng Quan Về Những Cập Nhật Cho Giao Diện Người Dùng",
"items": [
"<StrongComponent>Hướng Dẫn Khu Vực FLUX (beta)</StrongComponent>: Bản beta của Hướng Dẫn Khu Vực FLUX của chúng ta đã có mắt tại bảng điều khiển lệnh khu vực."
"Chế độ VRAM thấp",
"Trình quản lý bộ nhớ động",
"Tải model nhanh hơn",
"Ít lỗi bộ nhớ hơn"
]
},
"upsell": {

View File

@@ -1,16 +1,14 @@
import { logger } from 'app/logging/logger';
import { enqueueRequested } from 'app/store/actions';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectNodesSlice } from 'features/nodes/store/selectors';
import { isImageFieldCollectionInputInstance } from 'features/nodes/types/field';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { isBatchNode, isInvocationNode } from 'features/nodes/types/invocation';
import { buildNodesGraph } from 'features/nodes/util/graph/buildNodesGraph';
import { buildWorkflowWithValidation } from 'features/nodes/util/workflow/buildWorkflow';
import { resolveBatchValue } from 'features/queue/store/readiness';
import { groupBy } from 'lodash-es';
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
import type { Batch, BatchConfig } from 'services/api/types';
const log = logger('workflows');
export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) => {
startAppListening({
predicate: (action): action is ReturnType<typeof enqueueRequested> =>
@@ -33,28 +31,54 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
const data: Batch['data'] = [];
// Skip edges from batch nodes - these should not be in the graph, they exist only in the UI
const imageBatchNodes = nodes.nodes.filter(isInvocationNode).filter((node) => node.data.type === 'image_batch');
for (const node of imageBatchNodes) {
const images = node.data.inputs['images'];
if (!isImageFieldCollectionInputInstance(images)) {
log.warn({ nodeId: node.id }, 'Image batch images field is not an image collection');
break;
}
const edgesFromImageBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'image');
const batchDataCollectionItem: NonNullable<Batch['data']>[number] = [];
for (const edge of edgesFromImageBatch) {
if (!edge.targetHandle) {
break;
const invocationNodes = nodes.nodes.filter(isInvocationNode);
const batchNodes = invocationNodes.filter(isBatchNode);
// Handle zipping batch nodes. First group the batch nodes by their batch_group_id
const groupedBatchNodes = groupBy(batchNodes, (node) => node.data.inputs['batch_group_id']?.value);
// Then, we will create a batch data collection item for each group
for (const [batchGroupId, batchNodes] of Object.entries(groupedBatchNodes)) {
const zippedBatchDataCollectionItems: NonNullable<Batch['data']>[number] = [];
for (const node of batchNodes) {
const value = resolveBatchValue(node, invocationNodes, nodes.edges);
const sourceHandle = node.data.type === 'image_batch' ? 'image' : 'value';
const edgesFromBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === sourceHandle);
if (batchGroupId !== 'None') {
// If this batch node has a batch_group_id, we will zip the data collection items
for (const edge of edgesFromBatch) {
if (!edge.targetHandle) {
break;
}
zippedBatchDataCollectionItems.push({
node_path: edge.target,
field_name: edge.targetHandle,
items: value,
});
}
} else {
// Otherwise add the data collection items to root of the batch so they are not zipped
const productBatchDataCollectionItems: NonNullable<Batch['data']>[number] = [];
for (const edge of edgesFromBatch) {
if (!edge.targetHandle) {
break;
}
productBatchDataCollectionItems.push({
node_path: edge.target,
field_name: edge.targetHandle,
items: value,
});
}
if (productBatchDataCollectionItems.length > 0) {
data.push(productBatchDataCollectionItems);
}
}
batchDataCollectionItem.push({
node_path: edge.target,
field_name: edge.targetHandle,
items: images.value,
});
}
if (batchDataCollectionItem.length > 0) {
data.push(batchDataCollectionItem);
// Finally, if this batch data collection item has any items, add it to the data array
if (batchGroupId !== 'None' && zippedBatchDataCollectionItems.length > 0) {
data.push(zippedBatchDataCollectionItems);
}
}

View File

@@ -0,0 +1,72 @@
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
import { Combobox, CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
import type { BlurFilterConfig, BlurTypes } from 'features/controlLayers/store/filters';
import { IMAGE_FILTERS, isBlurTypes } from 'features/controlLayers/store/filters';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { FilterComponentProps } from './types';
type Props = FilterComponentProps<BlurFilterConfig>;
const DEFAULTS = IMAGE_FILTERS.img_blur.buildDefaults();
export const FilterBlur = memo(({ onChange, config }: Props) => {
const { t } = useTranslation();
const handleBlurTypeChange = useCallback<ComboboxOnChange>(
(v) => {
if (!isBlurTypes(v?.value)) {
return;
}
onChange({ ...config, blur_type: v.value });
},
[config, onChange]
);
const handleRadiusChange = useCallback(
(v: number) => {
onChange({ ...config, radius: v });
},
[config, onChange]
);
const options: { label: string; value: BlurTypes }[] = useMemo(
() => [
{ label: t('controlLayers.filter.img_blur.gaussian_type'), value: 'gaussian' },
{ label: t('controlLayers.filter.img_blur.box_type'), value: 'box' },
],
[t]
);
const value = useMemo(() => options.filter((o) => o.value === config.blur_type)[0], [options, config.blur_type]);
return (
<>
<FormControl>
<FormLabel m={0}>{t('controlLayers.filter.img_blur.blur_type')}</FormLabel>
<Combobox value={value} options={options} onChange={handleBlurTypeChange} isSearchable={false} />
</FormControl>
<FormControl>
<FormLabel m={0}>{t('controlLayers.filter.img_blur.blur_radius')}</FormLabel>
<CompositeSlider
value={config.radius}
defaultValue={DEFAULTS.radius}
onChange={handleRadiusChange}
min={1}
max={64}
step={0.1}
marks
/>
<CompositeNumberInput
value={config.radius}
defaultValue={DEFAULTS.radius}
onChange={handleRadiusChange}
min={1}
max={4096}
step={0.1}
/>
</FormControl>
</>
);
});
FilterBlur.displayName = 'FilterBlur';

View File

@@ -0,0 +1,111 @@
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
import { Combobox, CompositeNumberInput, CompositeSlider, FormControl, FormLabel, Switch } from '@invoke-ai/ui-library';
import type { NoiseFilterConfig, NoiseTypes } from 'features/controlLayers/store/filters';
import { IMAGE_FILTERS, isNoiseTypes } from 'features/controlLayers/store/filters';
import type { ChangeEvent } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { FilterComponentProps } from './types';
type Props = FilterComponentProps<NoiseFilterConfig>;
const DEFAULTS = IMAGE_FILTERS.img_noise.buildDefaults();
export const FilterNoise = memo(({ onChange, config }: Props) => {
const { t } = useTranslation();
const handleNoiseTypeChange = useCallback<ComboboxOnChange>(
(v) => {
if (!isNoiseTypes(v?.value)) {
return;
}
onChange({ ...config, noise_type: v.value });
},
[config, onChange]
);
const handleAmountChange = useCallback(
(v: number) => {
onChange({ ...config, amount: v });
},
[config, onChange]
);
const handleColorChange = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
onChange({ ...config, noise_color: e.target.checked });
},
[config, onChange]
);
const handleSizeChange = useCallback(
(v: number) => {
onChange({ ...config, size: v });
},
[config, onChange]
);
const options: { label: string; value: NoiseTypes }[] = useMemo(
() => [
{ label: t('controlLayers.filter.img_noise.gaussian_type'), value: 'gaussian' },
{ label: t('controlLayers.filter.img_noise.salt_and_pepper_type'), value: 'salt_and_pepper' },
],
[t]
);
const value = useMemo(() => options.filter((o) => o.value === config.noise_type)[0], [options, config.noise_type]);
return (
<>
<FormControl>
<FormLabel m={0}>{t('controlLayers.filter.img_noise.noise_type')}</FormLabel>
<Combobox value={value} options={options} onChange={handleNoiseTypeChange} isSearchable={false} />
</FormControl>
<FormControl>
<FormLabel m={0}>{t('controlLayers.filter.img_noise.noise_amount')}</FormLabel>
<CompositeSlider
value={config.amount}
defaultValue={DEFAULTS.amount}
onChange={handleAmountChange}
min={0}
max={1}
step={0.01}
marks
/>
<CompositeNumberInput
value={config.amount}
defaultValue={DEFAULTS.amount}
onChange={handleAmountChange}
min={0}
max={1}
step={0.01}
/>
</FormControl>
<FormControl>
<FormLabel m={0}>{t('controlLayers.filter.img_noise.size')}</FormLabel>
<CompositeSlider
value={config.size}
defaultValue={DEFAULTS.size}
onChange={handleSizeChange}
min={1}
max={16}
step={1}
marks
/>
<CompositeNumberInput
value={config.size}
defaultValue={DEFAULTS.size}
onChange={handleSizeChange}
min={1}
max={256}
step={1}
/>
</FormControl>
<FormControl w="max-content">
<FormLabel m={0}>{t('controlLayers.filter.img_noise.noise_color')}</FormLabel>
<Switch defaultChecked={DEFAULTS.noise_color} isChecked={config.noise_color} onChange={handleColorChange} />
</FormControl>
</>
);
});
FilterNoise.displayName = 'Filternoise';

View File

@@ -1,4 +1,5 @@
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { FilterBlur } from 'features/controlLayers/components/Filters/FilterBlur';
import { FilterCannyEdgeDetection } from 'features/controlLayers/components/Filters/FilterCannyEdgeDetection';
import { FilterColorMap } from 'features/controlLayers/components/Filters/FilterColorMap';
import { FilterContentShuffle } from 'features/controlLayers/components/Filters/FilterContentShuffle';
@@ -8,6 +9,7 @@ import { FilterHEDEdgeDetection } from 'features/controlLayers/components/Filter
import { FilterLineartEdgeDetection } from 'features/controlLayers/components/Filters/FilterLineartEdgeDetection';
import { FilterMediaPipeFaceDetection } from 'features/controlLayers/components/Filters/FilterMediaPipeFaceDetection';
import { FilterMLSDDetection } from 'features/controlLayers/components/Filters/FilterMLSDDetection';
import { FilterNoise } from 'features/controlLayers/components/Filters/FilterNoise';
import { FilterPiDiNetEdgeDetection } from 'features/controlLayers/components/Filters/FilterPiDiNetEdgeDetection';
import { FilterSpandrel } from 'features/controlLayers/components/Filters/FilterSpandrel';
import type { FilterConfig } from 'features/controlLayers/store/filters';
@@ -19,6 +21,10 @@ type Props = { filterConfig: FilterConfig; onChange: (filterConfig: FilterConfig
export const FilterSettings = memo(({ filterConfig, onChange }: Props) => {
const { t } = useTranslation();
if (filterConfig.type === 'img_blur') {
return <FilterBlur config={filterConfig} onChange={onChange} />;
}
if (filterConfig.type === 'canny_edge_detection') {
return <FilterCannyEdgeDetection config={filterConfig} onChange={onChange} />;
}
@@ -59,6 +65,10 @@ export const FilterSettings = memo(({ filterConfig, onChange }: Props) => {
return <FilterPiDiNetEdgeDetection config={filterConfig} onChange={onChange} />;
}
if (filterConfig.type === 'img_noise') {
return <FilterNoise config={filterConfig} onChange={onChange} />;
}
if (filterConfig.type === 'spandrel_filter') {
return <FilterSpandrel config={filterConfig} onChange={onChange} />;
}

View File

@@ -297,10 +297,9 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
const imageState = imageDTOToImageObject(filterResult.value);
this.$imageState.set(imageState);
// Destroy any existing masked image and create a new one
if (this.imageModule) {
this.imageModule.destroy();
}
// Stash the existing image module - we will destroy it after the new image is rendered to prevent a flash
// of an empty layer
const oldImageModule = this.imageModule;
this.imageModule = new CanvasObjectImage(imageState, this);
@@ -309,6 +308,16 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
this.konva.group.add(this.imageModule.konva.group);
// The filtered image have some transparency, so we need to hide the objects of the parent entity to prevent the
// two images from blending. We will show the objects again in the teardown method, which is always called after
// the filter finishes (applied or canceled).
this.parent.renderer.hideObjects();
if (oldImageModule) {
// Destroy the old image module now that the new one is rendered
oldImageModule.destroy();
}
// The porcessing is complete, set can set the last processed hash and isProcessing to false
this.$lastProcessedHash.set(hash);
@@ -424,6 +433,8 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
teardown = () => {
this.unsubscribe();
// Re-enable the objects of the parent entity
this.parent.renderer.showObjects();
this.konva.group.remove();
// The reset must be done _after_ unsubscribing from listeners, in case the listeners would otherwise react to
// the reset. For example, if auto-processing is enabled and we reset the state, it may trigger processing.

View File

@@ -185,6 +185,14 @@ export class CanvasEntityObjectRenderer extends CanvasModuleBase {
return didRender;
};
hideObjects = () => {
this.konva.objectGroup.hide();
};
showObjects = () => {
this.konva.objectGroup.show();
};
adoptObjectRenderer = (renderer: AnyObjectRenderer) => {
this.renderers.set(renderer.id, renderer);
renderer.konva.group.moveTo(this.konva.objectGroup);

View File

@@ -10,6 +10,7 @@ import {
getKonvaNodeDebugAttrs,
getPrefixedId,
offsetCoord,
roundRect,
} from 'features/controlLayers/konva/util';
import { selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
import type { Coordinate, Rect, RectWithRotation } from 'features/controlLayers/store/types';
@@ -773,7 +774,7 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
const rect = this.getRelativeRect();
const rasterizeResult = await withResultAsync(() =>
this.parent.renderer.rasterize({
rect,
rect: roundRect(rect),
replaceObjects: true,
ignoreCache: true,
attrs: { opacity: 1, filters: [] },

View File

@@ -740,3 +740,12 @@ export const getColorAtCoordinate = (stage: Konva.Stage, coord: Coordinate): Rgb
return { r, g, b };
};
export const roundRect = (rect: Rect): Rect => {
return {
x: Math.round(rect.x),
y: Math.round(rect.y),
width: Math.round(rect.width),
height: Math.round(rect.height),
};
};

View File

@@ -95,6 +95,28 @@ const zSpandrelFilterConfig = z.object({
});
export type SpandrelFilterConfig = z.infer<typeof zSpandrelFilterConfig>;
const zBlurTypes = z.enum(['gaussian', 'box']);
export type BlurTypes = z.infer<typeof zBlurTypes>;
export const isBlurTypes = (v: unknown): v is BlurTypes => zBlurTypes.safeParse(v).success;
const zBlurFilterConfig = z.object({
type: z.literal('img_blur'),
blur_type: zBlurTypes,
radius: z.number().gte(0),
});
export type BlurFilterConfig = z.infer<typeof zBlurFilterConfig>;
const zNoiseTypes = z.enum(['gaussian', 'salt_and_pepper']);
export type NoiseTypes = z.infer<typeof zNoiseTypes>;
export const isNoiseTypes = (v: unknown): v is NoiseTypes => zNoiseTypes.safeParse(v).success;
const zNoiseFilterConfig = z.object({
type: z.literal('img_noise'),
noise_type: zNoiseTypes,
amount: z.number().gte(0).lte(1),
noise_color: z.boolean(),
size: z.number().int().gte(1),
});
export type NoiseFilterConfig = z.infer<typeof zNoiseFilterConfig>;
const zFilterConfig = z.discriminatedUnion('type', [
zCannyEdgeDetectionFilterConfig,
zColorMapFilterConfig,
@@ -109,6 +131,8 @@ const zFilterConfig = z.discriminatedUnion('type', [
zPiDiNetEdgeDetectionFilterConfig,
zDWOpenposeDetectionFilterConfig,
zSpandrelFilterConfig,
zBlurFilterConfig,
zNoiseFilterConfig,
]);
export type FilterConfig = z.infer<typeof zFilterConfig>;
@@ -126,6 +150,8 @@ const zFilterType = z.enum([
'pidi_edge_detection',
'dw_openpose_detection',
'spandrel_filter',
'img_blur',
'img_noise',
]);
export type FilterType = z.infer<typeof zFilterType>;
export const isFilterType = (v: unknown): v is FilterType => zFilterType.safeParse(v).success;
@@ -429,6 +455,62 @@ export const IMAGE_FILTERS: { [key in FilterConfig['type']]: ImageFilterData<key
return true;
},
},
img_blur: {
type: 'img_blur',
buildDefaults: () => ({
type: 'img_blur',
blur_type: 'gaussian',
radius: 8,
}),
buildGraph: ({ image_name }, { blur_type, radius }) => {
const graph = new Graph(getPrefixedId('img_blur'));
const node = graph.addNode({
id: getPrefixedId('img_blur'),
type: 'img_blur',
image: { image_name },
blur_type: blur_type,
radius: radius,
});
return {
graph,
outputNodeId: node.id,
};
},
},
img_noise: {
type: 'img_noise',
buildDefaults: () => ({
type: 'img_noise',
noise_type: 'gaussian',
amount: 0.3,
noise_color: true,
size: 1,
}),
buildGraph: ({ image_name }, { noise_type, amount, noise_color, size }) => {
const graph = new Graph(getPrefixedId('img_noise'));
const node = graph.addNode({
id: getPrefixedId('img_noise'),
type: 'img_noise',
image: { image_name },
noise_type: noise_type,
amount: amount,
noise_color: noise_color,
size: size,
});
const rand = graph.addNode({
id: getPrefixedId('rand_int'),
use_cache: false,
type: 'rand_int',
low: 0,
high: 2147483647,
});
graph.addEdge(rand, 'value', node, 'seed');
return {
graph,
outputNodeId: node.id,
};
},
},
} as const;
/**

View File

@@ -1,4 +1,5 @@
import type {
BlurFilterConfig,
CannyEdgeDetectionFilterConfig,
ColorMapFilterConfig,
ContentShuffleFilterConfig,
@@ -12,6 +13,7 @@ import type {
LineartEdgeDetectionFilterConfig,
MediaPipeFaceDetectionFilterConfig,
MLSDDetectionFilterConfig,
NoiseFilterConfig,
NormalMapFilterConfig,
PiDiNetEdgeDetectionFilterConfig,
} from 'features/controlLayers/store/filters';
@@ -54,6 +56,7 @@ describe('Control Adapter Types', () => {
});
test('Processor Configs', () => {
// Types derived from OpenAPI
type _BlurFilterConfig = Required<Pick<Invocation<'img_blur'>, 'type' | 'radius' | 'blur_type'>>;
type _CannyEdgeDetectionFilterConfig = Required<
Pick<Invocation<'canny_edge_detection'>, 'type' | 'low_threshold' | 'high_threshold'>
>;
@@ -71,6 +74,9 @@ describe('Control Adapter Types', () => {
type _MLSDDetectionFilterConfig = Required<
Pick<Invocation<'mlsd_detection'>, 'type' | 'score_threshold' | 'distance_threshold'>
>;
type _NoiseFilterConfig = Required<
Pick<Invocation<'img_noise'>, 'type' | 'noise_type' | 'amount' | 'noise_color' | 'size'>
>;
type _NormalMapFilterConfig = Required<Pick<Invocation<'normal_map'>, 'type'>>;
type _DWOpenposeDetectionFilterConfig = Required<
Pick<Invocation<'dw_openpose_detection'>, 'type' | 'draw_body' | 'draw_face' | 'draw_hands'>
@@ -81,6 +87,7 @@ describe('Control Adapter Types', () => {
// The processor configs are manually modeled zod schemas. This test ensures that the inferred types are correct.
// The types prefixed with `_` are types generated from OpenAPI, while the types without the prefix are manually modeled.
assert<Equals<_BlurFilterConfig, BlurFilterConfig>>();
assert<Equals<_CannyEdgeDetectionFilterConfig, CannyEdgeDetectionFilterConfig>>();
assert<Equals<_ColorMapFilterConfig, ColorMapFilterConfig>>();
assert<Equals<_ContentShuffleFilterConfig, ContentShuffleFilterConfig>>();
@@ -90,6 +97,7 @@ describe('Control Adapter Types', () => {
assert<Equals<_LineartEdgeDetectionFilterConfig, LineartEdgeDetectionFilterConfig>>();
assert<Equals<_MediaPipeFaceDetectionFilterConfig, MediaPipeFaceDetectionFilterConfig>>();
assert<Equals<_MLSDDetectionFilterConfig, MLSDDetectionFilterConfig>>();
assert<Equals<_NoiseFilterConfig, NoiseFilterConfig>>();
assert<Equals<_NormalMapFilterConfig, NormalMapFilterConfig>>();
assert<Equals<_DWOpenposeDetectionFilterConfig, DWOpenposeDetectionFilterConfig>>();
assert<Equals<_PiDiNetEdgeDetectionFilterConfig, PiDiNetEdgeDetectionFilterConfig>>();

View File

@@ -11,7 +11,7 @@ import type { DndTargetState } from 'features/dnd/types';
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
import { selectMaxImageUploadCount } from 'features/system/store/configSlice';
import { toast } from 'features/toast/toast';
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { memo, useCallback, useEffect, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { uploadImages } from 'services/api/endpoints/images';
import { useBoardName } from 'services/api/hooks/useBoardName';
@@ -72,11 +72,10 @@ export const FullscreenDropzone = memo(() => {
const maxImageUploadCount = useAppSelector(selectMaxImageUploadCount);
const [dndState, setDndState] = useState<DndTargetState>('idle');
const uploadFilesSchema = useMemo(() => getFilesSchema(maxImageUploadCount), [maxImageUploadCount]);
const validateAndUploadFiles = useCallback(
(files: File[]) => {
const { getState } = getStore();
const uploadFilesSchema = getFilesSchema(maxImageUploadCount);
const parseResult = uploadFilesSchema.safeParse(files);
if (!parseResult.success) {
@@ -105,7 +104,18 @@ export const FullscreenDropzone = memo(() => {
uploadImages(uploadArgs);
},
[maxImageUploadCount, t, uploadFilesSchema]
[maxImageUploadCount, t]
);
const onPaste = useCallback(
(e: ClipboardEvent) => {
if (!e.clipboardData?.files) {
return;
}
const files = Array.from(e.clipboardData.files);
validateAndUploadFiles(files);
},
[validateAndUploadFiles]
);
useEffect(() => {
@@ -144,24 +154,12 @@ export const FullscreenDropzone = memo(() => {
}, [validateAndUploadFiles]);
useEffect(() => {
const controller = new AbortController();
document.addEventListener(
'paste',
(e) => {
if (!e.clipboardData?.files) {
return;
}
const files = Array.from(e.clipboardData.files);
validateAndUploadFiles(files);
},
{ signal: controller.signal }
);
window.addEventListener('paste', onPaste);
return () => {
controller.abort();
window.removeEventListener('paste', onPaste);
};
}, [validateAndUploadFiles]);
}, [onPaste]);
return (
<Box ref={ref} data-dnd-state={dndState} sx={sx}>

View File

@@ -1,3 +1,4 @@
import { logger } from 'app/logging/logger';
import type { AppDispatch, RootState } from 'app/store/store';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import type {
@@ -9,7 +10,6 @@ import { selectComparisonImages } from 'features/gallery/components/ImageViewer/
import type { BoardId } from 'features/gallery/store/types';
import {
addImagesToBoard,
addImagesToNodeImageFieldCollectionAction,
createNewCanvasEntityFromImage,
removeImagesFromBoard,
replaceCanvasEntityObjectsWithImage,
@@ -19,10 +19,14 @@ import {
setRegionalGuidanceReferenceImage,
setUpscaleInitialImage,
} from 'features/imageActions/actions';
import type { FieldIdentifier } from 'features/nodes/types/field';
import { fieldImageCollectionValueChanged } from 'features/nodes/store/nodesSlice';
import { selectFieldInputInstance, selectNodesSlice } from 'features/nodes/store/selectors';
import { type FieldIdentifier, isImageFieldCollectionInputInstance } from 'features/nodes/types/field';
import type { ImageDTO } from 'services/api/types';
import type { JsonObject } from 'type-fest';
const log = logger('dnd');
type RecordUnknown = Record<string | symbol, unknown>;
type DndData<
@@ -268,15 +272,27 @@ export const addImagesToNodeImageFieldCollectionDndTarget: DndTarget<
}
const { fieldIdentifier } = targetData.payload;
const imageDTOs: ImageDTO[] = [];
if (singleImageDndSource.typeGuard(sourceData)) {
imageDTOs.push(sourceData.payload.imageDTO);
} else {
imageDTOs.push(...sourceData.payload.imageDTOs);
const fieldInputInstance = selectFieldInputInstance(
selectNodesSlice(getState()),
fieldIdentifier.nodeId,
fieldIdentifier.fieldName
);
if (!isImageFieldCollectionInputInstance(fieldInputInstance)) {
log.warn({ fieldIdentifier }, 'Attempted to add images to a non-image field collection');
return;
}
addImagesToNodeImageFieldCollectionAction({ fieldIdentifier, imageDTOs, dispatch, getState });
const newValue = fieldInputInstance.value ? [...fieldInputInstance.value] : [];
if (singleImageDndSource.typeGuard(sourceData)) {
newValue.push({ image_name: sourceData.payload.imageDTO.image_name });
} else {
newValue.push(...sourceData.payload.imageDTOs.map(({ image_name }) => ({ image_name })));
}
dispatch(fieldImageCollectionValueChanged({ ...fieldIdentifier, value: newValue }));
},
};
//#endregion

View File

@@ -1,4 +1,3 @@
import { logger } from 'app/logging/logger';
import type { AppDispatch, RootState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import { selectDefaultIPAdapter } from 'features/controlLayers/hooks/addLayerHooks';
@@ -31,19 +30,15 @@ import { imageDTOToImageObject, imageDTOToImageWithDims, initialControlNet } fro
import { calculateNewSize } from 'features/controlLayers/util/getScaledBoundingBoxDimensions';
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
import type { BoardId } from 'features/gallery/store/types';
import { fieldImageCollectionValueChanged, fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { selectFieldInputInstance, selectNodesSlice } from 'features/nodes/store/selectors';
import { type FieldIdentifier, isImageFieldCollectionInputInstance } from 'features/nodes/types/field';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import type { FieldIdentifier } from 'features/nodes/types/field';
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
import { getOptimalDimension } from 'features/parameters/util/optimalDimension';
import { uniqBy } from 'lodash-es';
import { imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
const log = logger('system');
export const setGlobalReferenceImage = (arg: {
imageDTO: ImageDTO;
entityIdentifier: CanvasEntityIdentifier<'reference_image'>;
@@ -77,54 +72,6 @@ export const setNodeImageFieldImage = (arg: {
dispatch(fieldImageValueChanged({ ...fieldIdentifier, value: imageDTO }));
};
export const addImagesToNodeImageFieldCollectionAction = (arg: {
imageDTOs: ImageDTO[];
fieldIdentifier: FieldIdentifier;
dispatch: AppDispatch;
getState: () => RootState;
}) => {
const { imageDTOs, fieldIdentifier, dispatch, getState } = arg;
const fieldInputInstance = selectFieldInputInstance(
selectNodesSlice(getState()),
fieldIdentifier.nodeId,
fieldIdentifier.fieldName
);
if (!isImageFieldCollectionInputInstance(fieldInputInstance)) {
log.warn({ fieldIdentifier }, 'Attempted to add images to a non-image field collection');
return;
}
const images = fieldInputInstance.value ? [...fieldInputInstance.value] : [];
images.push(...imageDTOs.map(({ image_name }) => ({ image_name })));
const uniqueImages = uniqBy(images, 'image_name');
dispatch(fieldImageCollectionValueChanged({ ...fieldIdentifier, value: uniqueImages }));
};
export const removeImageFromNodeImageFieldCollectionAction = (arg: {
imageName: string;
fieldIdentifier: FieldIdentifier;
dispatch: AppDispatch;
getState: () => RootState;
}) => {
const { imageName, fieldIdentifier, dispatch, getState } = arg;
const fieldInputInstance = selectFieldInputInstance(
selectNodesSlice(getState()),
fieldIdentifier.nodeId,
fieldIdentifier.fieldName
);
if (!isImageFieldCollectionInputInstance(fieldInputInstance)) {
log.warn({ fieldIdentifier }, 'Attempted to remove image from a non-image field collection');
return;
}
const images = fieldInputInstance.value ? [...fieldInputInstance.value] : [];
const imagesWithoutTheImageToRemove = images.filter((image) => image.image_name !== imageName);
const uniqueImages = uniqBy(imagesWithoutTheImageToRemove, 'image_name');
dispatch(fieldImageCollectionValueChanged({ ...fieldIdentifier, value: uniqueImages }));
};
export const setComparisonImage = (arg: { imageDTO: ImageDTO; dispatch: AppDispatch }) => {
const { imageDTO, dispatch } = arg;
dispatch(imageToCompareChanged(imageDTO));

View File

@@ -36,10 +36,15 @@ const FieldHandle = (props: FieldHandleProps) => {
borderWidth: !isSingle(type) ? 4 : 0,
borderStyle: 'solid',
borderColor: color,
borderRadius: isModelType ? 4 : '100%',
borderRadius: isModelType || type.batch ? 4 : '100%',
zIndex: 1,
transformOrigin: 'center',
};
if (type.batch) {
s.transform = 'rotate(45deg) translateX(-0.3rem) translateY(-0.3rem)';
}
if (handleType === 'target') {
s.insetInlineStart = '-1rem';
} else {

View File

@@ -1,5 +1,10 @@
import { FloatGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FloatGeneratorFieldComponent';
import { ImageFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldCollectionInputComponent';
import { IntegerGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerGeneratorFieldComponent';
import ModelIdentifierFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent';
import { NumberFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/NumberFieldCollectionInputComponent';
import { StringFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/StringFieldCollectionInputComponent';
import { StringGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/StringGeneratorFieldComponent';
import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance';
import { useFieldInputTemplate } from 'features/nodes/hooks/useFieldInputTemplate';
import {
@@ -21,8 +26,12 @@ import {
isControlNetModelFieldInputTemplate,
isEnumFieldInputInstance,
isEnumFieldInputTemplate,
isFloatFieldCollectionInputInstance,
isFloatFieldCollectionInputTemplate,
isFloatFieldInputInstance,
isFloatFieldInputTemplate,
isFloatGeneratorFieldInputInstance,
isFloatGeneratorFieldInputTemplate,
isFluxMainModelFieldInputInstance,
isFluxMainModelFieldInputTemplate,
isFluxVAEModelFieldInputInstance,
@@ -31,8 +40,12 @@ import {
isImageFieldCollectionInputTemplate,
isImageFieldInputInstance,
isImageFieldInputTemplate,
isIntegerFieldCollectionInputInstance,
isIntegerFieldCollectionInputTemplate,
isIntegerFieldInputInstance,
isIntegerFieldInputTemplate,
isIntegerGeneratorFieldInputInstance,
isIntegerGeneratorFieldInputTemplate,
isIPAdapterModelFieldInputInstance,
isIPAdapterModelFieldInputTemplate,
isLoRAModelFieldInputInstance,
@@ -51,8 +64,12 @@ import {
isSDXLRefinerModelFieldInputTemplate,
isSpandrelImageToImageModelFieldInputInstance,
isSpandrelImageToImageModelFieldInputTemplate,
isStringFieldCollectionInputInstance,
isStringFieldCollectionInputTemplate,
isStringFieldInputInstance,
isStringFieldInputTemplate,
isStringGeneratorFieldInputInstance,
isStringGeneratorFieldInputTemplate,
isT2IAdapterModelFieldInputInstance,
isT2IAdapterModelFieldInputTemplate,
isT5EncoderModelFieldInputInstance,
@@ -97,6 +114,10 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
const fieldInstance = useFieldInputInstance(nodeId, fieldName);
const fieldTemplate = useFieldInputTemplate(nodeId, fieldName);
if (isStringFieldCollectionInputInstance(fieldInstance) && isStringFieldCollectionInputTemplate(fieldTemplate)) {
return <StringFieldCollectionInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isStringFieldInputInstance(fieldInstance) && isStringFieldInputTemplate(fieldTemplate)) {
return <StringFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
@@ -105,13 +126,22 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
return <BooleanFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (
(isIntegerFieldInputInstance(fieldInstance) && isIntegerFieldInputTemplate(fieldTemplate)) ||
(isFloatFieldInputInstance(fieldInstance) && isFloatFieldInputTemplate(fieldTemplate))
) {
if (isIntegerFieldInputInstance(fieldInstance) && isIntegerFieldInputTemplate(fieldTemplate)) {
return <NumberFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isFloatFieldInputInstance(fieldInstance) && isFloatFieldInputTemplate(fieldTemplate)) {
return <NumberFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isIntegerFieldCollectionInputInstance(fieldInstance) && isIntegerFieldCollectionInputTemplate(fieldTemplate)) {
return <NumberFieldCollectionInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isFloatFieldCollectionInputInstance(fieldInstance) && isFloatFieldCollectionInputTemplate(fieldTemplate)) {
return <NumberFieldCollectionInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isEnumFieldInputInstance(fieldInstance) && isEnumFieldInputTemplate(fieldTemplate)) {
return <EnumFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
@@ -216,6 +246,18 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
return <SchedulerFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isFloatGeneratorFieldInputInstance(fieldInstance) && isFloatGeneratorFieldInputTemplate(fieldTemplate)) {
return <FloatGeneratorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isIntegerGeneratorFieldInputInstance(fieldInstance) && isIntegerGeneratorFieldInputTemplate(fieldTemplate)) {
return <IntegerGeneratorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isStringGeneratorFieldInputInstance(fieldInstance) && isStringGeneratorFieldInputTemplate(fieldTemplate)) {
return <StringGeneratorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (fieldTemplate) {
// Fallback for when there is no component for the type
return null;

View File

@@ -1,6 +1,6 @@
import { Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { createSelector } from '@reduxjs/toolkit';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { $templates } from 'features/nodes/store/nodesSlice';
import { selectInvocationNode, selectNodesSlice } from 'features/nodes/store/selectors';
@@ -18,7 +18,7 @@ export const InvocationInputFieldCheck = memo(({ nodeId, fieldName, children }:
const templates = useStore($templates);
const selector = useMemo(
() =>
createSelector(selectNodesSlice, (nodesSlice) => {
createMemoizedSelector(selectNodesSlice, (nodesSlice) => {
const node = selectInvocationNode(nodesSlice, nodeId);
const instance = node.data.inputs[fieldName];
const template = templates[node.data.type];

View File

@@ -26,7 +26,7 @@ const EnumFieldInputComponent = (props: FieldComponentProps<EnumFieldInputInstan
);
return (
<Select className="nowheel nodrag" onChange={handleValueChanged} value={field.value}>
<Select className="nowheel nodrag" onChange={handleValueChanged} value={field.value} size="sm">
{fieldTemplate.options.map((option) => (
<option key={option} value={option}>
{fieldTemplate.ui_choice_labels ? fieldTemplate.ui_choice_labels[option] : option}

View File

@@ -0,0 +1,57 @@
import { CompositeNumberInput, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import type { FloatGeneratorArithmeticSequence } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
type FloatGeneratorArithmeticSequenceSettingsProps = {
state: FloatGeneratorArithmeticSequence;
onChange: (state: FloatGeneratorArithmeticSequence) => void;
};
export const FloatGeneratorArithmeticSequenceSettings = memo(
({ state, onChange }: FloatGeneratorArithmeticSequenceSettingsProps) => {
const { t } = useTranslation();
const onChangeStart = useCallback(
(start: number) => {
onChange({ ...state, start });
},
[onChange, state]
);
const onChangeStep = useCallback(
(step: number) => {
onChange({ ...state, step });
},
[onChange, state]
);
const onChangeCount = useCallback(
(count: number) => {
onChange({ ...state, count });
},
[onChange, state]
);
return (
<Flex gap={2} alignItems="flex-end">
<FormControl orientation="vertical">
<FormLabel>{t('common.start')}</FormLabel>
<CompositeNumberInput
value={state.start}
onChange={onChangeStart}
min={-Infinity}
max={Infinity}
step={0.01}
/>
</FormControl>
<FormControl orientation="vertical">
<FormLabel>{t('common.step')}</FormLabel>
<CompositeNumberInput value={state.step} onChange={onChangeStep} min={-Infinity} max={Infinity} step={0.01} />
</FormControl>
<FormControl orientation="vertical">
<FormLabel>{t('common.count')}</FormLabel>
<CompositeNumberInput value={state.count} onChange={onChangeCount} min={1} max={Infinity} />
</FormControl>
</Flex>
);
}
);
FloatGeneratorArithmeticSequenceSettings.displayName = 'FloatGeneratorArithmeticSequenceSettings';

View File

@@ -0,0 +1,120 @@
import { Flex, Select, Text } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { getOverlayScrollbarsParams, overlayScrollbarsStyles } from 'common/components/OverlayScrollbars/constants';
import { FloatGeneratorArithmeticSequenceSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FloatGeneratorArithmeticSequenceSettings';
import { FloatGeneratorLinearDistributionSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FloatGeneratorLinearDistributionSettings';
import { FloatGeneratorParseStringSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FloatGeneratorParseStringSettings';
import { FloatGeneratorUniformRandomDistributionSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FloatGeneratorUniformRandomDistributionSettings';
import type { FieldComponentProps } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/types';
import { fieldFloatGeneratorValueChanged } from 'features/nodes/store/nodesSlice';
import type { FloatGeneratorFieldInputInstance, FloatGeneratorFieldInputTemplate } from 'features/nodes/types/field';
import {
FloatGeneratorArithmeticSequenceType,
FloatGeneratorLinearDistributionType,
FloatGeneratorParseStringType,
FloatGeneratorUniformRandomDistributionType,
getFloatGeneratorDefaults,
resolveFloatGeneratorField,
} from 'features/nodes/types/field';
import { isNil, round } from 'lodash-es';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import type { ChangeEvent } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useDebounce } from 'use-debounce';
const overlayscrollbarsOptions = getOverlayScrollbarsParams().options;
export const FloatGeneratorFieldInputComponent = memo(
(props: FieldComponentProps<FloatGeneratorFieldInputInstance, FloatGeneratorFieldInputTemplate>) => {
const { nodeId, field } = props;
const { t } = useTranslation();
const dispatch = useAppDispatch();
const onChange = useCallback(
(value: FloatGeneratorFieldInputInstance['value']) => {
dispatch(
fieldFloatGeneratorValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const onChangeGeneratorType = useCallback(
(e: ChangeEvent<HTMLSelectElement>) => {
const value = getFloatGeneratorDefaults(e.target.value as FloatGeneratorFieldInputInstance['value']['type']);
if (!value) {
return;
}
dispatch(
fieldFloatGeneratorValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const [debouncedField] = useDebounce(field, 300);
const resolvedValuesAsString = useMemo(() => {
if (
debouncedField.value.type === FloatGeneratorUniformRandomDistributionType &&
isNil(debouncedField.value.seed)
) {
const { count } = debouncedField.value;
return `<${t('nodes.generatorNRandomValues', { count })}>`;
}
const resolvedValues = resolveFloatGeneratorField(debouncedField);
if (resolvedValues.length === 0) {
return `<${t('nodes.generatorNoValues')}>`;
} else {
return resolvedValues.map((val) => round(val, 2)).join(', ');
}
}, [debouncedField, t]);
return (
<Flex flexDir="column" gap={2}>
<Select className="nowheel nodrag" onChange={onChangeGeneratorType} value={field.value.type} size="sm">
<option value={FloatGeneratorArithmeticSequenceType}>{t('nodes.arithmeticSequence')}</option>
<option value={FloatGeneratorLinearDistributionType}>{t('nodes.linearDistribution')}</option>
<option value={FloatGeneratorUniformRandomDistributionType}>{t('nodes.uniformRandomDistribution')}</option>
<option value={FloatGeneratorParseStringType}>{t('nodes.parseString')}</option>
</Select>
{field.value.type === FloatGeneratorArithmeticSequenceType && (
<FloatGeneratorArithmeticSequenceSettings state={field.value} onChange={onChange} />
)}
{field.value.type === FloatGeneratorLinearDistributionType && (
<FloatGeneratorLinearDistributionSettings state={field.value} onChange={onChange} />
)}
{field.value.type === FloatGeneratorUniformRandomDistributionType && (
<FloatGeneratorUniformRandomDistributionSettings state={field.value} onChange={onChange} />
)}
{field.value.type === FloatGeneratorParseStringType && (
<FloatGeneratorParseStringSettings state={field.value} onChange={onChange} />
)}
<Flex w="full" h="full" p={2} borderWidth={1} borderRadius="base" maxH={128}>
<Flex w="full" h="auto">
<OverlayScrollbarsComponent
className="nodrag nowheel"
defer
style={overlayScrollbarsStyles}
options={overlayscrollbarsOptions}
>
<Text className="nodrag nowheel" fontFamily="monospace" userSelect="text" cursor="text">
{resolvedValuesAsString}
</Text>
</OverlayScrollbarsComponent>
</Flex>
</Flex>
</Flex>
);
}
);
FloatGeneratorFieldInputComponent.displayName = 'FloatGeneratorFieldInputComponent';

View File

@@ -0,0 +1,57 @@
import { CompositeNumberInput, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import type { FloatGeneratorLinearDistribution } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
type FloatGeneratorLinearDistributionSettingsProps = {
state: FloatGeneratorLinearDistribution;
onChange: (state: FloatGeneratorLinearDistribution) => void;
};
export const FloatGeneratorLinearDistributionSettings = memo(
({ state, onChange }: FloatGeneratorLinearDistributionSettingsProps) => {
const { t } = useTranslation();
const onChangeStart = useCallback(
(start: number) => {
onChange({ ...state, start });
},
[onChange, state]
);
const onChangeEnd = useCallback(
(end: number) => {
onChange({ ...state, end });
},
[onChange, state]
);
const onChangeCount = useCallback(
(count: number) => {
onChange({ ...state, count });
},
[onChange, state]
);
return (
<Flex gap={2} alignItems="flex-end">
<FormControl orientation="vertical">
<FormLabel>{t('common.start')}</FormLabel>
<CompositeNumberInput
value={state.start}
onChange={onChangeStart}
min={-Infinity}
max={Infinity}
step={0.01}
/>
</FormControl>
<FormControl orientation="vertical">
<FormLabel>{t('common.end')}</FormLabel>
<CompositeNumberInput value={state.end} onChange={onChangeEnd} min={-Infinity} max={Infinity} step={0.01} />
</FormControl>
<FormControl orientation="vertical">
<FormLabel>{t('common.count')}</FormLabel>
<CompositeNumberInput value={state.count} onChange={onChangeCount} min={1} max={Infinity} />
</FormControl>
</Flex>
);
}
);
FloatGeneratorLinearDistributionSettings.displayName = 'FloatGeneratorLinearDistributionSettings';

View File

@@ -0,0 +1,39 @@
import { Flex, FormControl, FormLabel, Input } from '@invoke-ai/ui-library';
import { GeneratorTextareaWithFileUpload } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/GeneratorTextareaWithFileUpload';
import type { FloatGeneratorParseString } from 'features/nodes/types/field';
import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
type FloatGeneratorParseStringSettingsProps = {
state: FloatGeneratorParseString;
onChange: (state: FloatGeneratorParseString) => void;
};
export const FloatGeneratorParseStringSettings = memo(({ state, onChange }: FloatGeneratorParseStringSettingsProps) => {
const { t } = useTranslation();
const onChangeSplitOn = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
onChange({ ...state, splitOn: e.target.value });
},
[onChange, state]
);
const onChangeInput = useCallback(
(input: string) => {
onChange({ ...state, input });
},
[onChange, state]
);
return (
<Flex gap={2} flexDir="column">
<FormControl orientation="vertical">
<FormLabel>{t('nodes.splitOn')}</FormLabel>
<Input value={state.splitOn} onChange={onChangeSplitOn} />
</FormControl>
<GeneratorTextareaWithFileUpload value={state.input} onChange={onChangeInput} />
</Flex>
);
});
FloatGeneratorParseStringSettings.displayName = 'FloatGeneratorParseStringSettings';

View File

@@ -0,0 +1,78 @@
import { Checkbox, CompositeNumberInput, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import type { FloatGeneratorUniformRandomDistribution } from 'features/nodes/types/field';
import { isNil } from 'lodash-es';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
type FloatGeneratorUniformRandomDistributionSettingsProps = {
state: FloatGeneratorUniformRandomDistribution;
onChange: (state: FloatGeneratorUniformRandomDistribution) => void;
};
export const FloatGeneratorUniformRandomDistributionSettings = memo(
({ state, onChange }: FloatGeneratorUniformRandomDistributionSettingsProps) => {
const { t } = useTranslation();
const onChangeMin = useCallback(
(min: number) => {
onChange({ ...state, min });
},
[onChange, state]
);
const onChangeMax = useCallback(
(max: number) => {
onChange({ ...state, max });
},
[onChange, state]
);
const onChangeCount = useCallback(
(count: number) => {
onChange({ ...state, count });
},
[onChange, state]
);
const onToggleSeed = useCallback(() => {
onChange({ ...state, seed: isNil(state.seed) ? 0 : null });
}, [onChange, state]);
const onChangeSeed = useCallback(
(seed?: number | null) => {
onChange({ ...state, seed });
},
[onChange, state]
);
return (
<Flex gap={2} flexDir="column">
<Flex gap={2} alignItems="flex-end">
<FormControl orientation="vertical">
<FormLabel>{t('common.min')}</FormLabel>
<CompositeNumberInput value={state.min} onChange={onChangeMin} min={-Infinity} max={Infinity} step={0.01} />
</FormControl>
<FormControl orientation="vertical">
<FormLabel>{t('common.max')}</FormLabel>
<CompositeNumberInput value={state.max} onChange={onChangeMax} min={-Infinity} max={Infinity} step={0.01} />
</FormControl>
<FormControl orientation="vertical">
<FormLabel>{t('common.count')}</FormLabel>
<CompositeNumberInput value={state.count} onChange={onChangeCount} min={1} max={Infinity} />
</FormControl>
<FormControl orientation="vertical">
<FormLabel alignItems="center" justifyContent="space-between" m={0} display="flex" w="full">
{t('common.seed')}
<Checkbox onChange={onToggleSeed} isChecked={!isNil(state.seed)} />
</FormLabel>
<CompositeNumberInput
isDisabled={isNil(state.seed)}
// This cast is save only because we disable the element when seed is not a number - the `...` is
// rendered in the input field in this case
value={state.seed ?? ('...' as unknown as number)}
onChange={onChangeSeed}
min={-Infinity}
max={Infinity}
/>
</FormControl>
</Flex>
</Flex>
);
}
);
FloatGeneratorUniformRandomDistributionSettings.displayName = 'FloatGeneratorUniformRandomDistributionSettings';

View File

@@ -0,0 +1,85 @@
import { FormControl, FormLabel, IconButton, Spacer, Textarea } from '@invoke-ai/ui-library';
import { toast } from 'features/toast/toast';
import { isString } from 'lodash-es';
import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react';
import { useDropzone } from 'react-dropzone';
import { useTranslation } from 'react-i18next';
import { PiUploadFill } from 'react-icons/pi';
const MAX_SIZE = 1024 * 128; // 128KB, we don't want to load huge files into node values...
type Props = {
value: string;
onChange: (value: string) => void;
};
export const GeneratorTextareaWithFileUpload = memo(({ value, onChange }: Props) => {
const { t } = useTranslation();
const onDropAccepted = useCallback(
(files: File[]) => {
const file = files[0];
if (!file) {
return;
}
const reader = new FileReader();
reader.onload = () => {
const result = reader.result;
if (!isString(result)) {
return;
}
onChange(result);
};
reader.onerror = () => {
toast({
title: 'Failed to load file',
status: 'error',
});
};
reader.readAsText(file);
},
[onChange]
);
const { getInputProps, getRootProps } = useDropzone({
accept: { 'text/csv': ['.csv'], 'text/plain': ['.txt'] },
maxSize: MAX_SIZE,
onDropAccepted,
noDrag: true,
multiple: false,
});
const onChangeInput = useCallback(
(e: ChangeEvent<HTMLTextAreaElement>) => {
onChange(e.target.value);
},
[onChange]
);
return (
<FormControl orientation="vertical" position="relative" alignItems="stretch">
<FormLabel m={0} display="flex" alignItems="center">
{t('common.input')}
<Spacer />
<IconButton
tooltip={t('nodes.generatorLoadFromFile')}
aria-label={t('nodes.generatorLoadFromFile')}
icon={<PiUploadFill />}
variant="link"
{...getRootProps()}
/>
<input {...getInputProps()} />
</FormLabel>
<Textarea
className="nowheel nodrag nopan"
value={value}
onChange={onChangeInput}
p={2}
resize="none"
rows={5}
fontSize="sm"
/>
</FormControl>
);
});
GeneratorTextareaWithFileUpload.displayName = 'GeneratorTextareaWithFileUpload';

View File

@@ -10,9 +10,9 @@ import { addImagesToNodeImageFieldCollectionDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { DndImage } from 'features/dnd/DndImage';
import { DndImageIcon } from 'features/dnd/DndImageIcon';
import { removeImageFromNodeImageFieldCollectionAction } from 'features/imageActions/actions';
import { useFieldIsInvalid } from 'features/nodes/hooks/useFieldIsInvalid';
import { fieldImageCollectionValueChanged } from 'features/nodes/store/nodesSlice';
import type { ImageField } from 'features/nodes/types/common';
import type { ImageFieldCollectionInputInstance, ImageFieldCollectionInputTemplate } from 'features/nodes/types/field';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import { memo, useCallback, useMemo } from 'react';
@@ -61,15 +61,12 @@ export const ImageFieldCollectionInputComponent = memo(
);
const onRemoveImage = useCallback(
(imageName: string) => {
removeImageFromNodeImageFieldCollectionAction({
imageName,
fieldIdentifier: { nodeId, fieldName: field.name },
dispatch: store.dispatch,
getState: store.getState,
});
(index: number) => {
const newValue = field.value ? [...field.value] : [];
newValue.splice(index, 1);
store.dispatch(fieldImageCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
},
[field.name, nodeId, store.dispatch, store.getState]
[field.name, field.value, nodeId, store]
);
return (
@@ -90,7 +87,7 @@ export const ImageFieldCollectionInputComponent = memo(
isError={isInvalid}
onUpload={onUpload}
fontSize={24}
variant="outline"
variant="ghost"
/>
)}
{field.value && field.value.length > 0 && (
@@ -102,9 +99,9 @@ export const ImageFieldCollectionInputComponent = memo(
options={overlayscrollbarsOptions}
>
<Grid w="full" h="full" templateColumns="repeat(4, 1fr)" gap={1}>
{field.value.map(({ image_name }) => (
<GridItem key={image_name} position="relative" className="nodrag">
<ImageGridItemContent imageName={image_name} onRemoveImage={onRemoveImage} />
{field.value.map((value, index) => (
<GridItem key={index} position="relative" className="nodrag">
<ImageGridItemContent value={value} index={index} onRemoveImage={onRemoveImage} />
</GridItem>
))}
</Grid>
@@ -124,11 +121,11 @@ export const ImageFieldCollectionInputComponent = memo(
ImageFieldCollectionInputComponent.displayName = 'ImageFieldCollectionInputComponent';
const ImageGridItemContent = memo(
({ imageName, onRemoveImage }: { imageName: string; onRemoveImage: (imageName: string) => void }) => {
const query = useGetImageDTOQuery(imageName);
({ value, index, onRemoveImage }: { value: ImageField; index: number; onRemoveImage: (index: number) => void }) => {
const query = useGetImageDTOQuery(value.image_name);
const onClickRemove = useCallback(() => {
onRemoveImage(imageName);
}, [imageName, onRemoveImage]);
onRemoveImage(index);
}, [index, onRemoveImage]);
if (query.isLoading) {
return <IAINoContentFallbackWithSpinner />;

View File

@@ -0,0 +1,51 @@
import { CompositeNumberInput, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import type { IntegerGeneratorArithmeticSequence } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
type IntegerGeneratorArithmeticSequenceSettingsProps = {
state: IntegerGeneratorArithmeticSequence;
onChange: (state: IntegerGeneratorArithmeticSequence) => void;
};
export const IntegerGeneratorArithmeticSequenceSettings = memo(
({ state, onChange }: IntegerGeneratorArithmeticSequenceSettingsProps) => {
const { t } = useTranslation();
const onChangeStart = useCallback(
(start: number) => {
onChange({ ...state, start });
},
[onChange, state]
);
const onChangeStep = useCallback(
(step: number) => {
onChange({ ...state, step });
},
[onChange, state]
);
const onChangeCount = useCallback(
(count: number) => {
onChange({ ...state, count });
},
[onChange, state]
);
return (
<Flex gap={2} alignItems="flex-end">
<FormControl orientation="vertical">
<FormLabel>{t('common.start')}</FormLabel>
<CompositeNumberInput value={state.start} onChange={onChangeStart} min={-Infinity} max={Infinity} />
</FormControl>
<FormControl orientation="vertical">
<FormLabel>{t('common.step')}</FormLabel>
<CompositeNumberInput value={state.step} onChange={onChangeStep} min={-Infinity} max={Infinity} />
</FormControl>
<FormControl orientation="vertical">
<FormLabel>{t('common.count')}</FormLabel>
<CompositeNumberInput value={state.count} onChange={onChangeCount} min={1} max={Infinity} />
</FormControl>
</Flex>
);
}
);
IntegerGeneratorArithmeticSequenceSettings.displayName = 'IntegerGeneratorArithmeticSequenceSettings';

View File

@@ -0,0 +1,122 @@
import { Flex, Select, Text } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { getOverlayScrollbarsParams, overlayScrollbarsStyles } from 'common/components/OverlayScrollbars/constants';
import { IntegerGeneratorArithmeticSequenceSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerGeneratorArithmeticSequenceSettings';
import { IntegerGeneratorLinearDistributionSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerGeneratorLinearDistributionSettings';
import { IntegerGeneratorParseStringSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerGeneratorParseStringSettings';
import { IntegerGeneratorUniformRandomDistributionSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerGeneratorUniformRandomDistributionSettings';
import type { FieldComponentProps } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/types';
import { fieldIntegerGeneratorValueChanged } from 'features/nodes/store/nodesSlice';
import type {
IntegerGeneratorFieldInputInstance,
IntegerGeneratorFieldInputTemplate,
} from 'features/nodes/types/field';
import {
getIntegerGeneratorDefaults,
IntegerGeneratorArithmeticSequenceType,
IntegerGeneratorLinearDistributionType,
IntegerGeneratorParseStringType,
IntegerGeneratorUniformRandomDistributionType,
resolveIntegerGeneratorField,
} from 'features/nodes/types/field';
import { isNil, round } from 'lodash-es';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import type { ChangeEvent } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useDebounce } from 'use-debounce';
const overlayscrollbarsOptions = getOverlayScrollbarsParams().options;
export const IntegerGeneratorFieldInputComponent = memo(
(props: FieldComponentProps<IntegerGeneratorFieldInputInstance, IntegerGeneratorFieldInputTemplate>) => {
const { nodeId, field } = props;
const { t } = useTranslation();
const dispatch = useAppDispatch();
const onChange = useCallback(
(value: IntegerGeneratorFieldInputInstance['value']) => {
dispatch(
fieldIntegerGeneratorValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const onChangeGeneratorType = useCallback(
(e: ChangeEvent<HTMLSelectElement>) => {
const value = getIntegerGeneratorDefaults(
e.target.value as IntegerGeneratorFieldInputInstance['value']['type']
);
dispatch(
fieldIntegerGeneratorValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const [debouncedField] = useDebounce(field, 300);
const resolvedValuesAsString = useMemo(() => {
if (
debouncedField.value.type === IntegerGeneratorUniformRandomDistributionType &&
isNil(debouncedField.value.seed)
) {
const { count } = debouncedField.value;
return `<${t('nodes.generatorNRandomValues', { count })}>`;
}
const resolvedValues = resolveIntegerGeneratorField(debouncedField);
if (resolvedValues.length === 0) {
return `<${t('nodes.generatorNoValues')}>`;
} else {
return resolvedValues.map((val) => round(val, 2)).join(', ');
}
}, [debouncedField, t]);
return (
<Flex flexDir="column" gap={2}>
<Select className="nowheel nodrag" onChange={onChangeGeneratorType} value={field.value.type} size="sm">
<option value={IntegerGeneratorArithmeticSequenceType}>{t('nodes.arithmeticSequence')}</option>
<option value={IntegerGeneratorLinearDistributionType}>{t('nodes.linearDistribution')}</option>
<option value={IntegerGeneratorUniformRandomDistributionType}>{t('nodes.uniformRandomDistribution')}</option>
<option value={IntegerGeneratorParseStringType}>{t('nodes.parseString')}</option>
</Select>
{field.value.type === IntegerGeneratorArithmeticSequenceType && (
<IntegerGeneratorArithmeticSequenceSettings state={field.value} onChange={onChange} />
)}
{field.value.type === IntegerGeneratorLinearDistributionType && (
<IntegerGeneratorLinearDistributionSettings state={field.value} onChange={onChange} />
)}
{field.value.type === IntegerGeneratorUniformRandomDistributionType && (
<IntegerGeneratorUniformRandomDistributionSettings state={field.value} onChange={onChange} />
)}
{field.value.type === IntegerGeneratorParseStringType && (
<IntegerGeneratorParseStringSettings state={field.value} onChange={onChange} />
)}
<Flex w="full" h="full" p={2} borderWidth={1} borderRadius="base" maxH={128}>
<Flex w="full" h="auto">
<OverlayScrollbarsComponent
className="nodrag nowheel"
defer
style={overlayScrollbarsStyles}
options={overlayscrollbarsOptions}
>
<Text className="nodrag nowheel" fontFamily="monospace" userSelect="text" cursor="text">
{resolvedValuesAsString}
</Text>
</OverlayScrollbarsComponent>
</Flex>
</Flex>
</Flex>
);
}
);
IntegerGeneratorFieldInputComponent.displayName = 'IntegerGeneratorFieldInputComponent';

View File

@@ -0,0 +1,51 @@
import { CompositeNumberInput, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import type { IntegerGeneratorLinearDistribution } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
type IntegerGeneratorLinearDistributionSettingsProps = {
state: IntegerGeneratorLinearDistribution;
onChange: (state: IntegerGeneratorLinearDistribution) => void;
};
export const IntegerGeneratorLinearDistributionSettings = memo(
({ state, onChange }: IntegerGeneratorLinearDistributionSettingsProps) => {
const { t } = useTranslation();
const onChangeStart = useCallback(
(start: number) => {
onChange({ ...state, start });
},
[onChange, state]
);
const onChangeEnd = useCallback(
(end: number) => {
onChange({ ...state, end });
},
[onChange, state]
);
const onChangeCount = useCallback(
(count: number) => {
onChange({ ...state, count });
},
[onChange, state]
);
return (
<Flex gap={2} alignItems="flex-end">
<FormControl orientation="vertical">
<FormLabel>{t('common.start')}</FormLabel>
<CompositeNumberInput value={state.start} onChange={onChangeStart} min={-Infinity} max={Infinity} />
</FormControl>
<FormControl orientation="vertical">
<FormLabel>{t('common.end')}</FormLabel>
<CompositeNumberInput value={state.end} onChange={onChangeEnd} min={-Infinity} max={Infinity} />
</FormControl>
<FormControl orientation="vertical">
<FormLabel>{t('common.count')}</FormLabel>
<CompositeNumberInput value={state.count} onChange={onChangeCount} min={1} max={Infinity} />
</FormControl>
</Flex>
);
}
);
IntegerGeneratorLinearDistributionSettings.displayName = 'IntegerGeneratorLinearDistributionSettings';

View File

@@ -0,0 +1,41 @@
import { Flex, FormControl, FormLabel, Input } from '@invoke-ai/ui-library';
import { GeneratorTextareaWithFileUpload } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/GeneratorTextareaWithFileUpload';
import type { IntegerGeneratorParseString } from 'features/nodes/types/field';
import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
type IntegerGeneratorParseStringSettingsProps = {
state: IntegerGeneratorParseString;
onChange: (state: IntegerGeneratorParseString) => void;
};
export const IntegerGeneratorParseStringSettings = memo(
({ state, onChange }: IntegerGeneratorParseStringSettingsProps) => {
const { t } = useTranslation();
const onChangeSplitOn = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
onChange({ ...state, splitOn: e.target.value });
},
[onChange, state]
);
const onChangeInput = useCallback(
(input: string) => {
onChange({ ...state, input });
},
[onChange, state]
);
return (
<Flex gap={2} flexDir="column">
<FormControl orientation="vertical">
<FormLabel>{t('nodes.splitOn')}</FormLabel>
<Input value={state.splitOn} onChange={onChangeSplitOn} />
</FormControl>
<GeneratorTextareaWithFileUpload value={state.input} onChange={onChangeInput} />
</Flex>
);
}
);
IntegerGeneratorParseStringSettings.displayName = 'IntegerGeneratorParseStringSettings';

View File

@@ -0,0 +1,78 @@
import { Checkbox, CompositeNumberInput, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import type { IntegerGeneratorUniformRandomDistribution } from 'features/nodes/types/field';
import { isNil } from 'lodash-es';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
type IntegerGeneratorUniformRandomDistributionSettingsProps = {
state: IntegerGeneratorUniformRandomDistribution;
onChange: (state: IntegerGeneratorUniformRandomDistribution) => void;
};
export const IntegerGeneratorUniformRandomDistributionSettings = memo(
({ state, onChange }: IntegerGeneratorUniformRandomDistributionSettingsProps) => {
const { t } = useTranslation();
const onChangeMin = useCallback(
(min: number) => {
onChange({ ...state, min });
},
[onChange, state]
);
const onChangeMax = useCallback(
(max: number) => {
onChange({ ...state, max });
},
[onChange, state]
);
const onChangeCount = useCallback(
(count: number) => {
onChange({ ...state, count });
},
[onChange, state]
);
const onToggleSeed = useCallback(() => {
onChange({ ...state, seed: isNil(state.seed) ? 0 : null });
}, [onChange, state]);
const onChangeSeed = useCallback(
(seed?: number | null) => {
onChange({ ...state, seed });
},
[onChange, state]
);
return (
<Flex gap={2} flexDir="column">
<Flex gap={2} alignItems="flex-end">
<FormControl orientation="vertical">
<FormLabel>{t('common.min')}</FormLabel>
<CompositeNumberInput value={state.min} onChange={onChangeMin} min={-Infinity} max={Infinity} />
</FormControl>
<FormControl orientation="vertical">
<FormLabel>{t('common.max')}</FormLabel>
<CompositeNumberInput value={state.max} onChange={onChangeMax} min={-Infinity} max={Infinity} />
</FormControl>
<FormControl orientation="vertical">
<FormLabel>{t('common.count')}</FormLabel>
<CompositeNumberInput value={state.count} onChange={onChangeCount} min={1} max={Infinity} />
</FormControl>
<FormControl orientation="vertical">
<FormLabel alignItems="center" justifyContent="space-between" m={0} display="flex" w="full">
{t('common.seed')}
<Checkbox onChange={onToggleSeed} isChecked={!isNil(state.seed)} />
</FormLabel>
<CompositeNumberInput
isDisabled={isNil(state.seed)}
// This cast is save only because we disable the element when seed is not a number - the `...` is
// rendered in the input field in this case
value={state.seed ?? ('...' as unknown as number)}
onChange={onChangeSeed}
min={-Infinity}
max={Infinity}
/>
</FormControl>
</Flex>
</Flex>
);
}
);
IntegerGeneratorUniformRandomDistributionSettings.displayName = 'IntegerGeneratorUniformRandomDistributionSettings';

View File

@@ -0,0 +1,237 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import {
Button,
CompositeNumberInput,
Divider,
Flex,
FormLabel,
Grid,
GridItem,
IconButton,
} from '@invoke-ai/ui-library';
import { NUMPY_RAND_MAX } from 'app/constants';
import { useAppStore } from 'app/store/nanostores/store';
import { getOverlayScrollbarsParams, overlayScrollbarsStyles } from 'common/components/OverlayScrollbars/constants';
import { useFieldIsInvalid } from 'features/nodes/hooks/useFieldIsInvalid';
import { fieldNumberCollectionValueChanged } from 'features/nodes/store/nodesSlice';
import type {
FloatFieldCollectionInputInstance,
FloatFieldCollectionInputTemplate,
IntegerFieldCollectionInputInstance,
IntegerFieldCollectionInputTemplate,
} from 'features/nodes/types/field';
import { isNil } from 'lodash-es';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiXBold } from 'react-icons/pi';
import type { FieldComponentProps } from './types';
const overlayscrollbarsOptions = getOverlayScrollbarsParams().options;
const sx = {
borderWidth: 1,
'&[data-error=true]': {
borderColor: 'error.500',
borderStyle: 'solid',
},
} satisfies SystemStyleObject;
export const NumberFieldCollectionInputComponent = memo(
(
props:
| FieldComponentProps<IntegerFieldCollectionInputInstance, IntegerFieldCollectionInputTemplate>
| FieldComponentProps<FloatFieldCollectionInputInstance, FloatFieldCollectionInputTemplate>
) => {
const { nodeId, field, fieldTemplate } = props;
const store = useAppStore();
const { t } = useTranslation();
const isInvalid = useFieldIsInvalid(nodeId, field.name);
const isIntegerField = useMemo(() => fieldTemplate.type.name === 'IntegerField', [fieldTemplate.type]);
const onRemoveNumber = useCallback(
(index: number) => {
const newValue = field.value ? [...field.value] : [];
newValue.splice(index, 1);
store.dispatch(fieldNumberCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
},
[field.name, field.value, nodeId, store]
);
const onChangeNumber = useCallback(
(index: number, value: number) => {
const newValue = field.value ? [...field.value] : [];
newValue[index] = value;
store.dispatch(fieldNumberCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
},
[field.name, field.value, nodeId, store]
);
const onAddNumber = useCallback(() => {
const newValue = field.value ? [...field.value, 0] : [0];
store.dispatch(fieldNumberCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
}, [field.name, field.value, nodeId, store]);
const min = useMemo(() => {
let min = -NUMPY_RAND_MAX;
if (!isNil(fieldTemplate.minimum)) {
min = fieldTemplate.minimum;
}
if (!isNil(fieldTemplate.exclusiveMinimum)) {
min = fieldTemplate.exclusiveMinimum + 0.01;
}
return min;
}, [fieldTemplate.exclusiveMinimum, fieldTemplate.minimum]);
const max = useMemo(() => {
let max = NUMPY_RAND_MAX;
if (!isNil(fieldTemplate.maximum)) {
max = fieldTemplate.maximum;
}
if (!isNil(fieldTemplate.exclusiveMaximum)) {
max = fieldTemplate.exclusiveMaximum - 0.01;
}
return max;
}, [fieldTemplate.exclusiveMaximum, fieldTemplate.maximum]);
const step = useMemo(() => {
if (isNil(fieldTemplate.multipleOf)) {
return isIntegerField ? 1 : 0.1;
}
return fieldTemplate.multipleOf;
}, [fieldTemplate.multipleOf, isIntegerField]);
const fineStep = useMemo(() => {
if (isNil(fieldTemplate.multipleOf)) {
return isIntegerField ? 1 : 0.01;
}
return fieldTemplate.multipleOf;
}, [fieldTemplate.multipleOf, isIntegerField]);
return (
<Flex
className="nodrag"
position="relative"
w="full"
h="auto"
maxH={64}
alignItems="stretch"
justifyContent="center"
p={1}
sx={sx}
data-error={isInvalid}
borderRadius="base"
flexDir="column"
gap={1}
>
<Button onClick={onAddNumber} variant="ghost">
{t('nodes.addItem')}
</Button>
{field.value && field.value.length > 0 && (
<>
<Divider />
<OverlayScrollbarsComponent
className="nowheel"
defer
style={overlayScrollbarsStyles}
options={overlayscrollbarsOptions}
>
<Grid gap={1} gridTemplateColumns="auto 1fr auto" alignItems="center">
{field.value.map((value, index) => (
<NumberListItemContent
key={index}
value={value}
index={index}
min={min}
max={max}
step={step}
fineStep={fineStep}
isIntegerField={isIntegerField}
onRemoveNumber={onRemoveNumber}
onChangeNumber={onChangeNumber}
/>
))}
</Grid>
</OverlayScrollbarsComponent>
</>
)}
</Flex>
);
}
);
NumberFieldCollectionInputComponent.displayName = 'NumberFieldCollectionInputComponent';
type NumberListItemContentProps = {
value: number;
index: number;
isIntegerField: boolean;
min: number;
max: number;
step: number;
fineStep: number;
onRemoveNumber: (index: number) => void;
onChangeNumber: (index: number, value: number) => void;
};
const NumberListItemContent = memo(
({
value,
index,
isIntegerField,
min,
max,
step,
fineStep,
onRemoveNumber,
onChangeNumber,
}: NumberListItemContentProps) => {
const { t } = useTranslation();
const onClickRemove = useCallback(() => {
onRemoveNumber(index);
}, [index, onRemoveNumber]);
const onChange = useCallback(
(v: number) => {
onChangeNumber(index, isIntegerField ? Math.floor(Number(v)) : Number(v));
},
[index, isIntegerField, onChangeNumber]
);
return (
<>
<GridItem>
<FormLabel ps={1} m={0}>
{index + 1}.
</FormLabel>
</GridItem>
<GridItem>
<CompositeNumberInput
onChange={onChange}
value={value}
min={min}
max={max}
step={step}
fineStep={fineStep}
className="nodrag"
flexGrow={1}
/>
</GridItem>
<GridItem>
<IconButton
tabIndex={-1}
size="sm"
variant="link"
alignSelf="stretch"
onClick={onClickRemove}
icon={<PiXBold />}
aria-label={t('common.delete')}
/>
</GridItem>
</>
);
}
);
NumberListItemContent.displayName = 'NumberListItemContent';

View File

@@ -0,0 +1,189 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Button, Divider, Flex, FormLabel, Grid, GridItem, IconButton, Input } from '@invoke-ai/ui-library';
import { useAppStore } from 'app/store/nanostores/store';
import { getOverlayScrollbarsParams, overlayScrollbarsStyles } from 'common/components/OverlayScrollbars/constants';
import { useFieldIsInvalid } from 'features/nodes/hooks/useFieldIsInvalid';
import { fieldStringCollectionValueChanged } from 'features/nodes/store/nodesSlice';
import type {
StringFieldCollectionInputInstance,
StringFieldCollectionInputTemplate,
} from 'features/nodes/types/field';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiXBold } from 'react-icons/pi';
import type { FieldComponentProps } from './types';
const overlayscrollbarsOptions = getOverlayScrollbarsParams().options;
const sx = {
borderWidth: 1,
'&[data-error=true]': {
borderColor: 'error.500',
borderStyle: 'solid',
},
} satisfies SystemStyleObject;
export const StringFieldCollectionInputComponent = memo(
(props: FieldComponentProps<StringFieldCollectionInputInstance, StringFieldCollectionInputTemplate>) => {
const { nodeId, field } = props;
const { t } = useTranslation();
const store = useAppStore();
const isInvalid = useFieldIsInvalid(nodeId, field.name);
const onRemoveString = useCallback(
(index: number) => {
const newValue = field.value ? [...field.value] : [];
newValue.splice(index, 1);
store.dispatch(fieldStringCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
},
[field.name, field.value, nodeId, store]
);
const onChangeString = useCallback(
(index: number, value: string) => {
const newValue = field.value ? [...field.value] : [];
newValue[index] = value;
store.dispatch(fieldStringCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
},
[field.name, field.value, nodeId, store]
);
const onAddString = useCallback(() => {
const newValue = field.value ? [...field.value, ''] : [''];
store.dispatch(fieldStringCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
}, [field.name, field.value, nodeId, store]);
return (
<Flex
className="nodrag"
position="relative"
w="full"
h="auto"
maxH={64}
alignItems="stretch"
justifyContent="center"
p={1}
sx={sx}
data-error={isInvalid}
borderRadius="base"
flexDir="column"
gap={1}
>
<Button onClick={onAddString} variant="ghost">
{t('nodes.addItem')}
</Button>
{field.value && field.value.length > 0 && (
<>
<Divider />
<OverlayScrollbarsComponent
className="nowheel"
defer
style={overlayScrollbarsStyles}
options={overlayscrollbarsOptions}
>
<Grid gap={1} gridTemplateColumns="auto 1fr auto" alignItems="center">
{field.value.map((value, index) => (
<ListItemContent
key={index}
value={value}
index={index}
onRemoveString={onRemoveString}
onChangeString={onChangeString}
/>
))}
</Grid>
</OverlayScrollbarsComponent>
</>
)}
</Flex>
);
}
);
StringFieldCollectionInputComponent.displayName = 'StringFieldCollectionInputComponent';
type StringListItemContentProps = {
value: string;
index: number;
onRemoveString: (index: number) => void;
onChangeString: (index: number, value: string) => void;
};
const StringListItemContent = memo(({ value, index, onRemoveString, onChangeString }: StringListItemContentProps) => {
const { t } = useTranslation();
const onClickRemove = useCallback(() => {
onRemoveString(index);
}, [index, onRemoveString]);
const onChange = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
onChangeString(index, e.target.value);
},
[index, onChangeString]
);
return (
<Flex alignItems="center" gap={1}>
<Input size="xs" resize="none" value={value} onChange={onChange} />
<IconButton
size="sm"
variant="link"
alignSelf="stretch"
onClick={onClickRemove}
icon={<PiXBold />}
aria-label={t('common.remove')}
tooltip={t('common.remove')}
/>
</Flex>
);
});
StringListItemContent.displayName = 'StringListItemContent';
type ListItemContentProps = {
value: string;
index: number;
onRemoveString: (index: number) => void;
onChangeString: (index: number, value: string) => void;
};
const ListItemContent = memo(({ value, index, onRemoveString, onChangeString }: ListItemContentProps) => {
const { t } = useTranslation();
const onClickRemove = useCallback(() => {
onRemoveString(index);
}, [index, onRemoveString]);
const onChange = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
onChangeString(index, e.target.value);
},
[index, onChangeString]
);
return (
<>
<GridItem>
<FormLabel ps={1} m={0}>
{index + 1}.
</FormLabel>
</GridItem>
<GridItem>
<Input size="sm" resize="none" value={value} onChange={onChange} />
</GridItem>
<GridItem>
<IconButton
tabIndex={-1}
size="sm"
variant="link"
alignSelf="stretch"
onClick={onClickRemove}
icon={<PiXBold />}
aria-label={t('common.delete')}
/>
</GridItem>
</>
);
});
ListItemContent.displayName = 'ListItemContent';

View File

@@ -0,0 +1,60 @@
import { CompositeNumberInput, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { GeneratorTextareaWithFileUpload } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/GeneratorTextareaWithFileUpload';
import type { StringGeneratorDynamicPromptsCombinatorial } from 'features/nodes/types/field';
import { memo, useCallback, useEffect, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useDynamicPromptsQuery } from 'services/api/endpoints/utilities';
import { useDebounce } from 'use-debounce';
type StringGeneratorDynamicPromptsCombinatorialSettingsProps = {
state: StringGeneratorDynamicPromptsCombinatorial;
onChange: (state: StringGeneratorDynamicPromptsCombinatorial) => void;
};
export const StringGeneratorDynamicPromptsCombinatorialSettings = memo(
({ state, onChange }: StringGeneratorDynamicPromptsCombinatorialSettingsProps) => {
const { t } = useTranslation();
const loadingValues = useMemo(() => [`<${t('nodes.generatorLoading')}>`], [t]);
const onChangeInput = useCallback(
(input: string) => {
onChange({ ...state, input, values: loadingValues });
},
[onChange, state, loadingValues]
);
const onChangeMaxPrompts = useCallback(
(v: number) => {
onChange({ ...state, maxPrompts: v, values: loadingValues });
},
[onChange, state, loadingValues]
);
const arg = useMemo(() => {
const { input, maxPrompts } = state;
return { prompt: input, max_prompts: maxPrompts, combinatorial: true };
}, [state]);
const [debouncedArg] = useDebounce(arg, 300);
const { data, isLoading } = useDynamicPromptsQuery(debouncedArg);
useEffect(() => {
if (isLoading) {
onChange({ ...state, values: loadingValues });
} else if (data) {
onChange({ ...state, values: data.prompts });
} else {
onChange({ ...state, values: [] });
}
}, [data, isLoading, loadingValues, onChange, state]);
return (
<Flex gap={2} flexDir="column">
<FormControl orientation="vertical">
<FormLabel>{t('dynamicPrompts.maxPrompts')}</FormLabel>
<CompositeNumberInput value={state.maxPrompts} onChange={onChangeMaxPrompts} min={1} max={1000} w="full" />
</FormControl>
<GeneratorTextareaWithFileUpload value={state.input} onChange={onChangeInput} />
</Flex>
);
}
);
StringGeneratorDynamicPromptsCombinatorialSettings.displayName = 'StringGeneratorDynamicPromptsCombinatorialSettings';

View File

@@ -0,0 +1,87 @@
import { Checkbox, CompositeNumberInput, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { GeneratorTextareaWithFileUpload } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/GeneratorTextareaWithFileUpload';
import type { StringGeneratorDynamicPromptsRandom } from 'features/nodes/types/field';
import { isNil, random } from 'lodash-es';
import { memo, useCallback, useEffect, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useDynamicPromptsQuery } from 'services/api/endpoints/utilities';
import { useDebounce } from 'use-debounce';
type StringGeneratorDynamicPromptsRandomSettingsProps = {
state: StringGeneratorDynamicPromptsRandom;
onChange: (state: StringGeneratorDynamicPromptsRandom) => void;
};
export const StringGeneratorDynamicPromptsRandomSettings = memo(
({ state, onChange }: StringGeneratorDynamicPromptsRandomSettingsProps) => {
const { t } = useTranslation();
const loadingValues = useMemo(() => [`<${t('nodes.generatorLoading')}>`], [t]);
const onChangeInput = useCallback(
(input: string) => {
onChange({ ...state, input, values: loadingValues });
},
[onChange, state, loadingValues]
);
const onChangeCount = useCallback(
(v: number) => {
onChange({ ...state, count: v, values: loadingValues });
},
[onChange, state, loadingValues]
);
const onToggleSeed = useCallback(() => {
onChange({ ...state, seed: isNil(state.seed) ? 0 : null, values: loadingValues });
}, [onChange, state, loadingValues]);
const onChangeSeed = useCallback(
(seed?: number | null) => {
onChange({ ...state, seed, values: loadingValues });
},
[onChange, state, loadingValues]
);
const arg = useMemo(() => {
const { input, count, seed } = state;
return { prompt: input, max_prompts: count, combinatorial: false, seed: seed ?? random() };
}, [state]);
const [debouncedArg] = useDebounce(arg, 300);
const { data, isLoading } = useDynamicPromptsQuery(debouncedArg);
useEffect(() => {
if (isLoading) {
onChange({ ...state, values: loadingValues });
} else if (data) {
onChange({ ...state, values: data.prompts });
} else {
onChange({ ...state, values: [] });
}
}, [data, isLoading, loadingValues, onChange, state]);
return (
<Flex gap={2} flexDir="column">
<Flex gap={2}>
<FormControl orientation="vertical">
<FormLabel alignItems="center" justifyContent="space-between" display="flex" w="full" pe={0.5}>
{t('common.seed')}
<Checkbox onChange={onToggleSeed} isChecked={!isNil(state.seed)} />
</FormLabel>
<CompositeNumberInput
isDisabled={isNil(state.seed)}
// This cast is save only because we disable the element when seed is not a number - the `...` is
// rendered in the input field in this case
value={state.seed ?? ('...' as unknown as number)}
onChange={onChangeSeed}
min={-Infinity}
max={Infinity}
/>
</FormControl>
<FormControl orientation="vertical">
<FormLabel>{t('common.count')}</FormLabel>
<CompositeNumberInput value={state.count} onChange={onChangeCount} min={1} max={1000} />
</FormControl>
</Flex>
<GeneratorTextareaWithFileUpload value={state.input} onChange={onChangeInput} />
</Flex>
);
}
);
StringGeneratorDynamicPromptsRandomSettings.displayName = 'StringGeneratorDynamicPromptsRandomSettings';

View File

@@ -0,0 +1,111 @@
import { Flex, Select, Text } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { getOverlayScrollbarsParams, overlayScrollbarsStyles } from 'common/components/OverlayScrollbars/constants';
import { StringGeneratorDynamicPromptsCombinatorialSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/StringGeneratorDynamicPromptsCombinatorialSettings';
import { StringGeneratorDynamicPromptsRandomSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/StringGeneratorDynamicPromptsRandomSettings';
import { StringGeneratorParseStringSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/StringGeneratorParseStringSettings';
import type { FieldComponentProps } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/types';
import { fieldStringGeneratorValueChanged } from 'features/nodes/store/nodesSlice';
import type { StringGeneratorFieldInputInstance, StringGeneratorFieldInputTemplate } from 'features/nodes/types/field';
import {
getStringGeneratorDefaults,
resolveStringGeneratorField,
StringGeneratorDynamicPromptsCombinatorialType,
StringGeneratorDynamicPromptsRandomType,
StringGeneratorParseStringType,
} from 'features/nodes/types/field';
import { isNil } from 'lodash-es';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import type { ChangeEvent } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useDebounce } from 'use-debounce';
const overlayscrollbarsOptions = getOverlayScrollbarsParams().options;
export const StringGeneratorFieldInputComponent = memo(
(props: FieldComponentProps<StringGeneratorFieldInputInstance, StringGeneratorFieldInputTemplate>) => {
const { nodeId, field } = props;
const { t } = useTranslation();
const dispatch = useAppDispatch();
const onChange = useCallback(
(value: StringGeneratorFieldInputInstance['value']) => {
dispatch(
fieldStringGeneratorValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const onChangeGeneratorType = useCallback(
(e: ChangeEvent<HTMLSelectElement>) => {
const value = getStringGeneratorDefaults(e.target.value as StringGeneratorFieldInputInstance['value']['type']);
dispatch(
fieldStringGeneratorValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const [debouncedField] = useDebounce(field, 300);
const resolvedValuesAsString = useMemo(() => {
if (debouncedField.value.type === StringGeneratorDynamicPromptsRandomType && isNil(debouncedField.value.seed)) {
const { count } = debouncedField.value;
return `<${t('nodes.generatorNRandomValues', { count })}>`;
}
const resolvedValues = resolveStringGeneratorField(debouncedField);
if (resolvedValues.length === 0) {
return `<${t('nodes.generatorNoValues')}>`;
} else {
return resolvedValues.join(', ');
}
}, [debouncedField, t]);
return (
<Flex flexDir="column" gap={2}>
<Select className="nowheel nodrag" onChange={onChangeGeneratorType} value={field.value.type} size="sm">
<option value={StringGeneratorParseStringType}>{t('nodes.parseString')}</option>
{/* <option value={StringGeneratorDynamicPromptsRandomType}>{t('nodes.dynamicPromptsRandom')}</option>
<option value={StringGeneratorDynamicPromptsCombinatorialType}>
{t('nodes.dynamicPromptsCombinatorial')}
</option> */}
</Select>
{field.value.type === StringGeneratorParseStringType && (
<StringGeneratorParseStringSettings state={field.value} onChange={onChange} />
)}
{field.value.type === StringGeneratorDynamicPromptsRandomType && (
<StringGeneratorDynamicPromptsRandomSettings state={field.value} onChange={onChange} />
)}
{field.value.type === StringGeneratorDynamicPromptsCombinatorialType && (
<StringGeneratorDynamicPromptsCombinatorialSettings state={field.value} onChange={onChange} />
)}
<Flex w="full" h="full" p={2} borderWidth={1} borderRadius="base" maxH={128}>
<Flex w="full" h="auto">
<OverlayScrollbarsComponent
className="nodrag nowheel"
defer
style={overlayScrollbarsStyles}
options={overlayscrollbarsOptions}
>
<Text className="nodrag nowheel" fontFamily="monospace" userSelect="text" cursor="text">
{resolvedValuesAsString}
</Text>
</OverlayScrollbarsComponent>
</Flex>
</Flex>
</Flex>
);
}
);
StringGeneratorFieldInputComponent.displayName = 'StringGeneratorFieldInputComponent';

View File

@@ -0,0 +1,41 @@
import { Flex, FormControl, FormLabel, Input } from '@invoke-ai/ui-library';
import { GeneratorTextareaWithFileUpload } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/GeneratorTextareaWithFileUpload';
import type { StringGeneratorParseString } from 'features/nodes/types/field';
import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
type StringGeneratorParseStringSettingsProps = {
state: StringGeneratorParseString;
onChange: (state: StringGeneratorParseString) => void;
};
export const StringGeneratorParseStringSettings = memo(
({ state, onChange }: StringGeneratorParseStringSettingsProps) => {
const { t } = useTranslation();
const onChangeSplitOn = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
onChange({ ...state, splitOn: e.target.value });
},
[onChange, state]
);
const onChangeInput = useCallback(
(input: string) => {
onChange({ ...state, input });
},
[onChange, state]
);
return (
<Flex gap={2} flexDir="column">
<FormControl orientation="vertical">
<FormLabel>{t('nodes.splitOn')}</FormLabel>
<Input value={state.splitOn} onChange={onChangeSplitOn} />
</FormControl>
<GeneratorTextareaWithFileUpload value={state.input} onChange={onChangeInput} />
</Flex>
);
}
);
StringGeneratorParseStringSettings.displayName = 'StringGeneratorParseStringSettings';

View File

@@ -1,12 +1,14 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Box, Editable, EditableInput, EditablePreview, Flex, useEditableControls } from '@invoke-ai/ui-library';
import type { SystemStyleObject, TextProps } from '@invoke-ai/ui-library';
import { Box, Editable, EditableInput, Flex, Text, useEditableControls } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useBatchGroupColorToken } from 'features/nodes/hooks/useBatchGroupColorToken';
import { useBatchGroupId } from 'features/nodes/hooks/useBatchGroupId';
import { useNodeLabel } from 'features/nodes/hooks/useNodeLabel';
import { useNodeTemplateTitle } from 'features/nodes/hooks/useNodeTemplateTitle';
import { nodeLabelChanged } from 'features/nodes/store/nodesSlice';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import type { MouseEvent } from 'react';
import { memo, useCallback, useEffect, useState } from 'react';
import { memo, useCallback, useEffect, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
type Props = {
@@ -17,6 +19,8 @@ type Props = {
const NodeTitle = ({ nodeId, title }: Props) => {
const dispatch = useAppDispatch();
const label = useNodeLabel(nodeId);
const batchGroupId = useBatchGroupId(nodeId);
const batchGroupColorToken = useBatchGroupColorToken(batchGroupId);
const templateTitle = useNodeTemplateTitle(nodeId);
const { t } = useTranslation();
@@ -29,6 +33,16 @@ const NodeTitle = ({ nodeId, title }: Props) => {
[dispatch, nodeId, title, templateTitle, label, t]
);
const localTitleWithBatchGroupId = useMemo(() => {
if (!batchGroupId) {
return localTitle;
}
if (batchGroupId === 'None') {
return `${localTitle} (${t('nodes.noBatchGroup')})`;
}
return `${localTitle} (${batchGroupId})`;
}, [batchGroupId, localTitle, t]);
const handleChange = useCallback((newTitle: string) => {
setLocalTitle(newTitle);
}, []);
@@ -50,7 +64,16 @@ const NodeTitle = ({ nodeId, title }: Props) => {
w="full"
h="full"
>
<EditablePreview fontSize="sm" p={0} w="full" noOfLines={1} />
<Preview
fontSize="sm"
p={0}
w="full"
noOfLines={1}
color={batchGroupColorToken}
fontWeight={batchGroupId ? 'semibold' : undefined}
>
{localTitleWithBatchGroupId}
</Preview>
<EditableInput className="nodrag" fontSize="sm" sx={editableInputStyles} />
<EditableControls />
</Editable>
@@ -60,6 +83,16 @@ const NodeTitle = ({ nodeId, title }: Props) => {
export default memo(NodeTitle);
const Preview = (props: TextProps) => {
const { isEditing } = useEditableControls();
if (isEditing) {
return null;
}
return <Text {...props} />;
};
function EditableControls() {
const { isEditing, getEditButtonProps } = useEditableControls();
const handleDoubleClick = useCallback(

View File

@@ -5,7 +5,7 @@ import NodeOpacitySlider from './NodeOpacitySlider';
import ViewportControls from './ViewportControls';
const BottomLeftPanel = () => (
<Flex gap={2} position="absolute" bottom={0} insetInlineStart={0}>
<Flex gap={2} position="absolute" bottom={2} insetInlineStart={2}>
<ViewportControls />
<NodeOpacitySlider />
</Flex>

View File

@@ -20,7 +20,7 @@ const MinimapPanel = () => {
const shouldShowMinimapPanel = useAppSelector(selectShouldShowMinimapPanel);
return (
<Flex gap={2} position="absolute" bottom={0} insetInlineEnd={0}>
<Flex gap={2} position="absolute" bottom={2} insetInlineEnd={2}>
{shouldShowMinimapPanel && (
<ChakraMiniMap
pannable

View File

@@ -12,7 +12,7 @@ import { memo } from 'react';
const TopCenterPanel = () => {
const name = useAppSelector(selectWorkflowName);
return (
<Flex gap={2} top={0} left={0} right={0} position="absolute" alignItems="flex-start" pointerEvents="none">
<Flex gap={2} top={2} left={2} right={2} position="absolute" alignItems="flex-start" pointerEvents="none">
<Flex gap="2">
<AddNodeButton />
<UpdateNodesButton />

View File

@@ -0,0 +1,22 @@
import { useMemo } from 'react';
export const useBatchGroupColorToken = (batchGroupId?: string) => {
const batchGroupColorToken = useMemo(() => {
switch (batchGroupId) {
case 'Group 1':
return 'invokeGreen.300';
case 'Group 2':
return 'invokeBlue.300';
case 'Group 3':
return 'invokePurple.200';
case 'Group 4':
return 'invokeRed.300';
case 'Group 5':
return 'invokeYellow.300';
default:
return undefined;
}
}, [batchGroupId]);
return batchGroupColorToken;
};

View File

@@ -0,0 +1,19 @@
import { useNode } from 'features/nodes/hooks/useNode';
import { isBatchNode, isInvocationNode } from 'features/nodes/types/invocation';
import { useMemo } from 'react';
export const useBatchGroupId = (nodeId: string) => {
const node = useNode(nodeId);
const batchGroupId = useMemo(() => {
if (!isInvocationNode(node)) {
return;
}
if (!isBatchNode(node)) {
return;
}
return node.data.inputs['batch_group_id']?.value as string;
}, [node]);
return batchGroupId;
};

View File

@@ -3,7 +3,21 @@ import { useAppSelector } from 'app/store/storeHooks';
import { useConnectionState } from 'features/nodes/hooks/useConnectionState';
import { useFieldInputTemplate } from 'features/nodes/hooks/useFieldInputTemplate';
import { selectFieldInputInstance, selectNodesSlice } from 'features/nodes/store/selectors';
import { isImageFieldCollectionInputInstance, isImageFieldCollectionInputTemplate } from 'features/nodes/types/field';
import {
isFloatFieldCollectionInputInstance,
isFloatFieldCollectionInputTemplate,
isImageFieldCollectionInputInstance,
isImageFieldCollectionInputTemplate,
isIntegerFieldCollectionInputInstance,
isIntegerFieldCollectionInputTemplate,
isStringFieldCollectionInputInstance,
isStringFieldCollectionInputTemplate,
} from 'features/nodes/types/field';
import {
validateImageFieldCollectionValue,
validateNumberFieldCollectionValue,
validateStringFieldCollectionValue,
} from 'features/nodes/types/fieldValidators';
import { useMemo } from 'react';
export const useFieldIsInvalid = (nodeId: string, fieldName: string) => {
@@ -35,13 +49,27 @@ export const useFieldIsInvalid = (nodeId: string, fieldName: string) => {
}
// Else special handling for individual field types
if (isImageFieldCollectionInputInstance(field) && isImageFieldCollectionInputTemplate(template)) {
// Image collections may have min or max item counts
if (template.minItems !== undefined && field.value.length < template.minItems) {
if (validateImageFieldCollectionValue(field.value, template).length > 0) {
return true;
}
}
if (template.maxItems !== undefined && field.value.length > template.maxItems) {
if (isStringFieldCollectionInputInstance(field) && isStringFieldCollectionInputTemplate(template)) {
if (validateStringFieldCollectionValue(field.value, template).length > 0) {
return true;
}
}
if (isIntegerFieldCollectionInputInstance(field) && isIntegerFieldCollectionInputTemplate(template)) {
if (validateNumberFieldCollectionValue(field.value, template).length > 0) {
return true;
}
}
if (isFloatFieldCollectionInputInstance(field) && isFloatFieldCollectionInputTemplate(template)) {
if (validateNumberFieldCollectionValue(field.value, template).length > 0) {
return true;
}
}

View File

@@ -16,10 +16,13 @@ import type {
EnumFieldValue,
FieldValue,
FloatFieldValue,
FloatGeneratorFieldValue,
FluxVAEModelFieldValue,
ImageFieldCollectionValue,
ImageFieldValue,
IntegerFieldCollectionValue,
IntegerFieldValue,
IntegerGeneratorFieldValue,
IPAdapterModelFieldValue,
LoRAModelFieldValue,
MainModelFieldValue,
@@ -28,7 +31,9 @@ import type {
SDXLRefinerModelFieldValue,
SpandrelImageToImageModelFieldValue,
StatefulFieldValue,
StringFieldCollectionValue,
StringFieldValue,
StringGeneratorFieldValue,
T2IAdapterModelFieldValue,
T5EncoderModelFieldValue,
VAEModelFieldValue,
@@ -43,11 +48,15 @@ import {
zControlLoRAModelFieldValue,
zControlNetModelFieldValue,
zEnumFieldValue,
zFloatFieldCollectionValue,
zFloatFieldValue,
zFloatGeneratorFieldValue,
zFluxVAEModelFieldValue,
zImageFieldCollectionValue,
zImageFieldValue,
zIntegerFieldCollectionValue,
zIntegerFieldValue,
zIntegerGeneratorFieldValue,
zIPAdapterModelFieldValue,
zLoRAModelFieldValue,
zMainModelFieldValue,
@@ -56,7 +65,9 @@ import {
zSDXLRefinerModelFieldValue,
zSpandrelImageToImageModelFieldValue,
zStatefulFieldValue,
zStringFieldCollectionValue,
zStringFieldValue,
zStringGeneratorFieldValue,
zT2IAdapterModelFieldValue,
zT5EncoderModelFieldValue,
zVAEModelFieldValue,
@@ -311,9 +322,15 @@ export const nodesSlice = createSlice({
fieldStringValueChanged: (state, action: FieldValueAction<StringFieldValue>) => {
fieldValueReducer(state, action, zStringFieldValue);
},
fieldStringCollectionValueChanged: (state, action: FieldValueAction<StringFieldCollectionValue>) => {
fieldValueReducer(state, action, zStringFieldCollectionValue);
},
fieldNumberValueChanged: (state, action: FieldValueAction<IntegerFieldValue | FloatFieldValue>) => {
fieldValueReducer(state, action, zIntegerFieldValue.or(zFloatFieldValue));
},
fieldNumberCollectionValueChanged: (state, action: FieldValueAction<IntegerFieldCollectionValue>) => {
fieldValueReducer(state, action, zIntegerFieldCollectionValue.or(zFloatFieldCollectionValue));
},
fieldBooleanValueChanged: (state, action: FieldValueAction<BooleanFieldValue>) => {
fieldValueReducer(state, action, zBooleanFieldValue);
},
@@ -383,6 +400,15 @@ export const nodesSlice = createSlice({
fieldSchedulerValueChanged: (state, action: FieldValueAction<SchedulerFieldValue>) => {
fieldValueReducer(state, action, zSchedulerFieldValue);
},
fieldFloatGeneratorValueChanged: (state, action: FieldValueAction<FloatGeneratorFieldValue>) => {
fieldValueReducer(state, action, zFloatGeneratorFieldValue);
},
fieldIntegerGeneratorValueChanged: (state, action: FieldValueAction<IntegerGeneratorFieldValue>) => {
fieldValueReducer(state, action, zIntegerGeneratorFieldValue);
},
fieldStringGeneratorValueChanged: (state, action: FieldValueAction<StringGeneratorFieldValue>) => {
fieldValueReducer(state, action, zStringGeneratorFieldValue);
},
notesNodeValueChanged: (state, action: PayloadAction<{ nodeId: string; value: string }>) => {
const { nodeId, value } = action.payload;
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
@@ -435,9 +461,11 @@ export const {
fieldModelIdentifierValueChanged,
fieldMainModelValueChanged,
fieldNumberValueChanged,
fieldNumberCollectionValueChanged,
fieldRefinerModelValueChanged,
fieldSchedulerValueChanged,
fieldStringValueChanged,
fieldStringCollectionValueChanged,
fieldVaeModelValueChanged,
fieldT5EncoderValueChanged,
fieldCLIPEmbedValueChanged,
@@ -445,6 +473,9 @@ export const {
fieldCLIPGEmbedValueChanged,
fieldControlLoRAModelValueChanged,
fieldFluxVAEModelValueChanged,
fieldFloatGeneratorValueChanged,
fieldIntegerGeneratorValueChanged,
fieldStringGeneratorValueChanged,
nodeEditorReset,
nodeIsIntermediateChanged,
nodeIsOpenChanged,
@@ -546,9 +577,11 @@ export const isAnyNodeOrEdgeMutation = isAnyOf(
fieldLoRAModelValueChanged,
fieldMainModelValueChanged,
fieldNumberValueChanged,
fieldNumberCollectionValueChanged,
fieldRefinerModelValueChanged,
fieldSchedulerValueChanged,
fieldStringValueChanged,
fieldStringCollectionValueChanged,
fieldVaeModelValueChanged,
fieldT5EncoderValueChanged,
fieldCLIPEmbedValueChanged,

View File

@@ -8,17 +8,21 @@ describe(areTypesEqual.name, () => {
const sourceType: FieldType = {
name: 'IntegerField',
cardinality: 'SINGLE',
batch: false,
originalType: {
name: 'Foo',
cardinality: 'SINGLE',
batch: false,
},
};
const targetType: FieldType = {
name: 'IntegerField',
cardinality: 'SINGLE',
batch: false,
originalType: {
name: 'Bar',
cardinality: 'SINGLE',
batch: false,
},
};
expect(areTypesEqual(sourceType, targetType)).toBe(true);
@@ -28,17 +32,21 @@ describe(areTypesEqual.name, () => {
const sourceType: FieldType = {
name: 'IntegerField',
cardinality: 'SINGLE',
batch: false,
originalType: {
name: 'Foo',
cardinality: 'SINGLE',
batch: false,
},
};
const targetType: FieldType = {
name: 'MainModelField',
cardinality: 'SINGLE',
batch: false,
originalType: {
name: 'IntegerField',
cardinality: 'SINGLE',
batch: false,
},
};
expect(areTypesEqual(sourceType, targetType)).toBe(true);
@@ -48,17 +56,21 @@ describe(areTypesEqual.name, () => {
const sourceType: FieldType = {
name: 'MainModelField',
cardinality: 'SINGLE',
batch: false,
originalType: {
name: 'IntegerField',
cardinality: 'SINGLE',
batch: false,
},
};
const targetType: FieldType = {
name: 'IntegerField',
cardinality: 'SINGLE',
batch: false,
originalType: {
name: 'Bar',
cardinality: 'SINGLE',
batch: false,
},
};
expect(areTypesEqual(sourceType, targetType)).toBe(true);
@@ -68,17 +80,21 @@ describe(areTypesEqual.name, () => {
const sourceType: FieldType = {
name: 'MainModelField',
cardinality: 'SINGLE',
batch: false,
originalType: {
name: 'IntegerField',
cardinality: 'SINGLE',
batch: false,
},
};
const targetType: FieldType = {
name: 'LoRAModelField',
cardinality: 'SINGLE',
batch: false,
originalType: {
name: 'IntegerField',
cardinality: 'SINGLE',
batch: false,
},
};
expect(areTypesEqual(sourceType, targetType)).toBe(true);

View File

@@ -11,7 +11,7 @@ describe(getCollectItemType.name, () => {
const n2 = buildNode(collect);
const e1 = buildEdge(n1.id, 'value', n2.id, 'item');
const result = getCollectItemType(templates, [n1, n2], [e1], n2.id);
expect(result).toEqual<FieldType>({ name: 'IntegerField', cardinality: 'SINGLE' });
expect(result).toEqual<FieldType>({ name: 'IntegerField', cardinality: 'SINGLE', batch: false });
});
it('should return null if the collect node does not have any connections', () => {
const n1 = buildNode(collect);

View File

@@ -34,6 +34,7 @@ export const add: InvocationTemplate = {
type: {
name: 'IntegerField',
cardinality: 'SINGLE',
batch: false,
},
default: 0,
},
@@ -48,6 +49,7 @@ export const add: InvocationTemplate = {
type: {
name: 'IntegerField',
cardinality: 'SINGLE',
batch: false,
},
default: 0,
},
@@ -61,6 +63,7 @@ export const add: InvocationTemplate = {
type: {
name: 'IntegerField',
cardinality: 'SINGLE',
batch: false,
},
ui_hidden: false,
},
@@ -89,6 +92,7 @@ export const sub: InvocationTemplate = {
type: {
name: 'IntegerField',
cardinality: 'SINGLE',
batch: false,
},
default: 0,
},
@@ -103,6 +107,7 @@ export const sub: InvocationTemplate = {
type: {
name: 'IntegerField',
cardinality: 'SINGLE',
batch: false,
},
default: 0,
},
@@ -116,6 +121,7 @@ export const sub: InvocationTemplate = {
type: {
name: 'IntegerField',
cardinality: 'SINGLE',
batch: false,
},
ui_hidden: false,
},
@@ -145,6 +151,7 @@ export const collect: InvocationTemplate = {
type: {
name: 'CollectionItemField',
cardinality: 'SINGLE',
batch: false,
},
},
},
@@ -157,6 +164,7 @@ export const collect: InvocationTemplate = {
type: {
name: 'CollectionField',
cardinality: 'COLLECTION',
batch: false,
},
ui_hidden: false,
ui_type: 'CollectionField',
@@ -187,10 +195,11 @@ const scheduler: InvocationTemplate = {
type: {
name: 'SchedulerField',
cardinality: 'SINGLE',
batch: false,
originalType: {
name: 'EnumField',
cardinality: 'SINGLE',
batch: false,
},
},
default: 'euler',
@@ -205,10 +214,12 @@ const scheduler: InvocationTemplate = {
type: {
name: 'SchedulerField',
cardinality: 'SINGLE',
batch: false,
originalType: {
name: 'EnumField',
cardinality: 'SINGLE',
batch: false,
},
},
ui_hidden: false,
@@ -240,10 +251,12 @@ export const main_model_loader: InvocationTemplate = {
type: {
name: 'MainModelField',
cardinality: 'SINGLE',
batch: false,
originalType: {
name: 'ModelIdentifierField',
cardinality: 'SINGLE',
batch: false,
},
},
},
@@ -257,6 +270,7 @@ export const main_model_loader: InvocationTemplate = {
type: {
name: 'VAEField',
cardinality: 'SINGLE',
batch: false,
},
ui_hidden: false,
},
@@ -268,6 +282,7 @@ export const main_model_loader: InvocationTemplate = {
type: {
name: 'CLIPField',
cardinality: 'SINGLE',
batch: false,
},
ui_hidden: false,
},
@@ -279,6 +294,7 @@ export const main_model_loader: InvocationTemplate = {
type: {
name: 'UNetField',
cardinality: 'SINGLE',
batch: false,
},
ui_hidden: false,
},
@@ -307,6 +323,7 @@ export const img_resize: InvocationTemplate = {
type: {
name: 'BoardField',
cardinality: 'SINGLE',
batch: false,
},
},
metadata: {
@@ -320,6 +337,7 @@ export const img_resize: InvocationTemplate = {
type: {
name: 'MetadataField',
cardinality: 'SINGLE',
batch: false,
},
},
image: {
@@ -333,6 +351,7 @@ export const img_resize: InvocationTemplate = {
type: {
name: 'ImageField',
cardinality: 'SINGLE',
batch: false,
},
},
width: {
@@ -346,6 +365,7 @@ export const img_resize: InvocationTemplate = {
type: {
name: 'IntegerField',
cardinality: 'SINGLE',
batch: false,
},
default: 512,
exclusiveMinimum: 0,
@@ -361,6 +381,7 @@ export const img_resize: InvocationTemplate = {
type: {
name: 'IntegerField',
cardinality: 'SINGLE',
batch: false,
},
default: 512,
exclusiveMinimum: 0,
@@ -376,6 +397,7 @@ export const img_resize: InvocationTemplate = {
type: {
name: 'EnumField',
cardinality: 'SINGLE',
batch: false,
},
options: ['nearest', 'box', 'bilinear', 'hamming', 'bicubic', 'lanczos'],
default: 'bicubic',
@@ -390,6 +412,7 @@ export const img_resize: InvocationTemplate = {
type: {
name: 'ImageField',
cardinality: 'SINGLE',
batch: false,
},
ui_hidden: false,
},
@@ -401,6 +424,7 @@ export const img_resize: InvocationTemplate = {
type: {
name: 'IntegerField',
cardinality: 'SINGLE',
batch: false,
},
ui_hidden: false,
},
@@ -412,6 +436,7 @@ export const img_resize: InvocationTemplate = {
type: {
name: 'IntegerField',
cardinality: 'SINGLE',
batch: false,
},
ui_hidden: false,
},
@@ -441,6 +466,7 @@ const iterate: InvocationTemplate = {
type: {
name: 'CollectionField',
cardinality: 'COLLECTION',
batch: false,
},
},
},
@@ -453,6 +479,7 @@ const iterate: InvocationTemplate = {
type: {
name: 'CollectionItemField',
cardinality: 'SINGLE',
batch: false,
},
ui_hidden: false,
ui_type: 'CollectionItemField',
@@ -465,6 +492,7 @@ const iterate: InvocationTemplate = {
type: {
name: 'IntegerField',
cardinality: 'SINGLE',
batch: false,
},
ui_hidden: false,
},
@@ -476,6 +504,7 @@ const iterate: InvocationTemplate = {
type: {
name: 'IntegerField',
cardinality: 'SINGLE',
batch: false,
},
ui_hidden: false,
},

View File

@@ -6,50 +6,57 @@ describe(validateConnectionTypes.name, () => {
describe('generic cases', () => {
it('should accept SINGLE to SINGLE of same type', () => {
const r = validateConnectionTypes(
{ name: 'FooField', cardinality: 'SINGLE' },
{ name: 'FooField', cardinality: 'SINGLE' }
{ name: 'FooField', cardinality: 'SINGLE', batch: false },
{ name: 'FooField', cardinality: 'SINGLE', batch: false }
);
expect(r).toBe(true);
});
it('should accept COLLECTION to COLLECTION of same type', () => {
const r = validateConnectionTypes(
{ name: 'FooField', cardinality: 'COLLECTION' },
{ name: 'FooField', cardinality: 'COLLECTION' }
{ name: 'FooField', cardinality: 'COLLECTION', batch: false },
{ name: 'FooField', cardinality: 'COLLECTION', batch: false }
);
expect(r).toBe(true);
});
it('should accept SINGLE to SINGLE_OR_COLLECTION of same type', () => {
const r = validateConnectionTypes(
{ name: 'FooField', cardinality: 'SINGLE' },
{ name: 'FooField', cardinality: 'SINGLE_OR_COLLECTION' }
{ name: 'FooField', cardinality: 'SINGLE', batch: false },
{ name: 'FooField', cardinality: 'SINGLE_OR_COLLECTION', batch: false }
);
expect(r).toBe(true);
});
it('should accept COLLECTION to SINGLE_OR_COLLECTION of same type', () => {
const r = validateConnectionTypes(
{ name: 'FooField', cardinality: 'COLLECTION' },
{ name: 'FooField', cardinality: 'SINGLE_OR_COLLECTION' }
{ name: 'FooField', cardinality: 'COLLECTION', batch: false },
{ name: 'FooField', cardinality: 'SINGLE_OR_COLLECTION', batch: false }
);
expect(r).toBe(true);
});
it('should reject COLLECTION to SINGLE of same type', () => {
const r = validateConnectionTypes(
{ name: 'FooField', cardinality: 'COLLECTION' },
{ name: 'FooField', cardinality: 'SINGLE' }
{ name: 'FooField', cardinality: 'COLLECTION', batch: false },
{ name: 'FooField', cardinality: 'SINGLE', batch: false }
);
expect(r).toBe(false);
});
it('should reject SINGLE_OR_COLLECTION to SINGLE of same type', () => {
const r = validateConnectionTypes(
{ name: 'FooField', cardinality: 'SINGLE_OR_COLLECTION' },
{ name: 'FooField', cardinality: 'SINGLE' }
{ name: 'FooField', cardinality: 'SINGLE_OR_COLLECTION', batch: false },
{ name: 'FooField', cardinality: 'SINGLE', batch: false }
);
expect(r).toBe(false);
});
it('should reject types with mismatch batch fields', () => {
const r = validateConnectionTypes(
{ name: 'FooField', cardinality: 'SINGLE', batch: false },
{ name: 'FooField', cardinality: 'SINGLE', batch: true }
);
expect(r).toBe(false);
});
it('should reject mismatched types', () => {
const r = validateConnectionTypes(
{ name: 'FooField', cardinality: 'SINGLE' },
{ name: 'BarField', cardinality: 'SINGLE' }
{ name: 'FooField', cardinality: 'SINGLE', batch: false },
{ name: 'BarField', cardinality: 'SINGLE', batch: false }
);
expect(r).toBe(false);
});
@@ -58,16 +65,16 @@ describe(validateConnectionTypes.name, () => {
describe('special cases', () => {
it('should reject a COLLECTION input to a COLLECTION input', () => {
const r = validateConnectionTypes(
{ name: 'CollectionField', cardinality: 'COLLECTION' },
{ name: 'CollectionField', cardinality: 'COLLECTION' }
{ name: 'CollectionField', cardinality: 'COLLECTION', batch: false },
{ name: 'CollectionField', cardinality: 'COLLECTION', batch: false }
);
expect(r).toBe(false);
});
it('should accept equal types', () => {
const r = validateConnectionTypes(
{ name: 'IntegerField', cardinality: 'SINGLE' },
{ name: 'IntegerField', cardinality: 'SINGLE' }
{ name: 'IntegerField', cardinality: 'SINGLE', batch: false },
{ name: 'IntegerField', cardinality: 'SINGLE', batch: false }
);
expect(r).toBe(true);
});
@@ -75,36 +82,36 @@ describe(validateConnectionTypes.name, () => {
describe('CollectionItemField', () => {
it('should accept CollectionItemField to any SINGLE target', () => {
const r = validateConnectionTypes(
{ name: 'CollectionItemField', cardinality: 'SINGLE' },
{ name: 'IntegerField', cardinality: 'SINGLE' }
{ name: 'CollectionItemField', cardinality: 'SINGLE', batch: false },
{ name: 'IntegerField', cardinality: 'SINGLE', batch: false }
);
expect(r).toBe(true);
});
it('should accept CollectionItemField to any SINGLE_OR_COLLECTION target', () => {
const r = validateConnectionTypes(
{ name: 'CollectionItemField', cardinality: 'SINGLE' },
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' }
{ name: 'CollectionItemField', cardinality: 'SINGLE', batch: false },
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION', batch: false }
);
expect(r).toBe(true);
});
it('should accept any SINGLE to CollectionItemField', () => {
const r = validateConnectionTypes(
{ name: 'IntegerField', cardinality: 'SINGLE' },
{ name: 'CollectionItemField', cardinality: 'SINGLE' }
{ name: 'IntegerField', cardinality: 'SINGLE', batch: false },
{ name: 'CollectionItemField', cardinality: 'SINGLE', batch: false }
);
expect(r).toBe(true);
});
it('should reject any COLLECTION to CollectionItemField', () => {
const r = validateConnectionTypes(
{ name: 'IntegerField', cardinality: 'COLLECTION' },
{ name: 'CollectionItemField', cardinality: 'SINGLE' }
{ name: 'IntegerField', cardinality: 'COLLECTION', batch: false },
{ name: 'CollectionItemField', cardinality: 'SINGLE', batch: false }
);
expect(r).toBe(false);
});
it('should reject any SINGLE_OR_COLLECTION to CollectionItemField', () => {
const r = validateConnectionTypes(
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' },
{ name: 'CollectionItemField', cardinality: 'SINGLE' }
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION', batch: false },
{ name: 'CollectionItemField', cardinality: 'SINGLE', batch: false }
);
expect(r).toBe(false);
});
@@ -113,22 +120,22 @@ describe(validateConnectionTypes.name, () => {
describe('SINGLE_OR_COLLECTION', () => {
it('should accept any SINGLE of same type to SINGLE_OR_COLLECTION', () => {
const r = validateConnectionTypes(
{ name: 'IntegerField', cardinality: 'SINGLE' },
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' }
{ name: 'IntegerField', cardinality: 'SINGLE', batch: false },
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION', batch: false }
);
expect(r).toBe(true);
});
it('should accept any COLLECTION of same type to SINGLE_OR_COLLECTION', () => {
const r = validateConnectionTypes(
{ name: 'IntegerField', cardinality: 'COLLECTION' },
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' }
{ name: 'IntegerField', cardinality: 'COLLECTION', batch: false },
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION', batch: false }
);
expect(r).toBe(true);
});
it('should accept any SINGLE_OR_COLLECTION of same type to SINGLE_OR_COLLECTION', () => {
const r = validateConnectionTypes(
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' },
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' }
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION', batch: false },
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION', batch: false }
);
expect(r).toBe(true);
});
@@ -137,15 +144,15 @@ describe(validateConnectionTypes.name, () => {
describe('CollectionField', () => {
it('should accept any CollectionField to any COLLECTION type', () => {
const r = validateConnectionTypes(
{ name: 'CollectionField', cardinality: 'SINGLE' },
{ name: 'IntegerField', cardinality: 'COLLECTION' }
{ name: 'CollectionField', cardinality: 'SINGLE', batch: false },
{ name: 'IntegerField', cardinality: 'COLLECTION', batch: false }
);
expect(r).toBe(true);
});
it('should accept any CollectionField to any SINGLE_OR_COLLECTION type', () => {
const r = validateConnectionTypes(
{ name: 'CollectionField', cardinality: 'SINGLE' },
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' }
{ name: 'CollectionField', cardinality: 'SINGLE', batch: false },
{ name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION', batch: false }
);
expect(r).toBe(true);
});
@@ -159,27 +166,30 @@ describe(validateConnectionTypes.name, () => {
{ t1: 'FloatField', t2: 'StringField' },
];
it.each(typePairs)('should accept SINGLE $t1 to SINGLE $t2', ({ t1, t2 }: TypePair) => {
const r = validateConnectionTypes({ name: t1, cardinality: 'SINGLE' }, { name: t2, cardinality: 'SINGLE' });
const r = validateConnectionTypes(
{ name: t1, cardinality: 'SINGLE', batch: false },
{ name: t2, cardinality: 'SINGLE', batch: false }
);
expect(r).toBe(true);
});
it.each(typePairs)('should accept SINGLE $t1 to SINGLE_OR_COLLECTION $t2', ({ t1, t2 }: TypePair) => {
const r = validateConnectionTypes(
{ name: t1, cardinality: 'SINGLE' },
{ name: t2, cardinality: 'SINGLE_OR_COLLECTION' }
{ name: t1, cardinality: 'SINGLE', batch: false },
{ name: t2, cardinality: 'SINGLE_OR_COLLECTION', batch: false }
);
expect(r).toBe(true);
});
it.each(typePairs)('should accept COLLECTION $t1 to COLLECTION $t2', ({ t1, t2 }: TypePair) => {
const r = validateConnectionTypes(
{ name: t1, cardinality: 'COLLECTION' },
{ name: t2, cardinality: 'COLLECTION' }
{ name: t1, cardinality: 'COLLECTION', batch: false },
{ name: t2, cardinality: 'COLLECTION', batch: false }
);
expect(r).toBe(true);
});
it.each(typePairs)('should accept COLLECTION $t1 to SINGLE_OR_COLLECTION $t2', ({ t1, t2 }: TypePair) => {
const r = validateConnectionTypes(
{ name: t1, cardinality: 'COLLECTION' },
{ name: t2, cardinality: 'SINGLE_OR_COLLECTION' }
{ name: t1, cardinality: 'COLLECTION', batch: false },
{ name: t2, cardinality: 'SINGLE_OR_COLLECTION', batch: false }
);
expect(r).toBe(true);
});
@@ -187,8 +197,8 @@ describe(validateConnectionTypes.name, () => {
'should accept SINGLE_OR_COLLECTION $t1 to SINGLE_OR_COLLECTION $t2',
({ t1, t2 }: TypePair) => {
const r = validateConnectionTypes(
{ name: t1, cardinality: 'SINGLE_OR_COLLECTION' },
{ name: t2, cardinality: 'SINGLE_OR_COLLECTION' }
{ name: t1, cardinality: 'SINGLE_OR_COLLECTION', batch: false },
{ name: t2, cardinality: 'SINGLE_OR_COLLECTION', batch: false }
);
expect(r).toBe(true);
}
@@ -198,22 +208,22 @@ describe(validateConnectionTypes.name, () => {
describe('AnyField', () => {
it('should accept any SINGLE type to AnyField', () => {
const r = validateConnectionTypes(
{ name: 'FooField', cardinality: 'SINGLE' },
{ name: 'AnyField', cardinality: 'SINGLE' }
{ name: 'FooField', cardinality: 'SINGLE', batch: false },
{ name: 'AnyField', cardinality: 'SINGLE', batch: false }
);
expect(r).toBe(true);
});
it('should accept any COLLECTION type to AnyField', () => {
const r = validateConnectionTypes(
{ name: 'FooField', cardinality: 'SINGLE' },
{ name: 'AnyField', cardinality: 'COLLECTION' }
{ name: 'FooField', cardinality: 'SINGLE', batch: false },
{ name: 'AnyField', cardinality: 'COLLECTION', batch: false }
);
expect(r).toBe(true);
});
it('should accept any SINGLE_OR_COLLECTION type to AnyField', () => {
const r = validateConnectionTypes(
{ name: 'FooField', cardinality: 'SINGLE' },
{ name: 'AnyField', cardinality: 'SINGLE_OR_COLLECTION' }
{ name: 'FooField', cardinality: 'SINGLE', batch: false },
{ name: 'AnyField', cardinality: 'SINGLE_OR_COLLECTION', batch: false }
);
expect(r).toBe(true);
});

View File

@@ -19,6 +19,11 @@ export const validateConnectionTypes = (sourceType: FieldType, targetType: Field
return true;
}
// Batch and non-batch fields are incompatible.
if (sourceType.batch !== targetType.batch) {
return false;
}
/**
* Connection types must be the same for a connection, with exceptions:
* - CollectionItem can connect to any non-COLLECTION (e.g. SINGLE or SINGLE_OR_COLLECTION)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,110 @@
import type {
FloatFieldCollectionInputTemplate,
FloatFieldCollectionValue,
ImageFieldCollectionInputTemplate,
ImageFieldCollectionValue,
IntegerFieldCollectionInputTemplate,
IntegerFieldCollectionValue,
StringFieldCollectionInputTemplate,
StringFieldCollectionValue,
} from 'features/nodes/types/field';
import { t } from 'i18next';
export const validateImageFieldCollectionValue = (
value: NonNullable<ImageFieldCollectionValue>,
template: ImageFieldCollectionInputTemplate
): string[] => {
const reasons: string[] = [];
const { minItems, maxItems } = template;
const count = value.length;
// Image collections may have min or max items to validate
if (minItems !== undefined && minItems > 0 && count === 0) {
reasons.push(t('parameters.invoke.collectionEmpty'));
}
if (minItems !== undefined && count < minItems) {
reasons.push(t('parameters.invoke.collectionTooFewItems', { count, minItems }));
}
if (maxItems !== undefined && count > maxItems) {
reasons.push(t('parameters.invoke.collectionTooManyItems', { count, maxItems }));
}
return reasons;
};
export const validateStringFieldCollectionValue = (
value: NonNullable<StringFieldCollectionValue>,
template: StringFieldCollectionInputTemplate
): string[] => {
const reasons: string[] = [];
const { minItems, maxItems, minLength, maxLength } = template;
const count = value.length;
// Image collections may have min or max items to validate
if (minItems !== undefined && minItems > 0 && count === 0) {
reasons.push(t('parameters.invoke.collectionEmpty'));
}
if (minItems !== undefined && count < minItems) {
reasons.push(t('parameters.invoke.collectionTooFewItems', { count, minItems }));
}
if (maxItems !== undefined && count > maxItems) {
reasons.push(t('parameters.invoke.collectionTooManyItems', { count, maxItems }));
}
for (const str of value) {
if (maxLength !== undefined && str.length > maxLength) {
reasons.push(t('parameters.invoke.collectionStringTooLong', { value, maxLength }));
}
if (minLength !== undefined && str.length < minLength) {
reasons.push(t('parameters.invoke.collectionStringTooShort', { value, minLength }));
}
}
return reasons;
};
export const validateNumberFieldCollectionValue = (
value: NonNullable<IntegerFieldCollectionValue> | NonNullable<FloatFieldCollectionValue>,
template: IntegerFieldCollectionInputTemplate | FloatFieldCollectionInputTemplate
): string[] => {
const reasons: string[] = [];
const { minItems, maxItems, minimum, maximum, exclusiveMinimum, exclusiveMaximum, multipleOf } = template;
const count = value.length;
// Image collections may have min or max items to validate
if (minItems !== undefined && minItems > 0 && count === 0) {
reasons.push(t('parameters.invoke.collectionEmpty'));
}
if (minItems !== undefined && count < minItems) {
reasons.push(t('parameters.invoke.collectionTooFewItems', { count, minItems }));
}
if (maxItems !== undefined && count > maxItems) {
reasons.push(t('parameters.invoke.collectionTooManyItems', { count, maxItems }));
}
for (const num of value) {
if (maximum !== undefined && num > maximum) {
reasons.push(t('parameters.invoke.collectionNumberGTMax', { value, maximum }));
}
if (minimum !== undefined && num < minimum) {
reasons.push(t('parameters.invoke.collectionNumberLTMin', { value, minimum }));
}
if (exclusiveMaximum !== undefined && num >= exclusiveMaximum) {
reasons.push(t('parameters.invoke.collectionNumberGTExclusiveMax', { value, exclusiveMaximum }));
}
if (exclusiveMinimum !== undefined && num <= exclusiveMinimum) {
reasons.push(t('parameters.invoke.collectionNumberLTExclusiveMin', { value, exclusiveMinimum }));
}
if (multipleOf !== undefined && num % multipleOf !== 0) {
reasons.push(t('parameters.invoke.collectionNumberNotMultipleOf', { value, multipleOf }));
}
}
return reasons;
};

View File

@@ -91,3 +91,30 @@ const zInvocationNodeEdgeExtra = z.object({
type InvocationNodeEdgeExtra = z.infer<typeof zInvocationNodeEdgeExtra>;
export type InvocationNodeEdge = Edge<InvocationNodeEdgeExtra>;
// #endregion
export const isBatchNode = (node: InvocationNode) => {
switch (node.data.type) {
case 'image_batch':
case 'string_batch':
case 'integer_batch':
case 'float_batch':
return true;
default:
return false;
}
};
const isGeneratorNode = (node: InvocationNode) => {
switch (node.data.type) {
case 'float_generator':
case 'integer_generator':
case 'string_generator':
return true;
default:
return false;
}
};
export const isExecutableNode = (node: InvocationNode) => {
return !isBatchNode(node) && !isGeneratorNode(node);
};

View File

@@ -1,6 +1,6 @@
import { logger } from 'app/logging/logger';
import type { NodesState } from 'features/nodes/store/types';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { isExecutableNode, isInvocationNode } from 'features/nodes/types/invocation';
import { omit, reduce } from 'lodash-es';
import type { AnyInvocation, Graph } from 'services/api/types';
import { v4 as uuidv4 } from 'uuid';
@@ -14,7 +14,7 @@ export const buildNodesGraph = (nodesState: NodesState): Graph => {
const { nodes, edges } = nodesState;
// Exclude all batch nodes - we will handle these in the batch setup in a diff function
const filteredNodes = nodes.filter(isInvocationNode).filter((node) => node.data.type !== 'image_batch');
const filteredNodes = nodes.filter(isInvocationNode).filter(isExecutableNode);
// Reduce the node editor nodes into invocation graph nodes
const parsedNodes = filteredNodes.reduce<NonNullable<Graph['nodes']>>((nodesAccumulator, node) => {

View File

@@ -29,6 +29,9 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
CLIPLEmbedModelField: undefined,
CLIPGEmbedModelField: undefined,
ControlLoRAModelField: undefined,
FloatGeneratorField: undefined,
IntegerGeneratorField: undefined,
StringGeneratorField: undefined,
};
export const buildFieldInputInstance = (id: string, template: FieldInputTemplate): FieldInputInstance => {

View File

@@ -11,12 +11,16 @@ import type {
EnumFieldInputTemplate,
FieldInputTemplate,
FieldType,
FloatFieldCollectionInputTemplate,
FloatFieldInputTemplate,
FloatGeneratorFieldInputTemplate,
FluxMainModelFieldInputTemplate,
FluxVAEModelFieldInputTemplate,
ImageFieldCollectionInputTemplate,
ImageFieldInputTemplate,
IntegerFieldCollectionInputTemplate,
IntegerFieldInputTemplate,
IntegerGeneratorFieldInputTemplate,
IPAdapterModelFieldInputTemplate,
LoRAModelFieldInputTemplate,
MainModelFieldInputTemplate,
@@ -28,12 +32,23 @@ import type {
SpandrelImageToImageModelFieldInputTemplate,
StatefulFieldType,
StatelessFieldInputTemplate,
StringFieldCollectionInputTemplate,
StringFieldInputTemplate,
StringGeneratorFieldInputTemplate,
T2IAdapterModelFieldInputTemplate,
T5EncoderModelFieldInputTemplate,
VAEModelFieldInputTemplate,
} from 'features/nodes/types/field';
import { isImageCollectionFieldType, isStatefulFieldType } from 'features/nodes/types/field';
import {
getFloatGeneratorArithmeticSequenceDefaults,
getIntegerGeneratorArithmeticSequenceDefaults,
getStringGeneratorParseStringDefaults,
isFloatCollectionFieldType,
isImageCollectionFieldType,
isIntegerCollectionFieldType,
isStatefulFieldType,
isStringCollectionFieldType,
} from 'features/nodes/types/field';
import type { InvocationFieldSchema } from 'features/nodes/types/openapi';
import { isSchemaObject } from 'features/nodes/types/openapi';
import { t } from 'i18next';
@@ -77,6 +92,48 @@ const buildIntegerFieldInputTemplate: FieldInputTemplateBuilder<IntegerFieldInpu
return template;
};
const buildIntegerFieldCollectionInputTemplate: FieldInputTemplateBuilder<IntegerFieldCollectionInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: IntegerFieldCollectionInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? (schemaObject.orig_required ? [] : undefined),
};
if (schemaObject.minItems !== undefined) {
template.minItems = schemaObject.minItems;
}
if (schemaObject.maxItems !== undefined) {
template.maxItems = schemaObject.maxItems;
}
if (schemaObject.multipleOf !== undefined) {
template.multipleOf = schemaObject.multipleOf;
}
if (schemaObject.maximum !== undefined) {
template.maximum = schemaObject.maximum;
}
if (schemaObject.exclusiveMaximum !== undefined && isNumber(schemaObject.exclusiveMaximum)) {
template.exclusiveMaximum = schemaObject.exclusiveMaximum;
}
if (schemaObject.minimum !== undefined) {
template.minimum = schemaObject.minimum;
}
if (schemaObject.exclusiveMinimum !== undefined && isNumber(schemaObject.exclusiveMinimum)) {
template.exclusiveMinimum = schemaObject.exclusiveMinimum;
}
return template;
};
const buildFloatFieldInputTemplate: FieldInputTemplateBuilder<FloatFieldInputTemplate> = ({
schemaObject,
baseField,
@@ -111,6 +168,48 @@ const buildFloatFieldInputTemplate: FieldInputTemplateBuilder<FloatFieldInputTem
return template;
};
const buildFloatFieldCollectionInputTemplate: FieldInputTemplateBuilder<FloatFieldCollectionInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: FloatFieldCollectionInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? (schemaObject.orig_required ? [] : undefined),
};
if (schemaObject.minItems !== undefined) {
template.minItems = schemaObject.minItems;
}
if (schemaObject.maxItems !== undefined) {
template.maxItems = schemaObject.maxItems;
}
if (schemaObject.multipleOf !== undefined) {
template.multipleOf = schemaObject.multipleOf;
}
if (schemaObject.maximum !== undefined) {
template.maximum = schemaObject.maximum;
}
if (schemaObject.exclusiveMaximum !== undefined && isNumber(schemaObject.exclusiveMaximum)) {
template.exclusiveMaximum = schemaObject.exclusiveMaximum;
}
if (schemaObject.minimum !== undefined) {
template.minimum = schemaObject.minimum;
}
if (schemaObject.exclusiveMinimum !== undefined && isNumber(schemaObject.exclusiveMinimum)) {
template.exclusiveMinimum = schemaObject.exclusiveMinimum;
}
return template;
};
const buildStringFieldInputTemplate: FieldInputTemplateBuilder<StringFieldInputTemplate> = ({
schemaObject,
baseField,
@@ -133,6 +232,36 @@ const buildStringFieldInputTemplate: FieldInputTemplateBuilder<StringFieldInputT
return template;
};
const buildStringFieldCollectionInputTemplate: FieldInputTemplateBuilder<StringFieldCollectionInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: StringFieldCollectionInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? (schemaObject.orig_required ? [] : undefined),
};
if (schemaObject.minLength !== undefined) {
template.minLength = schemaObject.minLength;
}
if (schemaObject.maxLength !== undefined) {
template.maxLength = schemaObject.maxLength;
}
if (schemaObject.minItems !== undefined) {
template.minItems = schemaObject.minItems;
}
if (schemaObject.maxItems !== undefined) {
template.maxItems = schemaObject.maxItems;
}
return template;
};
const buildBooleanFieldInputTemplate: FieldInputTemplateBuilder<BooleanFieldInputTemplate> = ({
schemaObject,
baseField,
@@ -514,6 +643,48 @@ const buildSchedulerFieldInputTemplate: FieldInputTemplateBuilder<SchedulerField
return template;
};
const buildFloatGeneratorFieldInputTemplate: FieldInputTemplateBuilder<FloatGeneratorFieldInputTemplate> = ({
// schemaObject,
baseField,
fieldType,
}) => {
const template: FloatGeneratorFieldInputTemplate = {
...baseField,
type: fieldType,
default: getFloatGeneratorArithmeticSequenceDefaults(),
};
return template;
};
const buildIntegerGeneratorFieldInputTemplate: FieldInputTemplateBuilder<IntegerGeneratorFieldInputTemplate> = ({
// schemaObject,
baseField,
fieldType,
}) => {
const template: IntegerGeneratorFieldInputTemplate = {
...baseField,
type: fieldType,
default: getIntegerGeneratorArithmeticSequenceDefaults(),
};
return template;
};
const buildStringGeneratorFieldInputTemplate: FieldInputTemplateBuilder<StringGeneratorFieldInputTemplate> = ({
// schemaObject,
baseField,
fieldType,
}) => {
const template: StringGeneratorFieldInputTemplate = {
...baseField,
type: fieldType,
default: getStringGeneratorParseStringDefaults(),
};
return template;
};
export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputTemplateBuilder> = {
BoardField: buildBoardFieldInputTemplate,
BooleanField: buildBooleanFieldInputTemplate,
@@ -542,6 +713,9 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
CLIPGEmbedModelField: buildCLIPGEmbedModelFieldInputTemplate,
FluxVAEModelField: buildFluxVAEModelFieldInputTemplate,
ControlLoRAModelField: buildControlLoRAModelFieldInputTemplate,
FloatGeneratorField: buildFloatGeneratorFieldInputTemplate,
IntegerGeneratorField: buildIntegerGeneratorFieldInputTemplate,
StringGeneratorField: buildStringGeneratorFieldInputTemplate,
} as const;
export const buildFieldInputTemplate = (
@@ -569,12 +743,29 @@ export const buildFieldInputTemplate = (
if (isStatefulFieldType(fieldType)) {
if (isImageCollectionFieldType(fieldType)) {
fieldType;
return buildImageFieldCollectionInputTemplate({
schemaObject: fieldSchema,
baseField,
fieldType,
});
} else if (isStringCollectionFieldType(fieldType)) {
return buildStringFieldCollectionInputTemplate({
schemaObject: fieldSchema,
baseField,
fieldType,
});
} else if (isIntegerCollectionFieldType(fieldType)) {
return buildIntegerFieldCollectionInputTemplate({
schemaObject: fieldSchema,
baseField,
fieldType,
});
} else if (isFloatCollectionFieldType(fieldType)) {
return buildFloatFieldCollectionInputTemplate({
schemaObject: fieldSchema,
baseField,
fieldType,
});
} else {
const builder = TEMPLATE_BUILDER_MAP[fieldType.name];
const template = builder({

View File

@@ -19,42 +19,42 @@ const primitiveTypes: ParseFieldTypeTestCase[] = [
{
name: 'SINGLE IntegerField',
schema: { type: 'integer' },
expected: { name: 'IntegerField', cardinality: 'SINGLE' },
expected: { name: 'IntegerField', cardinality: 'SINGLE', batch: false },
},
{
name: 'SINGLE FloatField',
schema: { type: 'number' },
expected: { name: 'FloatField', cardinality: 'SINGLE' },
expected: { name: 'FloatField', cardinality: 'SINGLE', batch: false },
},
{
name: 'SINGLE StringField',
schema: { type: 'string' },
expected: { name: 'StringField', cardinality: 'SINGLE' },
expected: { name: 'StringField', cardinality: 'SINGLE', batch: false },
},
{
name: 'SINGLE BooleanField',
schema: { type: 'boolean' },
expected: { name: 'BooleanField', cardinality: 'SINGLE' },
expected: { name: 'BooleanField', cardinality: 'SINGLE', batch: false },
},
{
name: 'COLLECTION IntegerField',
schema: { items: { type: 'integer' }, type: 'array' },
expected: { name: 'IntegerField', cardinality: 'COLLECTION' },
expected: { name: 'IntegerField', cardinality: 'COLLECTION', batch: false },
},
{
name: 'COLLECTION FloatField',
schema: { items: { type: 'number' }, type: 'array' },
expected: { name: 'FloatField', cardinality: 'COLLECTION' },
expected: { name: 'FloatField', cardinality: 'COLLECTION', batch: false },
},
{
name: 'COLLECTION StringField',
schema: { items: { type: 'string' }, type: 'array' },
expected: { name: 'StringField', cardinality: 'COLLECTION' },
expected: { name: 'StringField', cardinality: 'COLLECTION', batch: false },
},
{
name: 'COLLECTION BooleanField',
schema: { items: { type: 'boolean' }, type: 'array' },
expected: { name: 'BooleanField', cardinality: 'COLLECTION' },
expected: { name: 'BooleanField', cardinality: 'COLLECTION', batch: false },
},
{
name: 'SINGLE_OR_COLLECTION IntegerField',
@@ -71,7 +71,7 @@ const primitiveTypes: ParseFieldTypeTestCase[] = [
},
],
},
expected: { name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' },
expected: { name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION', batch: false },
},
{
name: 'SINGLE_OR_COLLECTION FloatField',
@@ -88,7 +88,7 @@ const primitiveTypes: ParseFieldTypeTestCase[] = [
},
],
},
expected: { name: 'FloatField', cardinality: 'SINGLE_OR_COLLECTION' },
expected: { name: 'FloatField', cardinality: 'SINGLE_OR_COLLECTION', batch: false },
},
{
name: 'SINGLE_OR_COLLECTION StringField',
@@ -105,7 +105,7 @@ const primitiveTypes: ParseFieldTypeTestCase[] = [
},
],
},
expected: { name: 'StringField', cardinality: 'SINGLE_OR_COLLECTION' },
expected: { name: 'StringField', cardinality: 'SINGLE_OR_COLLECTION', batch: false },
},
{
name: 'SINGLE_OR_COLLECTION BooleanField',
@@ -122,7 +122,7 @@ const primitiveTypes: ParseFieldTypeTestCase[] = [
},
],
},
expected: { name: 'BooleanField', cardinality: 'SINGLE_OR_COLLECTION' },
expected: { name: 'BooleanField', cardinality: 'SINGLE_OR_COLLECTION', batch: false },
},
];
@@ -136,7 +136,7 @@ const complexTypes: ParseFieldTypeTestCase[] = [
},
],
},
expected: { name: 'ConditioningField', cardinality: 'SINGLE' },
expected: { name: 'ConditioningField', cardinality: 'SINGLE', batch: false },
},
{
name: 'Nullable SINGLE ConditioningField',
@@ -150,7 +150,7 @@ const complexTypes: ParseFieldTypeTestCase[] = [
},
],
},
expected: { name: 'ConditioningField', cardinality: 'SINGLE' },
expected: { name: 'ConditioningField', cardinality: 'SINGLE', batch: false },
},
{
name: 'COLLECTION ConditioningField',
@@ -164,7 +164,7 @@ const complexTypes: ParseFieldTypeTestCase[] = [
},
],
},
expected: { name: 'ConditioningField', cardinality: 'COLLECTION' },
expected: { name: 'ConditioningField', cardinality: 'COLLECTION', batch: false },
},
{
name: 'Nullable Collection ConditioningField',
@@ -181,7 +181,7 @@ const complexTypes: ParseFieldTypeTestCase[] = [
},
],
},
expected: { name: 'ConditioningField', cardinality: 'COLLECTION' },
expected: { name: 'ConditioningField', cardinality: 'COLLECTION', batch: false },
},
{
name: 'SINGLE_OR_COLLECTION ConditioningField',
@@ -198,7 +198,7 @@ const complexTypes: ParseFieldTypeTestCase[] = [
},
],
},
expected: { name: 'ConditioningField', cardinality: 'SINGLE_OR_COLLECTION' },
expected: { name: 'ConditioningField', cardinality: 'SINGLE_OR_COLLECTION', batch: false },
},
{
name: 'Nullable SINGLE_OR_COLLECTION ConditioningField',
@@ -218,7 +218,7 @@ const complexTypes: ParseFieldTypeTestCase[] = [
},
],
},
expected: { name: 'ConditioningField', cardinality: 'SINGLE_OR_COLLECTION' },
expected: { name: 'ConditioningField', cardinality: 'SINGLE_OR_COLLECTION', batch: false },
},
];
@@ -229,14 +229,14 @@ const specialCases: ParseFieldTypeTestCase[] = [
type: 'string',
enum: ['large', 'base', 'small'],
},
expected: { name: 'EnumField', cardinality: 'SINGLE' },
expected: { name: 'EnumField', cardinality: 'SINGLE', batch: false },
},
{
name: 'String EnumField with one value',
schema: {
const: 'Some Value',
},
expected: { name: 'EnumField', cardinality: 'SINGLE' },
expected: { name: 'EnumField', cardinality: 'SINGLE', batch: false },
},
{
name: 'Explicit ui_type (SchedulerField)',
@@ -245,7 +245,7 @@ const specialCases: ParseFieldTypeTestCase[] = [
enum: ['ddim', 'ddpm', 'deis'],
ui_type: 'SchedulerField',
},
expected: { name: 'EnumField', cardinality: 'SINGLE' },
expected: { name: 'EnumField', cardinality: 'SINGLE', batch: false },
},
{
name: 'Explicit ui_type (AnyField)',
@@ -254,7 +254,7 @@ const specialCases: ParseFieldTypeTestCase[] = [
enum: ['ddim', 'ddpm', 'deis'],
ui_type: 'AnyField',
},
expected: { name: 'EnumField', cardinality: 'SINGLE' },
expected: { name: 'EnumField', cardinality: 'SINGLE', batch: false },
},
{
name: 'Explicit ui_type (CollectionField)',
@@ -263,7 +263,7 @@ const specialCases: ParseFieldTypeTestCase[] = [
enum: ['ddim', 'ddpm', 'deis'],
ui_type: 'CollectionField',
},
expected: { name: 'EnumField', cardinality: 'SINGLE' },
expected: { name: 'EnumField', cardinality: 'SINGLE', batch: false },
},
];

View File

@@ -49,6 +49,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType
return {
name: 'EnumField',
cardinality: 'SINGLE',
batch: false,
};
}
if (!schemaObject.type) {
@@ -65,6 +66,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType
return {
name,
cardinality: 'SINGLE',
batch: false,
};
}
} else if (schemaObject.anyOf) {
@@ -88,6 +90,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType
return {
name,
cardinality: 'SINGLE',
batch: false,
};
} else if (isSchemaObject(filteredAnyOf[0])) {
return parseFieldType(filteredAnyOf[0]);
@@ -141,6 +144,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType
return {
name: OPENAPI_TO_FIELD_TYPE_MAP[firstType] ?? firstType,
cardinality: 'SINGLE_OR_COLLECTION',
batch: false,
};
}
@@ -155,6 +159,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType
return {
name: 'EnumField',
cardinality: 'SINGLE',
batch: false,
};
} else if (schemaObject.type) {
if (schemaObject.type === 'array') {
@@ -181,6 +186,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType
return {
name,
cardinality: 'COLLECTION',
batch: false,
};
}
@@ -192,6 +198,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType
return {
name,
cardinality: 'COLLECTION',
batch: false,
};
} else if (!isArray(schemaObject.type)) {
// This is an OpenAPI primitive - 'null', 'object', 'array', 'integer', 'number', 'string', 'boolean'
@@ -207,6 +214,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType
return {
name,
cardinality: 'SINGLE',
batch: false,
};
}
}
@@ -218,6 +226,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType
return {
name,
cardinality: 'SINGLE',
batch: false,
};
}
throw new FieldParseError(t('nodes.unableToParseFieldType'));

View File

@@ -64,6 +64,32 @@ const isAllowedOutputField = (nodeType: string, fieldName: string) => {
return true;
};
const isBatchInputField = (nodeType: string, fieldName: string) => {
if (nodeType === 'float_batch' && fieldName === 'floats') {
return true;
}
if (nodeType === 'integer_batch' && fieldName === 'integers') {
return true;
}
if (nodeType === 'string_batch' && fieldName === 'strings') {
return true;
}
return false;
};
const isBatchOutputField = (nodeType: string, fieldName: string) => {
if (nodeType === 'float_generator' && fieldName === 'floats') {
return true;
}
if (nodeType === 'integer_generator' && fieldName === 'integers') {
return true;
}
if (nodeType === 'string_generator' && fieldName === 'strings') {
return true;
}
return false;
};
const isNotInDenylist = (schema: InvocationSchemaObject) =>
!invocationDenylist.includes(schema.properties.type.default);
@@ -107,6 +133,7 @@ export const parseSchema = (
? {
name: property.ui_type,
cardinality: isCollectionFieldType(property.ui_type) ? 'COLLECTION' : 'SINGLE',
batch: false,
}
: null;
@@ -127,6 +154,10 @@ export const parseSchema = (
fieldType.originalType = deepClone(originalFieldType);
}
if (isBatchInputField(type, propertyName)) {
fieldType.batch = true;
}
const fieldInputTemplate = buildFieldInputTemplate(property, propertyName, fieldType);
inputsAccumulator[propertyName] = fieldInputTemplate;
@@ -172,6 +203,7 @@ export const parseSchema = (
? {
name: property.ui_type,
cardinality: isCollectionFieldType(property.ui_type) ? 'COLLECTION' : 'SINGLE',
batch: false,
}
: null;
@@ -187,6 +219,10 @@ export const parseSchema = (
fieldType.originalType = deepClone(originalFieldType);
}
if (isBatchOutputField(type, propertyName)) {
fieldType.batch = true;
}
const fieldOutputTemplate = buildFieldOutputTemplate(property, propertyName, fieldType);
outputsAccumulator[propertyName] = fieldOutputTemplate;

View File

@@ -20,7 +20,7 @@ import { z } from 'zod';
* @param schema The zod schema to create a type guard from.
* @returns A type guard function for the schema.
*/
const buildTypeGuard = <T extends z.ZodTypeAny>(schema: T) => {
export const buildTypeGuard = <T extends z.ZodTypeAny>(schema: T) => {
return (val: unknown): val is z.infer<T> => schema.safeParse(val).success;
};

View File

@@ -206,8 +206,16 @@ const QueueCountPredictionWorkflowsTab = memo(() => {
const iterationsCount = useAppSelector(selectIterations);
const text = useMemo(() => {
const generationCount = Math.min(batchSize * iterationsCount, 10000);
const iterations = t('queue.iterations', { count: iterationsCount });
if (batchSize === 'NO_BATCHES') {
const generationCount = Math.min(10000, iterationsCount);
const generations = t('queue.generations', { count: generationCount });
return `${iterationsCount} ${iterations} -> ${generationCount} ${generations}`.toLowerCase();
}
if (batchSize === 'EMPTY_BATCHES') {
return t('parameters.invoke.invalidBatchConfigurationCannotCalculate');
}
const generationCount = Math.min(batchSize * iterationsCount, 10000);
const generations = t('queue.generations', { count: generationCount });
return `${batchSize} ${t('queue.batchSize')} \u00d7 ${iterationsCount} ${iterations} -> ${generationCount} ${generations}`.toLowerCase();
}, [batchSize, iterationsCount, t]);

View File

@@ -1,4 +1,5 @@
import { createSelector } from '@reduxjs/toolkit';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import type { AppConfig } from 'app/types/invokeai';
import type { ParamsState } from 'features/controlLayers/store/paramsSlice';
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
@@ -18,14 +19,36 @@ import { selectNodesSlice } from 'features/nodes/store/selectors';
import type { NodesState, Templates } from 'features/nodes/store/types';
import type { WorkflowSettingsState } from 'features/nodes/store/workflowSettingsSlice';
import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
import { isImageFieldCollectionInputInstance, isImageFieldCollectionInputTemplate } from 'features/nodes/types/field';
import { isInvocationNode } from 'features/nodes/types/invocation';
import {
isFloatFieldCollectionInputInstance,
isFloatFieldCollectionInputTemplate,
isFloatGeneratorFieldInputInstance,
isImageFieldCollectionInputInstance,
isImageFieldCollectionInputTemplate,
isIntegerFieldCollectionInputInstance,
isIntegerFieldCollectionInputTemplate,
isIntegerGeneratorFieldInputInstance,
isStringFieldCollectionInputInstance,
isStringFieldCollectionInputTemplate,
isStringGeneratorFieldInputInstance,
resolveFloatGeneratorField,
resolveIntegerGeneratorField,
resolveStringGeneratorField,
} from 'features/nodes/types/field';
import {
validateImageFieldCollectionValue,
validateNumberFieldCollectionValue,
validateStringFieldCollectionValue,
} from 'features/nodes/types/fieldValidators';
import type { InvocationNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
import { isBatchNode, isExecutableNode, isInvocationNode } from 'features/nodes/types/invocation';
import type { UpscaleState } from 'features/parameters/store/upscaleSlice';
import { selectUpscaleSlice } from 'features/parameters/store/upscaleSlice';
import { selectConfigSlice } from 'features/system/store/configSlice';
import i18n from 'i18next';
import { forEach, upperFirst } from 'lodash-es';
import { forEach, groupBy, upperFirst } from 'lodash-es';
import { getConnectedEdges } from 'reactflow';
import { assert } from 'tsafe';
/**
* This file contains selectors and utilities for determining the app is ready to enqueue generations. The handling
@@ -47,6 +70,67 @@ export type Reason = { prefix?: string; content: string };
const disconnectedReason = (t: typeof i18n.t) => ({ content: t('parameters.invoke.systemDisconnected') });
export const resolveBatchValue = (batchNode: InvocationNode, nodes: InvocationNode[], edges: InvocationNodeEdge[]) => {
if (batchNode.data.type === 'image_batch') {
assert(isImageFieldCollectionInputInstance(batchNode.data.inputs.images));
const ownValue = batchNode.data.inputs.images.value ?? [];
// no generators for images yet
return ownValue;
} else if (batchNode.data.type === 'string_batch') {
assert(isStringFieldCollectionInputInstance(batchNode.data.inputs.strings));
const ownValue = batchNode.data.inputs.strings.value;
const edgeToStrings = edges.find((edge) => edge.target === batchNode.id && edge.targetHandle === 'strings');
if (!edgeToStrings) {
return ownValue ?? [];
}
const generatorNode = nodes.find((node) => node.id === edgeToStrings.source);
assert(generatorNode, 'Missing edge from string generator to string batch');
const generatorField = generatorNode.data.inputs['generator'];
assert(isStringGeneratorFieldInputInstance(generatorField), 'Invalid string generator');
const generatorValue = resolveStringGeneratorField(generatorField);
return generatorValue;
} else if (batchNode.data.type === 'float_batch') {
assert(isFloatFieldCollectionInputInstance(batchNode.data.inputs.floats));
const ownValue = batchNode.data.inputs.floats.value;
const edgeToFloats = edges.find((edge) => edge.target === batchNode.id && edge.targetHandle === 'floats');
if (!edgeToFloats) {
return ownValue ?? [];
}
const generatorNode = nodes.find((node) => node.id === edgeToFloats.source);
assert(generatorNode, 'Missing edge from float generator to float batch');
const generatorField = generatorNode.data.inputs['generator'];
assert(isFloatGeneratorFieldInputInstance(generatorField), 'Invalid float generator');
const generatorValue = resolveFloatGeneratorField(generatorField);
return generatorValue;
} else if (batchNode.data.type === 'integer_batch') {
assert(isIntegerFieldCollectionInputInstance(batchNode.data.inputs.integers));
const ownValue = batchNode.data.inputs.integers.value;
const incomers = edges.find((edge) => edge.target === batchNode.id && edge.targetHandle === 'integers');
if (!incomers) {
return ownValue ?? [];
}
const generatorNode = nodes.find((node) => node.id === incomers.source);
assert(generatorNode, 'Missing edge from integer generator to integer batch');
const generatorField = generatorNode.data.inputs['generator'];
assert(isIntegerGeneratorFieldInputInstance(generatorField), 'Invalid integer generator field');
const generatorValue = resolveIntegerGeneratorField(generatorField);
return generatorValue;
}
assert(false, 'Invalid batch node type');
};
const getReasonsWhyCannotEnqueueWorkflowsTab = (arg: {
isConnected: boolean;
nodes: NodesState;
@@ -61,11 +145,54 @@ const getReasonsWhyCannotEnqueueWorkflowsTab = (arg: {
}
if (workflowSettings.shouldValidateGraph) {
if (!nodes.nodes.length) {
const invocationNodes = nodes.nodes.filter(isInvocationNode);
const batchNodes = invocationNodes.filter(isBatchNode);
const executableNodes = invocationNodes.filter(isExecutableNode);
if (!executableNodes.length) {
reasons.push({ content: i18n.t('parameters.invoke.noNodesInGraph') });
}
nodes.nodes.forEach((node) => {
for (const node of batchNodes) {
if (nodes.edges.find((e) => e.source === node.id) === undefined) {
reasons.push({ content: i18n.t('parameters.invoke.batchNodeNotConnected', { label: node.data.label }) });
}
}
if (batchNodes.length > 1) {
const batchSizes: number[] = [];
const groupedBatchNodes = groupBy(batchNodes, (node) => node.data.inputs['batch_group_id']?.value);
for (const [batchGroupId, batchNodes] of Object.entries(groupedBatchNodes)) {
// But grouped batch nodes must have the same collection size
const groupBatchSizes: number[] = [];
for (const node of batchNodes) {
const size = resolveBatchValue(node, invocationNodes, nodes.edges).length;
if (batchGroupId === 'None') {
// Ungrouped batch nodes may have differing collection sizes
batchSizes.push(size);
} else {
groupBatchSizes.push(size);
}
}
if (groupBatchSizes.some((count) => count !== groupBatchSizes[0])) {
reasons.push({
content: i18n.t('parameters.invoke.batchNodeCollectionSizeMismatch', { batchGroupId }),
});
}
if (groupBatchSizes[0] !== undefined) {
batchSizes.push(groupBatchSizes[0]);
}
}
if (batchSizes.some((size) => size === 0)) {
reasons.push({ content: i18n.t('parameters.invoke.batchNodeEmptyCollection') });
}
}
executableNodes.forEach((node) => {
if (!isInvocationNode(node)) {
return;
}
@@ -91,45 +218,38 @@ const getReasonsWhyCannotEnqueueWorkflowsTab = (arg: {
return;
}
const baseTKeyOptions = {
nodeLabel: node.data.label || nodeTemplate.title,
fieldLabel: field.label || fieldTemplate.title,
};
const prefix = `${node.data.label || nodeTemplate.title} -> ${field.label || fieldTemplate.title}`;
if (fieldTemplate.required && field.value === undefined && !hasConnection) {
reasons.push({ content: i18n.t('parameters.invoke.missingInputForField', baseTKeyOptions) });
return;
reasons.push({ prefix, content: i18n.t('parameters.invoke.missingInputForField') });
} else if (
field.value &&
isImageFieldCollectionInputInstance(field) &&
isImageFieldCollectionInputTemplate(fieldTemplate)
) {
// Image collections may have min or max items to validate
// TODO(psyche): generalize this to other collection types
if (fieldTemplate.minItems !== undefined && fieldTemplate.minItems > 0 && field.value.length === 0) {
reasons.push({ content: i18n.t('parameters.invoke.collectionEmpty', baseTKeyOptions) });
return;
}
if (fieldTemplate.minItems !== undefined && field.value.length < fieldTemplate.minItems) {
reasons.push({
content: i18n.t('parameters.invoke.collectionTooFewItems', {
...baseTKeyOptions,
size: field.value.length,
minItems: fieldTemplate.minItems,
}),
});
return;
}
if (fieldTemplate.maxItems !== undefined && field.value.length > fieldTemplate.maxItems) {
reasons.push({
content: i18n.t('parameters.invoke.collectionTooManyItems', {
...baseTKeyOptions,
size: field.value.length,
maxItems: fieldTemplate.maxItems,
}),
});
return;
}
const errors = validateImageFieldCollectionValue(field.value, fieldTemplate);
reasons.push(...errors.map((error) => ({ prefix, content: error })));
} else if (
field.value &&
isStringFieldCollectionInputInstance(field) &&
isStringFieldCollectionInputTemplate(fieldTemplate)
) {
const errors = validateStringFieldCollectionValue(field.value, fieldTemplate);
reasons.push(...errors.map((error) => ({ prefix, content: error })));
} else if (
field.value &&
isIntegerFieldCollectionInputInstance(field) &&
isIntegerFieldCollectionInputTemplate(fieldTemplate)
) {
const errors = validateNumberFieldCollectionValue(field.value, fieldTemplate);
reasons.push(...errors.map((error) => ({ prefix, content: error })));
} else if (
field.value &&
isFloatFieldCollectionInputInstance(field) &&
isFloatFieldCollectionInputTemplate(fieldTemplate)
) {
const errors = validateNumberFieldCollectionValue(field.value, fieldTemplate);
reasons.push(...errors.map((error) => ({ prefix, content: error })));
}
});
});
@@ -491,17 +611,80 @@ export const selectPromptsCount = createSelector(
(params, dynamicPrompts) => (getShouldProcessPrompt(params.positivePrompt) ? dynamicPrompts.prompts.length : 1)
);
export const selectWorkflowsBatchSize = createSelector(selectNodesSlice, ({ nodes }) =>
// The batch size is the product of all batch nodes' collection sizes
nodes.filter(isInvocationNode).reduce((batchSize, node) => {
if (!isImageFieldCollectionInputInstance(node.data.inputs.images)) {
return batchSize;
}
// If the batch size is not set, default to 1
batchSize = batchSize || 1;
// Multiply the batch size by the number of images in the batch
batchSize = batchSize * (node.data.inputs.images.value?.length ?? 0);
const buildSelectGroupBatchSizes = (batchGroupId: string) =>
createMemoizedSelector(selectNodesSlice, ({ nodes, edges }) => {
const invocationNodes = nodes.filter(isInvocationNode);
return invocationNodes
.filter(isBatchNode)
.filter((node) => node.data.inputs['batch_group_id']?.value === batchGroupId)
.map((batchNodes) => resolveBatchValue(batchNodes, invocationNodes, edges).length);
});
return batchSize;
}, 0)
const selectUngroupedBatchSizes = buildSelectGroupBatchSizes('None');
const selectGroup1BatchSizes = buildSelectGroupBatchSizes('Group 1');
const selectGroup2BatchSizes = buildSelectGroupBatchSizes('Group 2');
const selectGroup3BatchSizes = buildSelectGroupBatchSizes('Group 3');
const selectGroup4BatchSizes = buildSelectGroupBatchSizes('Group 4');
const selectGroup5BatchSizes = buildSelectGroupBatchSizes('Group 5');
export const selectWorkflowsBatchSize = createSelector(
selectUngroupedBatchSizes,
selectGroup1BatchSizes,
selectGroup2BatchSizes,
selectGroup3BatchSizes,
selectGroup4BatchSizes,
selectGroup5BatchSizes,
(
ungroupedBatchSizes,
group1BatchSizes,
group2BatchSizes,
group3BatchSizes,
group4BatchSizes,
group5BatchSizes
): number | 'EMPTY_BATCHES' | 'NO_BATCHES' => {
// All batch nodes _must_ have a populated collection
const allBatchSizes = [
...ungroupedBatchSizes,
...group1BatchSizes,
...group2BatchSizes,
...group3BatchSizes,
...group4BatchSizes,
...group5BatchSizes,
];
// There are no batch nodes
if (allBatchSizes.length === 0) {
return 'NO_BATCHES';
}
// All batch nodes must have a populated collection
if (allBatchSizes.some((size) => size === 0)) {
return 'EMPTY_BATCHES';
}
for (const group of [group1BatchSizes, group2BatchSizes, group3BatchSizes, group4BatchSizes, group5BatchSizes]) {
// Ignore groups with no batch nodes
if (group.length === 0) {
continue;
}
// Grouped batch nodes must have the same collection size
if (group.some((size) => size !== group[0])) {
return 'EMPTY_BATCHES';
}
}
// Total batch size = product of all ungrouped batches and each grouped batch
const totalBatchSize = [
...ungroupedBatchSizes,
// In case of no batch nodes in a group, fall back to 1 for the product calculation
group1BatchSizes[0] ?? 1,
group2BatchSizes[0] ?? 1,
group3BatchSizes[0] ?? 1,
group4BatchSizes[0] ?? 1,
group5BatchSizes[0] ?? 1,
].reduce((acc, size) => acc * size, 1);
return totalBatchSize;
}
);

View File

@@ -9,6 +9,7 @@ const StylePresetImage = ({ presetImageUrl, imageWidth }: { presetImageUrl: stri
return (
<Tooltip
closeOnScroll
openDelay={0}
label={
presetImageUrl && (
<Image

View File

@@ -1,4 +1,4 @@
import type { components } from 'services/api/schema';
import type { paths } from 'services/api/schema';
import { api, buildV1Url } from '..';
@@ -13,8 +13,8 @@ const buildUtilitiesUrl = (path: string = '') => buildV1Url(`utilities/${path}`)
export const utilitiesApi = api.injectEndpoints({
endpoints: (build) => ({
dynamicPrompts: build.query<
components['schemas']['DynamicPromptsResponse'],
{ prompt: string; max_prompts: number }
paths['/api/v1/utilities/dynamicprompts']['post']['responses']['200']['content']['application/json'],
paths['/api/v1/utilities/dynamicprompts']['post']['requestBody']['content']['application/json']
>({
query: (arg) => ({
url: buildUtilitiesUrl('dynamicprompts'),
@@ -28,3 +28,5 @@ export const utilitiesApi = api.injectEndpoints({
}),
}),
});
export const { useDynamicPromptsQuery } = utilitiesApi;

File diff suppressed because one or more lines are too long

View File

@@ -1 +1 @@
__version__ = "5.6.0rc2"
__version__ = "5.6.0"

View File

@@ -3,7 +3,11 @@ import torch
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import (
CachedModelOnlyFullLoad,
)
from tests.backend.model_manager.load.model_cache.cached_model.utils import DummyModule, parameterize_mps_and_cuda
from tests.backend.model_manager.load.model_cache.cached_model.utils import (
DummyModule,
parameterize_keep_ram_copy,
parameterize_mps_and_cuda,
)
class NonTorchModel:
@@ -17,16 +21,22 @@ class NonTorchModel:
@parameterize_mps_and_cuda
def test_cached_model_total_bytes(device: str):
@parameterize_keep_ram_copy
def test_cached_model_total_bytes(device: str, keep_ram_copy: bool):
model = DummyModule()
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
cached_model = CachedModelOnlyFullLoad(
model=model, compute_device=torch.device(device), total_bytes=100, keep_ram_copy=keep_ram_copy
)
assert cached_model.total_bytes() == 100
@parameterize_mps_and_cuda
def test_cached_model_is_in_vram(device: str):
@parameterize_keep_ram_copy
def test_cached_model_is_in_vram(device: str, keep_ram_copy: bool):
model = DummyModule()
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
cached_model = CachedModelOnlyFullLoad(
model=model, compute_device=torch.device(device), total_bytes=100, keep_ram_copy=keep_ram_copy
)
assert not cached_model.is_in_vram()
assert cached_model.cur_vram_bytes() == 0
@@ -40,9 +50,12 @@ def test_cached_model_is_in_vram(device: str):
@parameterize_mps_and_cuda
def test_cached_model_full_load_and_unload(device: str):
@parameterize_keep_ram_copy
def test_cached_model_full_load_and_unload(device: str, keep_ram_copy: bool):
model = DummyModule()
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
cached_model = CachedModelOnlyFullLoad(
model=model, compute_device=torch.device(device), total_bytes=100, keep_ram_copy=keep_ram_copy
)
assert cached_model.full_load_to_vram() == 100
assert cached_model.is_in_vram()
assert all(p.device.type == device for p in cached_model.model.parameters())
@@ -55,7 +68,9 @@ def test_cached_model_full_load_and_unload(device: str):
@parameterize_mps_and_cuda
def test_cached_model_get_cpu_state_dict(device: str):
model = DummyModule()
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
cached_model = CachedModelOnlyFullLoad(
model=model, compute_device=torch.device(device), total_bytes=100, keep_ram_copy=True
)
assert not cached_model.is_in_vram()
# The CPU state dict can be accessed and has the expected properties.
@@ -76,9 +91,12 @@ def test_cached_model_get_cpu_state_dict(device: str):
@parameterize_mps_and_cuda
def test_cached_model_full_load_and_inference(device: str):
@parameterize_keep_ram_copy
def test_cached_model_full_load_and_inference(device: str, keep_ram_copy: bool):
model = DummyModule()
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
cached_model = CachedModelOnlyFullLoad(
model=model, compute_device=torch.device(device), total_bytes=100, keep_ram_copy=keep_ram_copy
)
assert not cached_model.is_in_vram()
# Run inference on the CPU.
@@ -99,9 +117,12 @@ def test_cached_model_full_load_and_inference(device: str):
@parameterize_mps_and_cuda
def test_non_torch_model(device: str):
@parameterize_keep_ram_copy
def test_non_torch_model(device: str, keep_ram_copy: bool):
model = NonTorchModel()
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
cached_model = CachedModelOnlyFullLoad(
model=model, compute_device=torch.device(device), total_bytes=100, keep_ram_copy=keep_ram_copy
)
assert not cached_model.is_in_vram()
# The model does not have a CPU state dict.

View File

@@ -10,7 +10,11 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch
apply_custom_layers_to_model,
)
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
from tests.backend.model_manager.load.model_cache.cached_model.utils import DummyModule, parameterize_mps_and_cuda
from tests.backend.model_manager.load.model_cache.cached_model.utils import (
DummyModule,
parameterize_keep_ram_copy,
parameterize_mps_and_cuda,
)
@pytest.fixture
@@ -21,8 +25,11 @@ def model():
@parameterize_mps_and_cuda
def test_cached_model_total_bytes(device: str, model: DummyModule):
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
@parameterize_keep_ram_copy
def test_cached_model_total_bytes(device: str, model: DummyModule, keep_ram_copy: bool):
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
linear1_numel = 10 * 32 + 32
linear2_numel = 32 * 64 + 64
buffer1_numel = 64
@@ -31,9 +38,12 @@ def test_cached_model_total_bytes(device: str, model: DummyModule):
@parameterize_mps_and_cuda
def test_cached_model_cur_vram_bytes(device: str, model: DummyModule):
@parameterize_keep_ram_copy
def test_cached_model_cur_vram_bytes(device: str, model: DummyModule, keep_ram_copy: bool):
# Model starts in CPU memory.
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
assert cached_model.cur_vram_bytes() == 0
# Full load the model into VRAM.
@@ -45,9 +55,12 @@ def test_cached_model_cur_vram_bytes(device: str, model: DummyModule):
@parameterize_mps_and_cuda
def test_cached_model_partial_load(device: str, model: DummyModule):
@parameterize_keep_ram_copy
def test_cached_model_partial_load(device: str, model: DummyModule, keep_ram_copy: bool):
# Model starts in CPU memory.
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0
@@ -71,9 +84,12 @@ def test_cached_model_partial_load(device: str, model: DummyModule):
@parameterize_mps_and_cuda
def test_cached_model_partial_unload(device: str, model: DummyModule):
@parameterize_keep_ram_copy
def test_cached_model_partial_unload(device: str, model: DummyModule, keep_ram_copy: bool):
# Model starts in CPU memory.
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0
@@ -99,9 +115,14 @@ def test_cached_model_partial_unload(device: str, model: DummyModule):
@parameterize_mps_and_cuda
def test_cached_model_partial_unload_keep_required_weights_in_vram(device: str, model: DummyModule):
@parameterize_keep_ram_copy
def test_cached_model_partial_unload_keep_required_weights_in_vram(
device: str, model: DummyModule, keep_ram_copy: bool
):
# Model starts in CPU memory.
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0
@@ -130,8 +151,11 @@ def test_cached_model_partial_unload_keep_required_weights_in_vram(device: str,
@parameterize_mps_and_cuda
def test_cached_model_full_load_and_unload(device: str, model: DummyModule):
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
@parameterize_keep_ram_copy
def test_cached_model_full_load_and_unload(device: str, model: DummyModule, keep_ram_copy: bool):
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
# Model starts in CPU memory.
model_total_bytes = cached_model.total_bytes()
@@ -162,8 +186,11 @@ def test_cached_model_full_load_and_unload(device: str, model: DummyModule):
@parameterize_mps_and_cuda
def test_cached_model_full_load_from_partial(device: str, model: DummyModule):
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
@parameterize_keep_ram_copy
def test_cached_model_full_load_from_partial(device: str, model: DummyModule, keep_ram_copy: bool):
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
# Model starts in CPU memory.
model_total_bytes = cached_model.total_bytes()
@@ -190,8 +217,11 @@ def test_cached_model_full_load_from_partial(device: str, model: DummyModule):
@parameterize_mps_and_cuda
def test_cached_model_full_unload_from_partial(device: str, model: DummyModule):
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
@parameterize_keep_ram_copy
def test_cached_model_full_unload_from_partial(device: str, model: DummyModule, keep_ram_copy: bool):
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
# Model starts in CPU memory.
model_total_bytes = cached_model.total_bytes()
@@ -219,7 +249,7 @@ def test_cached_model_full_unload_from_partial(device: str, model: DummyModule):
@parameterize_mps_and_cuda
def test_cached_model_get_cpu_state_dict(device: str, model: DummyModule):
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device), keep_ram_copy=True)
# Model starts in CPU memory.
assert cached_model.cur_vram_bytes() == 0
@@ -242,8 +272,11 @@ def test_cached_model_get_cpu_state_dict(device: str, model: DummyModule):
@parameterize_mps_and_cuda
def test_cached_model_full_load_and_inference(device: str, model: DummyModule):
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
@parameterize_keep_ram_copy
def test_cached_model_full_load_and_inference(device: str, model: DummyModule, keep_ram_copy: bool):
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
# Model starts in CPU memory.
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0
@@ -269,9 +302,12 @@ def test_cached_model_full_load_and_inference(device: str, model: DummyModule):
@parameterize_mps_and_cuda
def test_cached_model_partial_load_and_inference(device: str, model: DummyModule):
@parameterize_keep_ram_copy
def test_cached_model_partial_load_and_inference(device: str, model: DummyModule, keep_ram_copy: bool):
# Model starts in CPU memory.
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
cached_model = CachedModelWithPartialLoad(
model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy
)
model_total_bytes = cached_model.total_bytes()
assert cached_model.cur_vram_bytes() == 0

View File

@@ -29,3 +29,5 @@ parameterize_mps_and_cuda = pytest.mark.parametrize(
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available.")),
],
)
parameterize_keep_ram_copy = pytest.mark.parametrize("keep_ram_copy", [True, False])

View File

@@ -94,6 +94,7 @@ def mm2_loader(mm2_app_config: InvokeAIAppConfig) -> ModelLoadServiceBase:
ram_cache = ModelCache(
execution_device_working_mem_gb=mm2_app_config.device_working_mem_gb,
enable_partial_loading=mm2_app_config.enable_partial_loading,
keep_ram_copy_of_weights=mm2_app_config.keep_ram_copy_of_weights,
max_ram_cache_size_gb=mm2_app_config.max_cache_ram_gb,
max_vram_cache_size_gb=mm2_app_config.max_cache_vram_gb,
execution_device=TorchDevice.choose_torch_device(),

View File

@@ -189,6 +189,26 @@ def test_cannot_create_bad_batch_items_type(batch_graph):
)
def test_number_type_interop(batch_graph):
# integers and floats can be mixed, should not throw an error
Batch(
graph=batch_graph,
data=[
[
BatchDatum(node_path="1", field_name="prompt", items=[1, 1.5]),
]
],
)
Batch(
graph=batch_graph,
data=[
[
BatchDatum(node_path="1", field_name="prompt", items=[1.5, 1]),
]
],
)
def test_cannot_create_bad_batch_unique_ids(batch_graph):
with pytest.raises(ValidationError, match="Each batch data must have unique node_id and field_name"):
Batch(