Compare commits

...

140 Commits

Author SHA1 Message Date
psychedelicious
ccc55069d1 chore: bump version to v6.3.0 2025-08-05 10:30:26 +10:00
psychedelicious
61ff9ee3a7 feat(ui): add button to ref image to recall size & optimize for model
This is useful for FLUX Kontext, where you typically want the generation
size to at least roughly match the first ref image size.
2025-08-05 10:28:44 +10:00
psychedelicious
111408c046 feat(mm): add flux krea to starter models 2025-08-05 10:25:14 +10:00
psychedelicious
d7619d465e feat(mm): change anime upscaling model to one that doesn't trigger picklescan 2025-08-05 10:25:14 +10:00
Kent Keirsey
8ad4f6e56d updates & fix 2025-08-05 10:10:52 +10:00
Cursor Agent
bf4899526f Add 'shift+s' hotkey for fitting bbox to canvas
Co-authored-by: kent <kent@invoke.ai>
2025-08-05 10:10:52 +10:00
psychedelicious
6435d265c6 fix(ui): overflow w/ long board names 2025-08-05 10:06:55 +10:00
Linos
3163ef454d translationBot(ui): update translation (Vietnamese)
Currently translated at 100.0% (2065 of 2065 strings)

Co-authored-by: Linos <linos.coding@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/vi/
Translation: InvokeAI/Web UI
2025-08-05 10:04:20 +10:00
Riccardo Giovanetti
7ea636df70 translationBot(ui): update translation (Italian)
Currently translated at 98.6% (2037 of 2065 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.6% (2037 of 2065 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.5% (2036 of 2065 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.6% (2014 of 2042 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2025-08-05 10:04:20 +10:00
Hosted Weblate
1869824803 translationBot(ui): update translation files
Updated by "Cleanup translation files" hook in Weblate.

translationBot(ui): update translation files

Updated by "Cleanup translation files" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI
2025-08-05 10:04:20 +10:00
psychedelicious
66fc8af8a6 fix(ui): reset session button actions
- Do not reset dimensions when resetting generation settings (they are
model-dependent, and we don't change model-dependent settings w/ that
butotn)
- Do not reset bbox when resetting canvas layers
- Show reset canvas layers button only on canvas tab
- Show reset generation settings button only on canvas or generate tab
2025-08-05 10:01:22 +10:00
psychedelicious
48cb6b12f0 fix(ui): add style ref launchpad using wrong dnd config
I don't think this actually caused problems bc the two DND targets were
very similar, but it was wrong.
2025-08-05 09:57:11 +10:00
psychedelicious
68e30a9864 feat(ui): prevent creating new canvases while staging
Disable these items while staging:
- New Canvas From Image context menu
- Edit image hook & launchpad button
- Generate from Text launchpad button (only while on canvas tab)
- Use a Layout Image launchpad button
2025-08-05 09:57:11 +10:00
psychedelicious
f65dc2c081 chore(ui): typegen 2025-08-05 09:54:00 +10:00
psychedelicious
0cd77443a7 feat(app): add setting to disable picklescan
When unsafe_disable_picklescan is enabled, instead of erroring on
detections or scan failures, a warning is logged.

A warning is also logged on app startup when this setting is enabled.

The setting is disabled by default and there is no change in behaviour
when disabled.
2025-08-05 09:54:00 +10:00
Mary Hipp
185ed86424 fix graph building 2025-08-04 12:32:27 -04:00
Mary Hipp
fed817ab83 add image concatenation to flux kontext graph if more than one refernece image 2025-08-04 11:27:02 -04:00
Mary Hipp
e0b45db69a remove check in readiness for multiple reg images 2025-08-04 11:27:02 -04:00
psychedelicious
2beac1fb04 chore: bump version to v6.3.0rc2 2025-08-04 23:55:04 +10:00
psychedelicious
e522de33f8 refactor(nodes): roll back latent-space resizing of kontext images 2025-08-04 23:03:12 +10:00
psychedelicious
d591b50c25 feat(ui): use image-space concatenation in FLUX graphs 2025-08-04 23:03:12 +10:00
psychedelicious
b365aad6d8 chore(ui): typegen 2025-08-04 23:03:12 +10:00
psychedelicious
65ad392361 feat(nodes): add node to prep images for FLUX Kontext 2025-08-04 23:03:12 +10:00
psychedelicious
56d75e1c77 feat(backend): use VAE mean encoding for Kontext reference images
Use distribution mean without sampling noise for more stable and
consistent reference image encoding, matching ComfyUI implementation
2025-08-04 23:03:12 +10:00
psychedelicious
df77a12efe refactor(backend): use torchvision transforms for Kontext image preprocessing
Replace numpy-based normalization with torchvision transforms for
consistency with other image processing in the codebase
2025-08-04 23:03:12 +10:00
psychedelicious
faf662d12e refactor(backend): use BICUBIC resampling for Kontext images
Switch from LANCZOS to BICUBIC for smoother image resizing to reduce
artifacts in reference image processing
2025-08-04 23:03:12 +10:00
psychedelicious
44a7dfd486 fix(backend): use consistent idx_offset=1 for all Kontext images
Changes from per-image index offsets to a consistent value of 1 for
all reference images, matching the ComfyUI implementation
2025-08-04 23:03:12 +10:00
psychedelicious
bb15e5cf06 feat(backend): add spatial tiling for multiple Kontext reference images
Implements intelligent spatial tiling that arranges multiple reference
images in a virtual canvas, choosing between horizontal and vertical
placement to maintain a square-like aspect ratio
2025-08-04 23:03:12 +10:00
psychedelicious
1a1c846be3 feat(backend): include reference images in negative CFG pass for Kontext
Maintains consistency between positive and negative passes to prevent
CFG artifacts when using Kontext reference images
2025-08-04 23:03:12 +10:00
psychedelicious
93c896a370 fix(backend): use img_cond_seq to check for Kontext slicing
Was incorrectly checking img_input_ids instead of img_cond_seq
2025-08-04 23:03:12 +10:00
psychedelicious
053d7c8c8e feat(ui): support disabling roarr output styling via localstorage 2025-07-31 23:02:45 +10:00
psychedelicious
5296263954 feat(ui): add missing translations 2025-07-31 22:51:33 +10:00
psychedelicious
a36b70c01c fix(ui): add image name data attr to gallery placeholder image elements
This fixes an issue where gallery's auto-scroll-into-view for selected
images didn't work, and users instead saw a "Unable to find image..."
debug log message in JS console.
2025-07-31 22:48:42 +10:00
psychedelicious
854a2a5a7a chore: bump version to v6.3.0rc1 2025-07-31 14:17:18 +10:00
psychedelicious
f9c64b0609 chore(ui): update whats new 2025-07-31 14:17:18 +10:00
psychedelicious
5889fa536a feat(ui): add migration path for client state from IndexedDB to server-backed storage 2025-07-31 14:09:45 +10:00
Linos
0e71ba892f translationBot(ui): update translation (Vietnamese)
Currently translated at 100.0% (2044 of 2044 strings)

Co-authored-by: Linos <linos.coding@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/vi/
Translation: InvokeAI/Web UI
2025-07-31 13:59:21 +10:00
Riccardo Giovanetti
d766a21223 translationBot(ui): update translation (Italian)
Currently translated at 98.6% (2016 of 2044 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2025-07-31 13:59:21 +10:00
psychedelicious
5c8c54eab8 chore: ruff 2025-07-31 06:38:48 +10:00
psychedelicious
f296f4525c tidy(ui): disable logging middleware 2025-07-31 06:38:48 +10:00
psychedelicious
7c9ba4cb52 refactor(ui): add persistence gate logic to prevent race conditions with slow rehydration 2025-07-31 06:38:48 +10:00
psychedelicious
6784fd5b43 refactor(ui): use new routes for _all_ client state persistence (no override/custom drivers) 2025-07-31 06:38:48 +10:00
psychedelicious
11d68cc646 chore(ui): typegen 2025-07-31 06:38:48 +10:00
psychedelicious
ea8c877025 refactor(app): move client state persistence to own route, add queue_id 2025-07-31 06:38:48 +10:00
psychedelicious
7a3c2332dd feat(ui): add visual indicator when input field is added to form 2025-07-31 06:33:22 +10:00
psychedelicious
3835fd2f72 feat(ui): zhoosh image comparison ui 2025-07-30 07:20:47 -04:00
psychedelicious
6f8746040c docs(ui): update comments in readiness re: flux kontext via bfl api 2025-07-30 12:26:48 +10:00
psychedelicious
35e3940a09 feat(ui): update warning when using multiple ref images on BFL API kontext
It only supports 1 image.
2025-07-30 12:26:48 +10:00
psychedelicious
415616d83f feat(ui): support multiple kontext ref images in studio 2025-07-30 12:26:48 +10:00
psychedelicious
afb67efef9 chore(ui): typegen 2025-07-30 12:26:48 +10:00
psychedelicious
1ed1fefa60 feat(nodes): support multiple kontext ref images
Images are concatenated in latent space.
2025-07-30 12:26:48 +10:00
Ar7ific1al
fa94a05c77 Update CanvasStateApiModule.ts
Add temporary grid snap with ctrl, optional small step with ctrl+shift, while grid snap is off
2025-07-30 12:16:42 +10:00
psychedelicious
7a23d8266f feat(ui): simpler storage driver impl 2025-07-30 05:53:20 +10:00
psychedelicious
a44de079dd perf(ui): instantiate logger for storage error handler once 2025-07-30 05:53:20 +10:00
psychedelicious
c3c1a3edd8 chore(ui): typegen 2025-07-30 05:53:20 +10:00
psychedelicious
ea26b5b147 feat(app): client state persistence endpoints accept stringified data 2025-07-30 05:53:20 +10:00
Eugene Brodsky
4226b741b1 fix(docker) rocm 6.3 based image (#8152)
1. Fix the run script to properly read the GPU_DRIVER
2. Cloned and adjusted the ROCM dockerbuild for docker
3. Adjust the docker-compose.yml to use the cloned dockerbuild
2025-07-29 10:16:42 -04:00
Eugene Brodsky
1424b7c254 Merge branch 'main' into bugfix/heathen711/rocm-docker 2025-07-29 10:12:13 -04:00
psychedelicious
933fb2294c fix(ui): zod rejects any board id besides "none"
Turns out the string autocomplete TS hack does not translate to zod.
Widen the zod schema to any string, but use the hack for the TS type.
2025-07-29 08:45:16 -04:00
psychedelicious
5a181ee0fd build(ui): export loading component 2025-07-29 08:43:03 -04:00
psychedelicious
3b0d59e459 tests(app): update mm tests to test updated behaviour 2025-07-29 16:08:15 +10:00
psychedelicious
fec296e41d fix(app): move (not copy) models from install tmpdir to destination
It's not clear why we were copying downloaded models to the destination
dir instead of moving them. I cannot find a reason for it, and I am able
to install single-file and diffusers models just fine with the change.

This fixes an issue where model installation requires 2x the model's
size (bc we were copying the model over).
2025-07-29 16:08:15 +10:00
Heathen711
ae4e38c6d0 Merge branch 'main' into bugfix/heathen711/rocm-docker 2025-07-28 21:24:34 -07:00
psychedelicious
a9f3f1a4b2 fix(app): handle model files with periods in their name
Previously, we used pathlib's `with_suffix()` method to change add a
suffix (e.g. ".safetensors") to a model when installing it.

The intention is to add a suffix to the model's name - but that method
actually replaces everything after the first period.

This can cause different models to be installed under the same name!

For example, the FLUX models all end up with the same name:
- "FLUX.1 schnell.safetensors" -> "FLUX.safetensors"
- "FLUX.1 dev.safetensors" -> "FLUX.safetensors"

The fix is easy - append the suffix using string formatting instead of
using pathlib.

This issue has existed for a long time, but was exacerbated in
075345bffd in which I updated the names of
our starter models, adding ".1" to the FLUX model names. Whoops!
2025-07-29 14:15:59 +10:00
psychedelicious
8a73df4fe1 fix(ui): progress image does not hide on viewer with autoswitch disabled 2025-07-29 12:53:45 +10:00
psychedelicious
ea2e1ea8f0 fix(ui): queue count badge renders when left panel collapsed 2025-07-29 12:51:23 +10:00
psychedelicious
e8aa91931d fix(ui): connect metadata to output node for ext api nodes 2025-07-29 06:46:17 +10:00
psychedelicious
8d22a314a6 docs(ui): add some comments for race condition handling 2025-07-29 06:34:08 +10:00
psychedelicious
57ce2b8aa7 chore(ui): lint 2025-07-29 06:34:08 +10:00
psychedelicious
6b810cb3fb fix(ui): race condition w/ queue counts 2025-07-29 06:34:08 +10:00
psychedelicious
4f3a5dcc43 tidy(ui): remove unused progress related logic and components 2025-07-29 06:34:08 +10:00
psychedelicious
c3ae14cf73 fix(ui): ignore events for already-completed queue items 2025-07-29 06:34:08 +10:00
psychedelicious
b9c44b92d5 fix(ui): clear progress images from viewer at the right time 2025-07-29 06:34:08 +10:00
psychedelicious
5a68b4ddbc build(ui): skip logging ctx plugin when running tests 2025-07-29 06:31:30 +10:00
psychedelicious
18a722839b chore(ui): update knip conifg 2025-07-29 06:31:30 +10:00
psychedelicious
7370cb9be6 build(ui): add vite plugin to add relative file path to logger context 2025-07-29 06:31:30 +10:00
Kent Keirsey
cc4df52f82 feat: server-side client state persistence (#8314)
## Summary

Move client state persistence from browser to server.

- Add new client state persistence service to handle reading and writing
client state to db & associated router. The API mirrors that of
LocalStorage/IndexedDB where the set/get methods both operate on _keys_.
For example, when we persist the canvas state, we send only the new
canvas state to the backend - not the whole app state.
- The data is very flexibly-typed as a pydantic `JsonValue`. The client
is expected to handle all data parsing/validation (it must do this
anyways, and does this today).
- Change persistence from debounced to throttled at 2 seconds. Maybe
less is OK? Trying to not hammer the server.
- Add new persistence storage driver in client and use it in
redux-remember. It does its best to avoid extraneous persist requests,
caching the last data it persisted and noop-ing if there are no changes.
- Storage driver tracks pending persist actions using ref counts (bc
each slice is persisted independently). If there user navigates away
from the page during a persist request, it will give them the "you may
lose something if you navigate away" alert.
- This "lose something" alert message is not customizable (browser
security reasons).
- The alert is triggered only when the user closes the tape while a
persist network request is mid-flight. It's possible that the user makes
a change and closes the page before we start persisting. In this case,
they will lose the last 2 seconds of data.
- I tried making triggering the alert when a persist was waiting to
start, and it felt off.
- Maybe the alert isn't even necessary. Again you'd lose 2s of data at
most, probably a non issue. IMO after trying it, a subtle indicator
somewhere on the page is probably less confusing/intrusive.
- Fix an issue where the `redux-remember` enhancer was added _last_ in
the enhancer chain, which prevented us detecting when a persist has
succeeded. This required a small change to the `unserialze` utility
(used during rehydration) to ensure slices enhanced with `redux-undo`
are set up correctly as they are rehydrated.
- Restructure the redux store code to avoid circular dependencies. I
couldn't figure out how to do this without just smooshing it all into
the main `store.ts` file. Oh well.

Implications:
- Because client state is now on the server, different browsers will
have the same studio state. For example, if I start working on something
in Firefox, if I switch to Chrome, I have the same client state.
- Incognito windows won't do anything bc client state is server-side.
- It takes a bit longer for persistence to happen thanks to the
debounce, but there's now an indicator that tells you your stuff isn't
saved yet.
- Resetting the browser won't fix an issue with your studio state. You
must use `Reset Web UI` to fix it (or otherwise hit the appropriate
endpoint). It may be possible to end up in a Catch-22 where you can't
click the button and get stuck w/ a borked studio - I think to think
through this a bit more, might not be an issue.
- It probably takes a bit longer to start up, since we need to retrieve
client state over network instead of directly with browser APIs.

Other notes:
- We could explore adding an "incognito" mode, enabled via
`invokeai.yaml` setting or maybe in the UI. This would temporarily
disable persistence. Actually, I don't think this really makes sense, bc
all the images would be saved to disk.
- The studio state is stored in a single row in the DB. Currently, a
static row ID is used to force the studio state to be a singleton. It is
_possible_ to support multiple saved states. Might be a solve for app
workspaces.

## Related Issues / Discussions

n/a

## QA Instructions

Try it out. It's pretty straightforward. Error states are the main
things to test - for example, network blips. The new server-side
persistence driver is the only real functional change - everything else
is just kinda shuffling things around to support it.

## Merge Plan

n/a

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
2025-07-25 12:08:47 -04:00
Kent Keirsey
1cb4ef05a4 add newline 2025-07-25 11:08:54 -04:00
Kent Keirsey
7da141101c Merge branch 'main' into psyche/feat/app/client-state-persistence 2025-07-25 11:07:17 -04:00
psychedelicious
2571e199c5 tidy(ui): remove unused props 2025-07-25 11:06:18 -04:00
psychedelicious
79e93f905e fix(ui): add separate wrapper components for notes and current image nodes that do not need invocation node context 2025-07-25 11:06:18 -04:00
psychedelicious
f562e4f835 fix(ui): ensure all node context provider wraps all calls to useInvocationNodeContext 2025-07-25 11:06:18 -04:00
psychedelicious
47e220aaf3 perf(ui): imperatively get nodes and edges in autolayout hook 2025-07-25 11:06:18 -04:00
psychedelicious
9365154bfe chore: bump version to v6.2.0 2025-07-25 11:06:18 -04:00
psychedelicious
afc6911c96 chore: bump version to v6.3.0a1 2025-07-25 19:07:08 +10:00
psychedelicious
afa1ee7ffd tidy(ui): enable devmode redux checks 2025-07-25 19:04:21 +10:00
psychedelicious
5a102f6b53 chore(ui): lint 2025-07-25 19:04:21 +10:00
psychedelicious
af345a33f3 fix(ui): infinite loop when setting tile controlnet model 2025-07-25 19:04:21 +10:00
psychedelicious
038b110a82 fix(ui): do not store whole model configs in state 2025-07-25 19:04:21 +10:00
psychedelicious
f3cd49d46e refactor(ui): just manually validate async stuff 2025-07-25 19:04:21 +10:00
psychedelicious
ca7d7c9d93 refactor(ui): work around zod async validation issue 2025-07-25 19:04:21 +10:00
psychedelicious
1addeb4b59 fix(ui): check initial retrieval and set as last persisted 2025-07-25 19:04:21 +10:00
psychedelicious
6ea4884b0c chore(ui): bump zod to latest
Checking if it fixes an issue w/ async validators
2025-07-25 19:04:21 +10:00
psychedelicious
aed9b1013e refactor(ui): use zod for all redux state 2025-07-25 19:04:21 +10:00
psychedelicious
6962536b4a refactor(ui): use zod for all redux state (wip)
needed for confidence w/ state rehydration logic
2025-07-25 19:04:21 +10:00
psychedelicious
7e59d040aa feat(ui): iterate on storage api 2025-07-25 19:04:20 +10:00
psychedelicious
e7c67da2c2 refactor(ui): restructure persistence driver creation to support custom drivers 2025-07-25 19:04:20 +10:00
psychedelicious
c44571bc36 revert(ui): temp changes to main.tsx for testing 2025-07-25 19:04:20 +10:00
psychedelicious
ca257650d4 revert(ui): temp disable eslint rule 2025-07-25 19:04:20 +10:00
psychedelicious
6a9962d2bb git: update gitignore 2025-07-25 19:04:20 +10:00
psychedelicious
9492569a2c wip 2025-07-25 19:04:20 +10:00
psychedelicious
61e711620d chore: ruff 2025-07-25 19:04:20 +10:00
psychedelicious
3cf82505bb tests(app): service mocks 2025-07-25 19:04:20 +10:00
psychedelicious
53bcbc58f5 chore(ui): lint 2025-07-25 19:04:20 +10:00
psychedelicious
42f3990f7a refactor(ui): iterate on persistence 2025-07-25 19:04:20 +10:00
psychedelicious
456205da17 refactor(ui): iterate on persistence 2025-07-25 19:04:20 +10:00
psychedelicious
ca0684700e refactor(ui): alternate approach to slice configs 2025-07-25 19:04:19 +10:00
psychedelicious
6a702821ef chore(ui): typegen 2025-07-25 19:04:19 +10:00
psychedelicious
682d271f6f feat(api): make client state key query not body 2025-07-25 19:04:19 +10:00
psychedelicious
e872c253b1 refactor(ui): cleaner slice definitions 2025-07-25 19:04:19 +10:00
psychedelicious
28633c9983 feat: server-side client state persistence 2025-07-25 19:04:19 +10:00
psychedelicious
70ac58e64a tidy(ui): remove unused props 2025-07-25 18:51:21 +10:00
psychedelicious
e653837236 fix(ui): add separate wrapper components for notes and current image nodes that do not need invocation node context 2025-07-25 18:51:21 +10:00
psychedelicious
2bbfcc2f13 fix(ui): ensure all node context provider wraps all calls to useInvocationNodeContext 2025-07-25 18:51:21 +10:00
psychedelicious
d6e0e439c5 perf(ui): imperatively get nodes and edges in autolayout hook 2025-07-25 18:50:59 +10:00
psychedelicious
26aab60f81 chore: bump version to v6.2.0 2025-07-25 18:41:00 +10:00
Riccardo Giovanetti
7bea2fa11f translationBot(ui): update translation (Italian)
Currently translated at 98.6% (2016 of 2044 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.6% (2015 of 2043 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2025-07-25 17:15:01 +10:00
Heathen711
1cdd4b5980 bugfix(docs) link syntax 2025-07-17 04:26:06 +00:00
Heathen711
89ceecc870 bugfix(docker) Ensure the correct extra install. 2025-07-17 04:19:22 +00:00
Heathen711
687cccdb99 cleanup(docker) 2025-07-17 04:00:42 +00:00
Heathen711
c84f8465b8 bugfix(pyproject) Convert from dependency groups to extras and update docks to use UV's built in torch support 2025-07-17 03:58:26 +00:00
Heathen711
4b5c481b7a Merge remote-tracking branch 'origin' into bugfix/heathen711/rocm-docker 2025-07-17 01:03:03 +00:00
Heathen711
2caa1b166d Merge remote-tracking branch 'origin' into bugfix/heathen711/rocm-docker 2025-07-13 00:55:39 +00:00
Heathen711
1b6ebede7b Revert "cleanup(github actions)"
This reverts commit 017d38eee2.
2025-07-10 21:10:56 +00:00
Heathen711
017d38eee2 cleanup(github actions) 2025-07-10 21:04:48 +00:00
Heathen711
78eb6b0338 cleanup(docker) 2025-07-10 21:03:57 +00:00
Heathen711
3e8e0f6ddf Merge remote-tracking branch 'origin' into bugfix/heathen711/rocm-docker 2025-07-10 20:14:27 +00:00
Heathen711
8213f62d3b bugfix(docker) render group controls the devices, but it needs to match the host's render group ID 2025-07-09 20:20:59 +00:00
Heathen711
233740a40e Merge remote-tracking branch 'origin' into bugfix/heathen711/rocm-docker 2025-07-09 03:27:42 +00:00
Heathen711
8c5fcfd0fd cleanup(docker) remove no cache argument 2025-07-05 15:25:26 +00:00
Heathen711
6d7b231196 Merge remote-tracking branch 'origin' into bugfix/heathen711/rocm-docker 2025-07-05 15:22:35 +00:00
Heathen711
31ca314b02 Missed files 2025-07-05 15:21:46 +00:00
Heathen711
0db304f1ee bugfix(uv) Lock torchvision and ensure the docker uses the same rocm version 2025-07-05 03:35:11 +00:00
Heathen711
a3cb3e03f4 bugfix(ci) Clean up more space for typegen check 2025-07-03 21:22:11 +00:00
Heathen711
641a6cfdb7 bugfix(docker) Remove the need for UV index as that is now baked into the uv.lock 2025-07-03 21:15:03 +00:00
Heathen711
f27471cea7 bugfix(docker): Use uv.lock for docker, and update to newer index urls. 2025-07-03 20:08:28 +00:00
Heathen711
47508b8d6c bugfix(docker) combined the dockerfiles and reduced image size 2025-07-03 06:01:51 +00:00
Heathen711
28e0242907 Fix tagging & remove force reinstall 2025-07-03 01:56:46 +00:00
Heathen711
96523ca01f fix(docker) Add cloned dockerbuild 2025-06-29 22:07:11 +00:00
Heathen711
c10a6fdab1 fix(docker) rocm 2.4.6 based image 2025-06-29 22:02:40 +00:00
186 changed files with 5111 additions and 2351 deletions

View File

@@ -39,6 +39,18 @@ jobs:
- name: checkout
uses: actions/checkout@v4
- name: Free up more disk space on the runner
# https://github.com/actions/runner-images/issues/2840#issuecomment-1284059930
run: |
echo "----- Free space before cleanup"
df -h
sudo rm -rf /usr/share/dotnet
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
sudo swapoff /mnt/swapfile
sudo rm -rf /mnt/swapfile
echo "----- Free space after cleanup"
df -h
- name: check for changed files
if: ${{ inputs.always_run != true }}
id: changed-files

View File

@@ -22,6 +22,10 @@
## GPU_DRIVER can be set to either `cuda` or `rocm` to enable GPU support in the container accordingly.
# GPU_DRIVER=cuda #| rocm
## If you are using ROCM, you will need to ensure that the render group within the container and the host system use the same group ID.
## To obtain the group ID of the render group on the host system, run `getent group render` and grab the number.
# RENDER_GROUP_ID=
## CONTAINER_UID can be set to the UID of the user on the host system that should own the files in the container.
## It is usually not necessary to change this. Use `id -u` on the host system to find the UID.
# CONTAINER_UID=1000

View File

@@ -43,7 +43,6 @@ ENV \
UV_MANAGED_PYTHON=1 \
UV_LINK_MODE=copy \
UV_PROJECT_ENVIRONMENT=/opt/venv \
UV_INDEX="https://download.pytorch.org/whl/cu124" \
INVOKEAI_ROOT=/invokeai \
INVOKEAI_HOST=0.0.0.0 \
INVOKEAI_PORT=9090 \
@@ -74,19 +73,17 @@ RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,source=uv.lock,target=uv.lock \
# this is just to get the package manager to recognize that the project exists, without making changes to the docker layer
--mount=type=bind,source=invokeai/version,target=invokeai/version \
if [ "$TARGETPLATFORM" = "linux/arm64" ] || [ "$GPU_DRIVER" = "cpu" ]; then UV_INDEX="https://download.pytorch.org/whl/cpu"; \
elif [ "$GPU_DRIVER" = "rocm" ]; then UV_INDEX="https://download.pytorch.org/whl/rocm6.2"; \
fi && \
uv sync --frozen
# build patchmatch
RUN cd /usr/lib/$(uname -p)-linux-gnu/pkgconfig/ && ln -sf opencv4.pc opencv.pc
RUN python -c "from patchmatch import patch_match"
ulimit -n 30000 && \
uv sync --extra $GPU_DRIVER --frozen
# Link amdgpu.ids for ROCm builds
# contributed by https://github.com/Rubonnek
RUN mkdir -p "/opt/amdgpu/share/libdrm" &&\
ln -s "/usr/share/libdrm/amdgpu.ids" "/opt/amdgpu/share/libdrm/amdgpu.ids"
ln -s "/usr/share/libdrm/amdgpu.ids" "/opt/amdgpu/share/libdrm/amdgpu.ids" && groupadd render
# build patchmatch
RUN cd /usr/lib/$(uname -p)-linux-gnu/pkgconfig/ && ln -sf opencv4.pc opencv.pc
RUN python -c "from patchmatch import patch_match"
RUN mkdir -p ${INVOKEAI_ROOT} && chown -R ${CONTAINER_UID}:${CONTAINER_GID} ${INVOKEAI_ROOT}
@@ -105,8 +102,6 @@ COPY invokeai ${INVOKEAI_SRC}/invokeai
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
--mount=type=bind,source=uv.lock,target=uv.lock \
if [ "$TARGETPLATFORM" = "linux/arm64" ] || [ "$GPU_DRIVER" = "cpu" ]; then UV_INDEX="https://download.pytorch.org/whl/cpu"; \
elif [ "$GPU_DRIVER" = "rocm" ]; then UV_INDEX="https://download.pytorch.org/whl/rocm6.2"; \
fi && \
uv pip install -e .
ulimit -n 30000 && \
uv pip install -e .[$GPU_DRIVER]

136
docker/Dockerfile-rocm-full Normal file
View File

@@ -0,0 +1,136 @@
# syntax=docker/dockerfile:1.4
#### Web UI ------------------------------------
FROM docker.io/node:22-slim AS web-builder
ENV PNPM_HOME="/pnpm"
ENV PATH="$PNPM_HOME:$PATH"
RUN corepack use pnpm@8.x
RUN corepack enable
WORKDIR /build
COPY invokeai/frontend/web/ ./
RUN --mount=type=cache,target=/pnpm/store \
pnpm install --frozen-lockfile
RUN npx vite build
## Backend ---------------------------------------
FROM library/ubuntu:24.04
ARG DEBIAN_FRONTEND=noninteractive
RUN rm -f /etc/apt/apt.conf.d/docker-clean; echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' > /etc/apt/apt.conf.d/keep-cache
RUN --mount=type=cache,target=/var/cache/apt \
--mount=type=cache,target=/var/lib/apt \
apt update && apt install -y --no-install-recommends \
ca-certificates \
git \
gosu \
libglib2.0-0 \
libgl1 \
libglx-mesa0 \
build-essential \
libopencv-dev \
libstdc++-10-dev \
wget
ENV \
PYTHONUNBUFFERED=1 \
PYTHONDONTWRITEBYTECODE=1 \
VIRTUAL_ENV=/opt/venv \
INVOKEAI_SRC=/opt/invokeai \
PYTHON_VERSION=3.12 \
UV_PYTHON=3.12 \
UV_COMPILE_BYTECODE=1 \
UV_MANAGED_PYTHON=1 \
UV_LINK_MODE=copy \
UV_PROJECT_ENVIRONMENT=/opt/venv \
INVOKEAI_ROOT=/invokeai \
INVOKEAI_HOST=0.0.0.0 \
INVOKEAI_PORT=9090 \
PATH="/opt/venv/bin:$PATH" \
CONTAINER_UID=${CONTAINER_UID:-1000} \
CONTAINER_GID=${CONTAINER_GID:-1000}
ARG GPU_DRIVER=cuda
# Install `uv` for package management
COPY --from=ghcr.io/astral-sh/uv:0.6.9 /uv /uvx /bin/
# Install python & allow non-root user to use it by traversing the /root dir without read permissions
RUN --mount=type=cache,target=/root/.cache/uv \
uv python install ${PYTHON_VERSION} && \
# chmod --recursive a+rX /root/.local/share/uv/python
chmod 711 /root
WORKDIR ${INVOKEAI_SRC}
# Install project's dependencies as a separate layer so they aren't rebuilt every commit.
# bind-mount instead of copy to defer adding sources to the image until next layer.
#
# NOTE: there are no pytorch builds for arm64 + cuda, only cpu
# x86_64/CUDA is the default
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
--mount=type=bind,source=uv.lock,target=uv.lock \
# this is just to get the package manager to recognize that the project exists, without making changes to the docker layer
--mount=type=bind,source=invokeai/version,target=invokeai/version \
ulimit -n 30000 && \
uv sync --extra $GPU_DRIVER --frozen
RUN --mount=type=cache,target=/var/cache/apt \
--mount=type=cache,target=/var/lib/apt \
if [ "$GPU_DRIVER" = "rocm" ]; then \
wget -O /tmp/amdgpu-install.deb \
https://repo.radeon.com/amdgpu-install/6.3.4/ubuntu/noble/amdgpu-install_6.3.60304-1_all.deb && \
apt install -y /tmp/amdgpu-install.deb && \
apt update && \
amdgpu-install --usecase=rocm -y && \
apt-get autoclean && \
apt clean && \
rm -rf /tmp/* /var/tmp/* && \
usermod -a -G render ubuntu && \
usermod -a -G video ubuntu && \
echo "\\n/opt/rocm/lib\\n/opt/rocm/lib64" >> /etc/ld.so.conf.d/rocm.conf && \
ldconfig && \
update-alternatives --auto rocm; \
fi
## Heathen711: Leaving this for review input, will remove before merge
# RUN --mount=type=cache,target=/var/cache/apt \
# --mount=type=cache,target=/var/lib/apt \
# if [ "$GPU_DRIVER" = "rocm" ]; then \
# groupadd render && \
# usermod -a -G render ubuntu && \
# usermod -a -G video ubuntu; \
# fi
## Link amdgpu.ids for ROCm builds
## contributed by https://github.com/Rubonnek
# RUN mkdir -p "/opt/amdgpu/share/libdrm" &&\
# ln -s "/usr/share/libdrm/amdgpu.ids" "/opt/amdgpu/share/libdrm/amdgpu.ids"
# build patchmatch
RUN cd /usr/lib/$(uname -p)-linux-gnu/pkgconfig/ && ln -sf opencv4.pc opencv.pc
RUN python -c "from patchmatch import patch_match"
RUN mkdir -p ${INVOKEAI_ROOT} && chown -R ${CONTAINER_UID}:${CONTAINER_GID} ${INVOKEAI_ROOT}
COPY docker/docker-entrypoint.sh ./
ENTRYPOINT ["/opt/invokeai/docker-entrypoint.sh"]
CMD ["invokeai-web"]
# --link requires buldkit w/ dockerfile syntax 1.4, does not work with podman
COPY --link --from=web-builder /build/dist ${INVOKEAI_SRC}/invokeai/frontend/web/dist
# add sources last to minimize image changes on code changes
COPY invokeai ${INVOKEAI_SRC}/invokeai
# this should not increase image size because we've already installed dependencies
# in a previous layer
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
--mount=type=bind,source=uv.lock,target=uv.lock \
ulimit -n 30000 && \
uv pip install -e .[$GPU_DRIVER]

View File

@@ -47,8 +47,9 @@ services:
invokeai-rocm:
<<: *invokeai
devices:
- /dev/kfd:/dev/kfd
- /dev/dri:/dev/dri
environment:
- AMD_VISIBLE_DEVICES=all
- RENDER_GROUP_ID=${RENDER_GROUP_ID}
runtime: amd
profiles:
- rocm

View File

@@ -21,6 +21,17 @@ _=$(id ${USER} 2>&1) || useradd -u ${USER_ID} ${USER}
# ensure the UID is correct
usermod -u ${USER_ID} ${USER} 1>/dev/null
## ROCM specific configuration
# render group within the container must match the host render group
# otherwise the container will not be able to access the host GPU.
if [[ -v "RENDER_GROUP_ID" ]] && [[ ! -z "${RENDER_GROUP_ID}" ]]; then
# ensure the render group exists
groupmod -g ${RENDER_GROUP_ID} render
usermod -a -G render ${USER}
usermod -a -G video ${USER}
fi
### Set the $PUBLIC_KEY env var to enable SSH access.
# We do not install openssh-server in the image by default to avoid bloat.
# but it is useful to have the full SSH server e.g. on Runpod.

View File

@@ -13,7 +13,7 @@ run() {
# parse .env file for build args
build_args=$(awk '$1 ~ /=[^$]/ && $0 !~ /^#/ {print "--build-arg " $0 " "}' .env) &&
profile="$(awk -F '=' '/GPU_DRIVER/ {print $2}' .env)"
profile="$(awk -F '=' '/GPU_DRIVER=/ {print $2}' .env)"
# default to 'cuda' profile
[[ -z "$profile" ]] && profile="cuda"
@@ -30,7 +30,7 @@ run() {
printf "%s\n" "starting service $service_name"
docker compose --profile "$profile" up -d "$service_name"
docker compose logs -f
docker compose --profile "$profile" logs -f
}
run

View File

@@ -69,34 +69,34 @@ The following commands vary depending on the version of Invoke being installed a
- If you have an Nvidia 20xx series GPU or older, use `invokeai[xformers]`.
- If you have an Nvidia 30xx series GPU or newer, or do not have an Nvidia GPU, use `invokeai`.
7. Determine the `PyPI` index URL to use for installation, if any. This is necessary to get the right version of torch installed.
7. Determine the torch backend to use for installation, if any. This is necessary to get the right version of torch installed. This is acheived by using [UV's built in torch support.](https://docs.astral.sh/uv/guides/integration/pytorch/#automatic-backend-selection)
=== "Invoke v5.12 and later"
- If you are on Windows or Linux with an Nvidia GPU, use `https://download.pytorch.org/whl/cu128`.
- If you are on Linux with no GPU, use `https://download.pytorch.org/whl/cpu`.
- If you are on Linux with an AMD GPU, use `https://download.pytorch.org/whl/rocm6.2.4`.
- **In all other cases, do not use an index.**
- If you are on Windows or Linux with an Nvidia GPU, use `--torch-backend=cu128`.
- If you are on Linux with no GPU, use `--torch-backend=cpu`.
- If you are on Linux with an AMD GPU, use `--torch-backend=rocm6.3`.
- **In all other cases, do not use a torch backend.**
=== "Invoke v5.10.0 to v5.11.0"
- If you are on Windows or Linux with an Nvidia GPU, use `https://download.pytorch.org/whl/cu126`.
- If you are on Linux with no GPU, use `https://download.pytorch.org/whl/cpu`.
- If you are on Linux with an AMD GPU, use `https://download.pytorch.org/whl/rocm6.2.4`.
- If you are on Windows or Linux with an Nvidia GPU, use `--torch-backend=cu126`.
- If you are on Linux with no GPU, use `--torch-backend=cpu`.
- If you are on Linux with an AMD GPU, use `--torch-backend=rocm6.2.4`.
- **In all other cases, do not use an index.**
=== "Invoke v5.0.0 to v5.9.1"
- If you are on Windows with an Nvidia GPU, use `https://download.pytorch.org/whl/cu124`.
- If you are on Linux with no GPU, use `https://download.pytorch.org/whl/cpu`.
- If you are on Linux with an AMD GPU, use `https://download.pytorch.org/whl/rocm6.1`.
- If you are on Windows with an Nvidia GPU, use `--torch-backend=cu124`.
- If you are on Linux with no GPU, use `--torch-backend=cpu`.
- If you are on Linux with an AMD GPU, use `--torch-backend=rocm6.1`.
- **In all other cases, do not use an index.**
=== "Invoke v4"
- If you are on Windows with an Nvidia GPU, use `https://download.pytorch.org/whl/cu124`.
- If you are on Linux with no GPU, use `https://download.pytorch.org/whl/cpu`.
- If you are on Linux with an AMD GPU, use `https://download.pytorch.org/whl/rocm5.2`.
- If you are on Windows with an Nvidia GPU, use `--torch-backend=cu124`.
- If you are on Linux with no GPU, use `--torch-backend=cpu`.
- If you are on Linux with an AMD GPU, use `--torch-backend=rocm5.2`.
- **In all other cases, do not use an index.**
8. Install the `invokeai` package. Substitute the package specifier and version.
@@ -105,10 +105,10 @@ The following commands vary depending on the version of Invoke being installed a
uv pip install <PACKAGE_SPECIFIER>==<VERSION> --python 3.12 --python-preference only-managed --force-reinstall
```
If you determined you needed to use a `PyPI` index URL in the previous step, you'll need to add `--index=<INDEX_URL>` like this:
If you determined you needed to use a torch backend in the previous step, you'll need to set the backend like this:
```sh
uv pip install <PACKAGE_SPECIFIER>==<VERSION> --python 3.12 --python-preference only-managed --index=<INDEX_URL> --force-reinstall
uv pip install <PACKAGE_SPECIFIER>==<VERSION> --python 3.12 --python-preference only-managed --torch-backend=<VERSION> --force-reinstall
```
9. Deactivate and reactivate your venv so that the invokeai-specific commands become available in the environment:

View File

@@ -10,6 +10,7 @@ from invokeai.app.services.board_images.board_images_default import BoardImagesS
from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage
from invokeai.app.services.boards.boards_default import BoardService
from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService
from invokeai.app.services.client_state_persistence.client_state_persistence_sqlite import ClientStatePersistenceSqlite
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.download.download_default import DownloadQueueService
from invokeai.app.services.events.events_fastapievents import FastAPIEventService
@@ -151,6 +152,7 @@ class ApiDependencies:
style_preset_records = SqliteStylePresetRecordsStorage(db=db)
style_preset_image_files = StylePresetImageFileStorageDisk(style_presets_folder / "images")
workflow_thumbnails = WorkflowThumbnailFileStorageDisk(workflow_thumbnails_folder)
client_state_persistence = ClientStatePersistenceSqlite(db=db)
services = InvocationServices(
board_image_records=board_image_records,
@@ -181,6 +183,7 @@ class ApiDependencies:
style_preset_records=style_preset_records,
style_preset_image_files=style_preset_image_files,
workflow_thumbnails=workflow_thumbnails,
client_state_persistence=client_state_persistence,
)
ApiDependencies.invoker = Invoker(services)

View File

@@ -0,0 +1,58 @@
from fastapi import Body, HTTPException, Path, Query
from fastapi.routing import APIRouter
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.backend.util.logging import logging
client_state_router = APIRouter(prefix="/v1/client_state", tags=["client_state"])
@client_state_router.get(
"/{queue_id}/get_by_key",
operation_id="get_client_state_by_key",
response_model=str | None,
)
async def get_client_state_by_key(
queue_id: str = Path(description="The queue id to perform this operation on"),
key: str = Query(..., description="Key to get"),
) -> str | None:
"""Gets the client state"""
try:
return ApiDependencies.invoker.services.client_state_persistence.get_by_key(queue_id, key)
except Exception as e:
logging.error(f"Error getting client state: {e}")
raise HTTPException(status_code=500, detail="Error setting client state")
@client_state_router.post(
"/{queue_id}/set_by_key",
operation_id="set_client_state",
response_model=str,
)
async def set_client_state(
queue_id: str = Path(description="The queue id to perform this operation on"),
key: str = Query(..., description="Key to set"),
value: str = Body(..., description="Stringified value to set"),
) -> str:
"""Sets the client state"""
try:
return ApiDependencies.invoker.services.client_state_persistence.set_by_key(queue_id, key, value)
except Exception as e:
logging.error(f"Error setting client state: {e}")
raise HTTPException(status_code=500, detail="Error setting client state")
@client_state_router.post(
"/{queue_id}/delete",
operation_id="delete_client_state",
responses={204: {"description": "Client state deleted"}},
)
async def delete_client_state(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> None:
"""Deletes the client state"""
try:
ApiDependencies.invoker.services.client_state_persistence.delete(queue_id)
except Exception as e:
logging.error(f"Error deleting client state: {e}")
raise HTTPException(status_code=500, detail="Error deleting client state")

View File

@@ -19,6 +19,7 @@ from invokeai.app.api.routers import (
app_info,
board_images,
boards,
client_state,
download_queue,
images,
model_manager,
@@ -131,6 +132,7 @@ app.include_router(app_info.app_router, prefix="/api")
app.include_router(session_queue.session_queue_router, prefix="/api")
app.include_router(workflows.workflows_router, prefix="/api")
app.include_router(style_presets.style_presets_router, prefix="/api")
app.include_router(client_state.client_state_router, prefix="/api")
app.openapi = get_openapi_func(app)
@@ -155,6 +157,12 @@ def overridden_redoc() -> HTMLResponse:
web_root_path = Path(list(web_dir.__path__)[0])
if app_config.unsafe_disable_picklescan:
logger.warning(
"The unsafe_disable_picklescan option is enabled. This disables malware scanning while installing and"
"loading models, which may allow malicious code to be executed. Use at your own risk."
)
try:
app.mount("/", NoCacheStaticFiles(directory=Path(web_root_path, "dist"), html=True), name="ui")
except RuntimeError:

View File

@@ -63,7 +63,7 @@ from invokeai.backend.util.devices import TorchDevice
title="FLUX Denoise",
tags=["image", "flux"],
category="image",
version="4.0.0",
version="4.1.0",
)
class FluxDenoiseInvocation(BaseInvocation):
"""Run denoising process with a FLUX transformer model."""
@@ -153,7 +153,7 @@ class FluxDenoiseInvocation(BaseInvocation):
description=FieldDescriptions.ip_adapter, title="IP-Adapter", default=None, input=Input.Connection
)
kontext_conditioning: Optional[FluxKontextConditioningField] = InputField(
kontext_conditioning: FluxKontextConditioningField | list[FluxKontextConditioningField] | None = InputField(
default=None,
description="FLUX Kontext conditioning (reference image).",
input=Input.Connection,
@@ -386,13 +386,15 @@ class FluxDenoiseInvocation(BaseInvocation):
)
kontext_extension = None
if self.kontext_conditioning is not None:
if self.kontext_conditioning:
if not self.controlnet_vae:
raise ValueError("A VAE (e.g., controlnet_vae) must be provided to use Kontext conditioning.")
kontext_extension = KontextExtension(
context=context,
kontext_conditioning=self.kontext_conditioning,
kontext_conditioning=self.kontext_conditioning
if isinstance(self.kontext_conditioning, list)
else [self.kontext_conditioning],
vae_field=self.controlnet_vae,
device=TorchDevice.choose_torch_device(),
dtype=inference_dtype,

View File

@@ -1347,3 +1347,96 @@ class PasteImageIntoBoundingBoxInvocation(BaseInvocation, WithMetadata, WithBoar
image_dto = context.images.save(image=target_image)
return ImageOutput.build(image_dto)
@invocation(
"flux_kontext_image_prep",
title="FLUX Kontext Image Prep",
tags=["image", "concatenate", "flux", "kontext"],
category="image",
version="1.0.0",
)
class FluxKontextConcatenateImagesInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Prepares an image or images for use with FLUX Kontext. The first/single image is resized to the nearest
preferred Kontext resolution. All other images are concatenated horizontally, maintaining their aspect ratio."""
images: list[ImageField] = InputField(
description="The images to concatenate",
min_length=1,
max_length=10,
)
use_preferred_resolution: bool = InputField(
default=True, description="Use FLUX preferred resolutions for the first image"
)
def invoke(self, context: InvocationContext) -> ImageOutput:
from invokeai.backend.flux.util import PREFERED_KONTEXT_RESOLUTIONS
# Step 1: Load all images
pil_images = []
for image_field in self.images:
image = context.images.get_pil(image_field.image_name, mode="RGBA")
pil_images.append(image)
# Step 2: Determine target resolution for the first image
first_image = pil_images[0]
width, height = first_image.size
if self.use_preferred_resolution:
aspect_ratio = width / height
# Find the closest preferred resolution for the first image
_, target_width, target_height = min(
((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS), key=lambda x: x[0]
)
# Apply BFL's scaling formula
scaled_height = 2 * int(target_height / 16)
final_height = 8 * scaled_height # This will be consistent for all images
scaled_width = 2 * int(target_width / 16)
first_width = 8 * scaled_width
else:
# Use original dimensions of first image, ensuring divisibility by 16
final_height = 16 * (height // 16)
first_width = 16 * (width // 16)
# Ensure minimum dimensions
if final_height < 16:
final_height = 16
if first_width < 16:
first_width = 16
# Step 3: Process and resize all images with consistent height
processed_images = []
total_width = 0
for i, image in enumerate(pil_images):
if i == 0:
# First image uses the calculated dimensions
final_width = first_width
else:
# Subsequent images maintain aspect ratio with the same height
img_aspect_ratio = image.width / image.height
# Calculate width that maintains aspect ratio at the target height
calculated_width = int(final_height * img_aspect_ratio)
# Ensure width is divisible by 16 for proper VAE encoding
final_width = 16 * (calculated_width // 16)
# Ensure minimum width
if final_width < 16:
final_width = 16
# Resize image to calculated dimensions
resized_image = image.resize((final_width, final_height), Image.Resampling.LANCZOS)
processed_images.append(resized_image)
total_width += final_width
# Step 4: Concatenate images horizontally
concatenated_image = Image.new("RGB", (total_width, final_height))
x_offset = 0
for img in processed_images:
concatenated_image.paste(img, (x_offset, 0))
x_offset += img.width
# Save the concatenated image
image_dto = context.images.save(image=concatenated_image)
return ImageOutput.build(image_dto)

View File

@@ -0,0 +1,42 @@
from abc import ABC, abstractmethod
class ClientStatePersistenceABC(ABC):
"""
Base class for client persistence implementations.
This class defines the interface for persisting client data.
"""
@abstractmethod
def set_by_key(self, queue_id: str, key: str, value: str) -> str:
"""
Set a key-value pair for the client.
Args:
key (str): The key to set.
value (str): The value to set for the key.
Returns:
str: The value that was set.
"""
pass
@abstractmethod
def get_by_key(self, queue_id: str, key: str) -> str | None:
"""
Get the value for a specific key of the client.
Args:
key (str): The key to retrieve the value for.
Returns:
str | None: The value associated with the key, or None if the key does not exist.
"""
pass
@abstractmethod
def delete(self, queue_id: str) -> None:
"""
Delete all client state.
"""
pass

View File

@@ -0,0 +1,65 @@
import json
from invokeai.app.services.client_state_persistence.client_state_persistence_base import ClientStatePersistenceABC
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class ClientStatePersistenceSqlite(ClientStatePersistenceABC):
"""
Base class for client persistence implementations.
This class defines the interface for persisting client data.
"""
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._db = db
self._default_row_id = 1
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
def _get(self) -> dict[str, str] | None:
with self._db.transaction() as cursor:
cursor.execute(
f"""
SELECT data FROM client_state
WHERE id = {self._default_row_id}
"""
)
row = cursor.fetchone()
if row is None:
return None
return json.loads(row[0])
def set_by_key(self, queue_id: str, key: str, value: str) -> str:
state = self._get() or {}
state.update({key: value})
with self._db.transaction() as cursor:
cursor.execute(
f"""
INSERT INTO client_state (id, data)
VALUES ({self._default_row_id}, ?)
ON CONFLICT(id) DO UPDATE
SET data = excluded.data;
""",
(json.dumps(state),),
)
return value
def get_by_key(self, queue_id: str, key: str) -> str | None:
state = self._get()
if state is None:
return None
return state.get(key, None)
def delete(self, queue_id: str) -> None:
with self._db.transaction() as cursor:
cursor.execute(
f"""
DELETE FROM client_state
WHERE id = {self._default_row_id}
"""
)

View File

@@ -107,6 +107,7 @@ class InvokeAIAppConfig(BaseSettings):
hashing_algorithm: Model hashing algorthim for model installs. 'blake3_multi' is best for SSDs. 'blake3_single' is best for spinning disk HDDs. 'random' disables hashing, instead assigning a UUID to models. Useful when using a memory db to reduce model installation time, or if you don't care about storing stable hashes for models. Alternatively, any other hashlib algorithm is accepted, though these are not nearly as performant as blake3.<br>Valid values: `blake3_multi`, `blake3_single`, `random`, `md5`, `sha1`, `sha224`, `sha256`, `sha384`, `sha512`, `blake2b`, `blake2s`, `sha3_224`, `sha3_256`, `sha3_384`, `sha3_512`, `shake_128`, `shake_256`
remote_api_tokens: List of regular expression and token pairs used when downloading models from URLs. The download URL is tested against the regex, and if it matches, the token is provided in as a Bearer token.
scan_models_on_startup: Scan the models directory on startup, registering orphaned models. This is typically only used in conjunction with `use_memory_db` for testing purposes.
unsafe_disable_picklescan: UNSAFE. Disable the picklescan security check during model installation. Recommended only for development and testing purposes. This will allow arbitrary code execution during model installation, so should never be used in production.
"""
_root: Optional[Path] = PrivateAttr(default=None)
@@ -196,6 +197,7 @@ class InvokeAIAppConfig(BaseSettings):
hashing_algorithm: HASHING_ALGORITHMS = Field(default="blake3_single", description="Model hashing algorthim for model installs. 'blake3_multi' is best for SSDs. 'blake3_single' is best for spinning disk HDDs. 'random' disables hashing, instead assigning a UUID to models. Useful when using a memory db to reduce model installation time, or if you don't care about storing stable hashes for models. Alternatively, any other hashlib algorithm is accepted, though these are not nearly as performant as blake3.")
remote_api_tokens: Optional[list[URLRegexTokenPair]] = Field(default=None, description="List of regular expression and token pairs used when downloading models from URLs. The download URL is tested against the regex, and if it matches, the token is provided in as a Bearer token.")
scan_models_on_startup: bool = Field(default=False, description="Scan the models directory on startup, registering orphaned models. This is typically only used in conjunction with `use_memory_db` for testing purposes.")
unsafe_disable_picklescan: bool = Field(default=False, description="UNSAFE. Disable the picklescan security check during model installation. Recommended only for development and testing purposes. This will allow arbitrary code execution during model installation, so should never be used in production.")
# fmt: on

View File

@@ -17,6 +17,7 @@ if TYPE_CHECKING:
from invokeai.app.services.board_records.board_records_base import BoardRecordStorageBase
from invokeai.app.services.boards.boards_base import BoardServiceABC
from invokeai.app.services.bulk_download.bulk_download_base import BulkDownloadBase
from invokeai.app.services.client_state_persistence.client_state_persistence_base import ClientStatePersistenceABC
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadQueueServiceBase
from invokeai.app.services.events.events_base import EventServiceBase
@@ -73,6 +74,7 @@ class InvocationServices:
style_preset_records: "StylePresetRecordsStorageBase",
style_preset_image_files: "StylePresetImageFileStorageBase",
workflow_thumbnails: "WorkflowThumbnailServiceBase",
client_state_persistence: "ClientStatePersistenceABC",
):
self.board_images = board_images
self.board_image_records = board_image_records
@@ -102,3 +104,4 @@ class InvocationServices:
self.style_preset_records = style_preset_records
self.style_preset_image_files = style_preset_image_files
self.workflow_thumbnails = workflow_thumbnails
self.client_state_persistence = client_state_persistence

View File

@@ -7,7 +7,7 @@ import threading
import time
from pathlib import Path
from queue import Empty, Queue
from shutil import copyfile, copytree, move, rmtree
from shutil import move, rmtree
from tempfile import mkdtemp
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
@@ -186,13 +186,14 @@ class ModelInstallService(ModelInstallServiceBase):
info: AnyModelConfig = self._probe(Path(model_path), config) # type: ignore
if preferred_name := config.name:
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
# Careful! Don't use pathlib.Path(...).with_suffix - it can will strip everything after the first dot.
preferred_name = f"{preferred_name}{model_path.suffix}"
dest_path = (
self.app_config.models_path / info.base.value / info.type.value / (preferred_name or model_path.name)
)
try:
new_path = self._copy_model(model_path, dest_path)
new_path = self._move_model(model_path, dest_path)
except FileExistsError as excp:
raise DuplicateModelException(
f"A model named {model_path.name} is already installed at {dest_path.as_posix()}"
@@ -617,16 +618,6 @@ class ModelInstallService(ModelInstallServiceBase):
self.record_store.update_model(key, ModelRecordChanges(path=model.path))
return model
def _copy_model(self, old_path: Path, new_path: Path) -> Path:
if old_path == new_path:
return old_path
new_path.parent.mkdir(parents=True, exist_ok=True)
if old_path.is_dir():
copytree(old_path, new_path)
else:
copyfile(old_path, new_path)
return new_path
def _move_model(self, old_path: Path, new_path: Path) -> Path:
if old_path == new_path:
return old_path

View File

@@ -87,9 +87,21 @@ class ModelLoadService(ModelLoadServiceBase):
def torch_load_file(checkpoint: Path) -> AnyModel:
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0:
raise Exception(f"The model at {checkpoint} is potentially infected by malware. Aborting load.")
if self._app_config.unsafe_disable_picklescan:
self._logger.warning(
f"Model at {checkpoint} is potentially infected by malware, but picklescan is disabled. "
"Proceeding with caution."
)
else:
raise Exception(f"The model at {checkpoint} is potentially infected by malware. Aborting load.")
if scan_result.scan_err:
raise Exception(f"Error scanning model at {checkpoint} for malware. Aborting load.")
if self._app_config.unsafe_disable_picklescan:
self._logger.warning(
f"Error scanning model at {checkpoint} for malware, but picklescan is disabled. "
"Proceeding with caution."
)
else:
raise Exception(f"Error scanning model at {checkpoint} for malware. Aborting load.")
result = torch_load(checkpoint, map_location="cpu")
return result

View File

@@ -23,6 +23,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_17 import
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_18 import build_migration_18
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_19 import build_migration_19
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_20 import build_migration_20
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_21 import build_migration_21
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@@ -63,6 +64,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_18())
migrator.register_migration(build_migration_19(app_config=config))
migrator.register_migration(build_migration_20())
migrator.register_migration(build_migration_21())
migrator.run_migrations()
return db

View File

@@ -0,0 +1,40 @@
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration21Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
cursor.execute(
"""
CREATE TABLE client_state (
id INTEGER PRIMARY KEY CHECK(id = 1),
data TEXT NOT NULL, -- Frontend will handle the shape of this data
updated_at DATETIME NOT NULL DEFAULT (CURRENT_TIMESTAMP)
);
"""
)
cursor.execute(
"""
CREATE TRIGGER tg_client_state_updated_at
AFTER UPDATE ON client_state
FOR EACH ROW
BEGIN
UPDATE client_state
SET updated_at = CURRENT_TIMESTAMP
WHERE id = OLD.id;
END;
"""
)
def build_migration_21() -> Migration:
"""Builds the migration object for migrating from version 20 to version 21. This includes:
- Creating the `client_state` table.
- Adding a trigger to update the `updated_at` field on updates.
"""
return Migration(
from_version=20,
to_version=21,
callback=Migration21Callback(),
)

View File

@@ -112,7 +112,7 @@ def denoise(
)
# Slice prediction to only include the main image tokens
if img_input_ids is not None:
if img_cond_seq is not None:
pred = pred[:, :original_seq_len]
step_cfg_scale = cfg_scale[step_index]
@@ -125,9 +125,26 @@ def denoise(
if neg_regional_prompting_extension is None:
raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.")
# For negative prediction with Kontext, we need to include the reference images
# to maintain consistency between positive and negative passes. Without this,
# CFG would create artifacts as the attention mechanism would see different
# spatial structures in each pass
neg_img_input = img
neg_img_input_ids = img_ids
# Add channel-wise conditioning for negative pass if present
if img_cond is not None:
neg_img_input = torch.cat((neg_img_input, img_cond), dim=-1)
# Add sequence-wise conditioning (Kontext) for negative pass
# This ensures reference images are processed consistently
if img_cond_seq is not None:
neg_img_input = torch.cat((neg_img_input, img_cond_seq), dim=1)
neg_img_input_ids = torch.cat((neg_img_input_ids, img_cond_seq_ids), dim=1)
neg_pred = model(
img=img,
img_ids=img_ids,
img=neg_img_input,
img_ids=neg_img_input_ids,
txt=neg_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
txt_ids=neg_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
y=neg_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
@@ -140,6 +157,10 @@ def denoise(
ip_adapter_extensions=neg_ip_adapter_extensions,
regional_prompting_extension=neg_regional_prompting_extension,
)
# Slice negative prediction to match main image tokens
if img_cond_seq is not None:
neg_pred = neg_pred[:, :original_seq_len]
pred = neg_pred + step_cfg_scale * (pred - neg_pred)
preview_img = img - t_curr * pred

View File

@@ -1,15 +1,14 @@
import einops
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms as T
from einops import repeat
from PIL import Image
from invokeai.app.invocations.fields import FluxKontextConditioningField
from invokeai.app.invocations.flux_vae_encode import FluxVaeEncodeInvocation
from invokeai.app.invocations.model import VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.flux.sampling_utils import pack
from invokeai.backend.flux.util import PREFERED_KONTEXT_RESOLUTIONS
from invokeai.backend.util.devices import TorchDevice
def generate_img_ids_with_offset(
@@ -19,8 +18,10 @@ def generate_img_ids_with_offset(
device: torch.device,
dtype: torch.dtype,
idx_offset: int = 0,
h_offset: int = 0,
w_offset: int = 0,
) -> torch.Tensor:
"""Generate tensor of image position ids with an optional offset.
"""Generate tensor of image position ids with optional index and spatial offsets.
Args:
latent_height (int): Height of image in latent space (after packing, this becomes h//2).
@@ -28,7 +29,9 @@ def generate_img_ids_with_offset(
batch_size (int): Number of images in the batch.
device (torch.device): Device to create tensors on.
dtype (torch.dtype): Data type for the tensors.
idx_offset (int): Offset to add to the first dimension of the image ids.
idx_offset (int): Offset to add to the first dimension of the image ids (default: 0).
h_offset (int): Spatial offset for height/y-coordinates in latent space (default: 0).
w_offset (int): Spatial offset for width/x-coordinates in latent space (default: 0).
Returns:
torch.Tensor: Image position ids with shape [batch_size, (latent_height//2 * latent_width//2), 3].
@@ -42,6 +45,10 @@ def generate_img_ids_with_offset(
packed_height = latent_height // 2
packed_width = latent_width // 2
# Convert spatial offsets from latent space to packed space
packed_h_offset = h_offset // 2
packed_w_offset = w_offset // 2
# Create base tensor for position IDs with shape [packed_height, packed_width, 3]
# The 3 channels represent: [batch_offset, y_position, x_position]
img_ids = torch.zeros(packed_height, packed_width, 3, device=device, dtype=dtype)
@@ -49,13 +56,13 @@ def generate_img_ids_with_offset(
# Set the batch offset for all positions
img_ids[..., 0] = idx_offset
# Create y-coordinate indices (vertical positions)
y_indices = torch.arange(packed_height, device=device, dtype=dtype)
# Create y-coordinate indices (vertical positions) with spatial offset
y_indices = torch.arange(packed_height, device=device, dtype=dtype) + packed_h_offset
# Broadcast y_indices to match the spatial dimensions [packed_height, 1]
img_ids[..., 1] = y_indices[:, None]
# Create x-coordinate indices (horizontal positions)
x_indices = torch.arange(packed_width, device=device, dtype=dtype)
# Create x-coordinate indices (horizontal positions) with spatial offset
x_indices = torch.arange(packed_width, device=device, dtype=dtype) + packed_w_offset
# Broadcast x_indices to match the spatial dimensions [1, packed_width]
img_ids[..., 2] = x_indices[None, :]
@@ -73,14 +80,14 @@ class KontextExtension:
def __init__(
self,
kontext_conditioning: FluxKontextConditioningField,
kontext_conditioning: list[FluxKontextConditioningField],
context: InvocationContext,
vae_field: VAEField,
device: torch.device,
dtype: torch.dtype,
):
"""
Initializes the KontextExtension, pre-processing the reference image
Initializes the KontextExtension, pre-processing the reference images
into latents and positional IDs.
"""
self._context = context
@@ -93,54 +100,101 @@ class KontextExtension:
self.kontext_latents, self.kontext_ids = self._prepare_kontext()
def _prepare_kontext(self) -> tuple[torch.Tensor, torch.Tensor]:
"""Encodes the reference image and prepares its latents and IDs."""
image = self._context.images.get_pil(self.kontext_conditioning.image.image_name)
"""Encodes the reference images and prepares their concatenated latents and IDs with spatial tiling."""
all_latents = []
all_ids = []
# Calculate aspect ratio of input image
width, height = image.size
aspect_ratio = width / height
# Track cumulative dimensions for spatial tiling
# These track the running extent of the virtual canvas in latent space
h = 0 # Running height extent
w = 0 # Running width extent
# Find the closest preferred resolution by aspect ratio
_, target_width, target_height = min(
((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS), key=lambda x: x[0]
)
# Apply BFL's scaling formula
# This ensures compatibility with the model's training
scaled_width = 2 * int(target_width / 16)
scaled_height = 2 * int(target_height / 16)
# Resize to the exact resolution used during training
image = image.convert("RGB")
final_width = 8 * scaled_width
final_height = 8 * scaled_height
image = image.resize((final_width, final_height), Image.Resampling.LANCZOS)
# Convert to tensor with same normalization as BFL
image_np = np.array(image)
image_tensor = torch.from_numpy(image_np).float() / 127.5 - 1.0
image_tensor = einops.rearrange(image_tensor, "h w c -> 1 c h w")
image_tensor = image_tensor.to(self._device)
# Continue with VAE encoding
vae_info = self._context.models.load(self._vae_field.vae)
kontext_latents_unpacked = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
# Extract tensor dimensions
batch_size, _, latent_height, latent_width = kontext_latents_unpacked.shape
for idx, kontext_field in enumerate(self.kontext_conditioning):
image = self._context.images.get_pil(kontext_field.image.image_name)
# Pack the latents and generate IDs
kontext_latents_packed = pack(kontext_latents_unpacked).to(self._device, self._dtype)
kontext_ids = generate_img_ids_with_offset(
latent_height=latent_height,
latent_width=latent_width,
batch_size=batch_size,
device=self._device,
dtype=self._dtype,
idx_offset=1,
)
# Convert to RGB
image = image.convert("RGB")
return kontext_latents_packed, kontext_ids
# Convert to tensor using torchvision transforms for consistency
transformation = T.Compose(
[
T.ToTensor(), # Converts PIL image to tensor and scales to [0, 1]
]
)
image_tensor = transformation(image)
# Convert from [0, 1] to [-1, 1] range expected by VAE
image_tensor = image_tensor * 2.0 - 1.0
image_tensor = image_tensor.unsqueeze(0) # Add batch dimension
image_tensor = image_tensor.to(self._device)
# Continue with VAE encoding
# Don't sample from the distribution for reference images - use the mean (matching ComfyUI)
with vae_info as vae:
assert isinstance(vae, AutoEncoder)
vae_dtype = next(iter(vae.parameters())).dtype
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
# Use sample=False to get the distribution mean without noise
kontext_latents_unpacked = vae.encode(image_tensor, sample=False)
# Extract tensor dimensions
batch_size, _, latent_height, latent_width = kontext_latents_unpacked.shape
# Pad latents to be compatible with patch_size=2
# This ensures dimensions are even for the pack() function
pad_h = (2 - latent_height % 2) % 2
pad_w = (2 - latent_width % 2) % 2
if pad_h > 0 or pad_w > 0:
kontext_latents_unpacked = F.pad(kontext_latents_unpacked, (0, pad_w, 0, pad_h), mode="circular")
# Update dimensions after padding
_, _, latent_height, latent_width = kontext_latents_unpacked.shape
# Pack the latents
kontext_latents_packed = pack(kontext_latents_unpacked).to(self._device, self._dtype)
# Determine spatial offsets for this reference image
# - Compare the potential new canvas dimensions if we add the image vertically vs horizontally
# - Choose the placement that results in a more square-like canvas
h_offset = 0
w_offset = 0
if idx > 0: # First image starts at (0, 0)
# Check which placement would result in better canvas dimensions
# If adding to height would make the canvas taller than wide, tile horizontally
# Otherwise, tile vertically
if latent_height + h > latent_width + w:
# Tile horizontally (to the right of existing images)
w_offset = w
else:
# Tile vertically (below existing images)
h_offset = h
# Generate IDs with both index offset and spatial offsets
kontext_ids = generate_img_ids_with_offset(
latent_height=latent_height,
latent_width=latent_width,
batch_size=batch_size,
device=self._device,
dtype=self._dtype,
idx_offset=1, # All reference images use index=1 (matching ComfyUI implementation)
h_offset=h_offset,
w_offset=w_offset,
)
# Update cumulative dimensions
# Track the maximum extent of the virtual canvas after placing this image
h = max(h, latent_height + h_offset)
w = max(w, latent_width + w_offset)
all_latents.append(kontext_latents_packed)
all_ids.append(kontext_ids)
# Concatenate all latents and IDs along the sequence dimension
concatenated_latents = torch.cat(all_latents, dim=1) # Concatenate along sequence dimension
concatenated_ids = torch.cat(all_ids, dim=1) # Concatenate along sequence dimension
return concatenated_latents, concatenated_ids
def ensure_batch_size(self, target_batch_size: int) -> None:
"""Ensures the kontext latents and IDs match the target batch size by repeating if necessary."""

View File

@@ -9,6 +9,7 @@ import spandrel
import torch
import invokeai.backend.util.logging as logger
from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.misc import uuid_string
from invokeai.backend.flux.controlnet.state_dict_utils import (
is_state_dict_instantx_controlnet,
@@ -493,9 +494,21 @@ class ModelProbe(object):
# scan model
scan_result = pscan.scan_file_path(checkpoint)
if scan_result.infected_files != 0:
raise Exception(f"The model {model_name} is potentially infected by malware. Aborting import.")
if get_config().unsafe_disable_picklescan:
logger.warning(
f"The model {model_name} is potentially infected by malware, but picklescan is disabled. "
"Proceeding with caution."
)
else:
raise RuntimeError(f"The model {model_name} is potentially infected by malware. Aborting import.")
if scan_result.scan_err:
raise Exception(f"Error scanning model {model_name} for malware. Aborting import.")
if get_config().unsafe_disable_picklescan:
logger.warning(
f"Error scanning the model at {model_name} for malware, but picklescan is disabled. "
"Proceeding with caution."
)
else:
raise RuntimeError(f"Error scanning the model at {model_name} for malware. Aborting import.")
# Probing utilities

View File

@@ -6,13 +6,17 @@ import torch
from picklescan.scanner import scan_file_path
from safetensors import safe_open
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
from invokeai.backend.model_manager.taxonomy import ModelRepoVariant
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.backend.util.silence_warnings import SilenceWarnings
StateDict: TypeAlias = dict[str | int, Any] # When are the keys int?
logger = InvokeAILogger.get_logger()
class ModelOnDisk:
"""A utility class representing a model stored on disk."""
@@ -79,8 +83,24 @@ class ModelOnDisk:
with SilenceWarnings():
if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
scan_result = scan_file_path(path)
if scan_result.infected_files != 0 or scan_result.scan_err:
raise RuntimeError(f"The model {path.stem} is potentially infected by malware. Aborting import.")
if scan_result.infected_files != 0:
if get_config().unsafe_disable_picklescan:
logger.warning(
f"The model {path.stem} is potentially infected by malware, but picklescan is disabled. "
"Proceeding with caution."
)
else:
raise RuntimeError(
f"The model {path.stem} is potentially infected by malware. Aborting import."
)
if scan_result.scan_err:
if get_config().unsafe_disable_picklescan:
logger.warning(
f"Error scanning the model at {path.stem} for malware, but picklescan is disabled. "
"Proceeding with caution."
)
else:
raise RuntimeError(f"Error scanning the model at {path.stem} for malware. Aborting import.")
checkpoint = torch.load(path, map_location="cpu")
assert isinstance(checkpoint, dict)
elif path.suffix.endswith(".gguf"):

View File

@@ -149,13 +149,29 @@ flux_kontext = StarterModel(
dependencies=[t5_base_encoder, flux_vae, clip_l_encoder],
)
flux_kontext_quantized = StarterModel(
name="FLUX.1 Kontext dev (Quantized)",
name="FLUX.1 Kontext dev (quantized)",
base=BaseModelType.Flux,
source="https://huggingface.co/unsloth/FLUX.1-Kontext-dev-GGUF/resolve/main/flux1-kontext-dev-Q4_K_M.gguf",
description="FLUX.1 Kontext dev quantized (q4_k_m). Total size with dependencies: ~14GB",
type=ModelType.Main,
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
)
flux_krea = StarterModel(
name="FLUX.1 Krea dev",
base=BaseModelType.Flux,
source="https://huggingface.co/InvokeAI/FLUX.1-Krea-dev/resolve/main/flux1-krea-dev.safetensors",
description="FLUX.1 Krea dev. Total size with dependencies: ~33GB",
type=ModelType.Main,
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
)
flux_krea_quantized = StarterModel(
name="FLUX.1 Krea dev (quantized)",
base=BaseModelType.Flux,
source="https://huggingface.co/InvokeAI/FLUX.1-Krea-dev-GGUF/resolve/main/flux1-krea-dev-Q4_K_M.gguf",
description="FLUX.1 Krea dev quantized (q4_k_m). Total size with dependencies: ~14GB",
type=ModelType.Main,
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
)
sd35_medium = StarterModel(
name="SD3.5 Medium",
base=BaseModelType.StableDiffusion3,
@@ -580,13 +596,14 @@ t2i_sketch_sdxl = StarterModel(
)
# endregion
# region SpandrelImageToImage
realesrgan_anime = StarterModel(
name="RealESRGAN_x4plus_anime_6B",
animesharp_v4_rcan = StarterModel(
name="2x-AnimeSharpV4_RCAN",
base=BaseModelType.Any,
source="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
description="A Real-ESRGAN 4x upscaling model (optimized for anime images).",
source="https://github.com/Kim2091/Kim2091-Models/releases/download/2x-AnimeSharpV4/2x-AnimeSharpV4_RCAN.safetensors",
description="A 2x upscaling model (optimized for anime images).",
type=ModelType.SpandrelImageToImage,
)
realesrgan_x4 = StarterModel(
name="RealESRGAN_x4plus",
base=BaseModelType.Any,
@@ -732,7 +749,7 @@ STARTER_MODELS: list[StarterModel] = [
t2i_lineart_sdxl,
t2i_sketch_sdxl,
realesrgan_x4,
realesrgan_anime,
animesharp_v4_rcan,
realesrgan_x2,
swinir,
t5_base_encoder,
@@ -743,6 +760,8 @@ STARTER_MODELS: list[StarterModel] = [
llava_onevision,
flux_fill,
cogview4,
flux_krea,
flux_krea_quantized,
]
sd1_bundle: list[StarterModel] = [
@@ -794,6 +813,7 @@ flux_bundle: list[StarterModel] = [
flux_redux,
flux_fill,
flux_kontext_quantized,
flux_krea_quantized,
]
STARTER_BUNDLES: dict[str, StarterModelBundle] = {

View File

@@ -8,8 +8,12 @@ import picklescan.scanner as pscan
import safetensors
import torch
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.model_manager.taxonomy import ClipVariantType
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
from invokeai.backend.util.logging import InvokeAILogger
logger = InvokeAILogger.get_logger()
def _fast_safetensors_reader(path: str) -> Dict[str, torch.Tensor]:
@@ -59,9 +63,21 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = True) -> Dict[str,
if scan:
scan_result = pscan.scan_file_path(path)
if scan_result.infected_files != 0:
raise Exception(f"The model at {path} is potentially infected by malware. Aborting import.")
if get_config().unsafe_disable_picklescan:
logger.warning(
f"The model {path} is potentially infected by malware, but picklescan is disabled. "
"Proceeding with caution."
)
else:
raise RuntimeError(f"The model {path} is potentially infected by malware. Aborting import.")
if scan_result.scan_err:
raise Exception(f"Error scanning model at {path} for malware. Aborting import.")
if get_config().unsafe_disable_picklescan:
logger.warning(
f"Error scanning the model at {path} for malware, but picklescan is disabled. "
"Proceeding with caution."
)
else:
raise RuntimeError(f"Error scanning the model at {path} for malware. Aborting import.")
checkpoint = torch.load(path, map_location=torch.device("meta"))
return checkpoint

View File

@@ -44,4 +44,5 @@ yalc.lock
# vitest
tsconfig.vitest-temp.json
coverage/
coverage/
*.tgz

View File

@@ -26,7 +26,7 @@ i18n.use(initReactI18next).init({
returnNull: false,
});
const store = createStore(undefined, false);
const store = createStore();
$store.set(store);
$baseUrl.set('http://localhost:9090');

View File

@@ -197,6 +197,10 @@ export default [
importNames: ['isEqual'],
message: 'Please use objectEquals from @observ33r/object-equals instead.',
},
{
name: 'zod/v3',
message: 'Import from zod instead.',
},
],
},
],

View File

@@ -17,6 +17,7 @@ const config: KnipConfig = {
'src/app/store/use-debounced-app-selector.ts',
],
ignoreBinaries: ['only-allow'],
ignoreDependencies: ['magic-string'],
paths: {
'public/*': ['public/*'],
},

View File

@@ -63,7 +63,7 @@
"framer-motion": "^11.10.0",
"i18next": "^25.3.2",
"i18next-http-backend": "^3.0.2",
"idb-keyval": "6.2.2",
"idb-keyval": "6.2.1",
"jsondiffpatch": "^0.7.3",
"konva": "^9.3.22",
"linkify-react": "^4.3.1",
@@ -103,7 +103,7 @@
"use-debounce": "^10.0.5",
"use-device-pixel-ratio": "^1.1.2",
"uuid": "^11.1.0",
"zod": "^4.0.5",
"zod": "^4.0.10",
"zod-validation-error": "^3.5.2"
},
"peerDependencies": {
@@ -139,6 +139,7 @@
"eslint-plugin-unused-imports": "^4.1.4",
"globals": "^16.3.0",
"knip": "^5.61.3",
"magic-string": "^0.30.17",
"openapi-types": "^12.1.3",
"openapi-typescript": "^7.6.1",
"prettier": "^3.5.3",

View File

@@ -81,8 +81,8 @@ importers:
specifier: ^3.0.2
version: 3.0.2
idb-keyval:
specifier: 6.2.2
version: 6.2.2
specifier: 6.2.1
version: 6.2.1
jsondiffpatch:
specifier: ^0.7.3
version: 0.7.3
@@ -201,11 +201,11 @@ importers:
specifier: ^11.1.0
version: 11.1.0
zod:
specifier: ^4.0.5
version: 4.0.5
specifier: ^4.0.10
version: 4.0.10
zod-validation-error:
specifier: ^3.5.2
version: 3.5.3(zod@4.0.5)
version: 3.5.3(zod@4.0.10)
devDependencies:
'@eslint/js':
specifier: ^9.31.0
@@ -291,6 +291,9 @@ importers:
knip:
specifier: ^5.61.3
version: 5.61.3(@types/node@22.16.0)(typescript@5.8.3)
magic-string:
specifier: ^0.30.17
version: 0.30.17
openapi-types:
specifier: ^12.1.3
version: 12.1.3
@@ -411,6 +414,10 @@ packages:
resolution: {integrity: sha512-vbavdySgbTTrmFE+EsiqUTzlOr5bzlnJtUv9PynGCAKvfQqjIXbvFdumPM/GxMDfyuGMJaJAU6TO4zc1Jf1i8Q==}
engines: {node: '>=6.9.0'}
'@babel/runtime@7.28.2':
resolution: {integrity: sha512-KHp2IflsnGywDjBWDkR9iEqiWSpc8GIi0lgTT3mOElT0PP1tG26P4tmFI2YvAdzgq9RGyoHZQEIEdZy6Ec5xCA==}
engines: {node: '>=6.9.0'}
'@babel/template@7.27.2':
resolution: {integrity: sha512-LPDZ85aEJyYSd18/DkjNh4/y1ntkE5KwUHWTiqgRxruuZL2F1yuHligVHLvcHY2vMHXttKFpJn6LwfI7cw7ODw==}
engines: {node: '>=6.9.0'}
@@ -2771,8 +2778,8 @@ packages:
typescript:
optional: true
idb-keyval@6.2.2:
resolution: {integrity: sha512-yjD9nARJ/jb1g+CvD0tlhUHOrJ9Sy0P8T9MF3YaLlHnSRpwPfpTX0XIvpmw3gAJUmEu3FiICLBDPXVwyEvrleg==}
idb-keyval@6.2.1:
resolution: {integrity: sha512-8Sb3veuYCyrZL+VBt9LJfZjLUPWVvqn8tG28VqYNFCo43KHcKuq+b4EiXGeuaLAQWL2YmyDgMp2aSpH9JHsEQg==}
ieee754@1.2.1:
resolution: {integrity: sha512-dcyqhDvX1C46lXZcVqCpK+FtMRQVdIMN6/Df5js2zouUsqG7I6sFxitIC+7KYK29KdXOLHdu9zL4sFnoVQnqaA==}
@@ -4511,8 +4518,8 @@ packages:
zod@3.25.76:
resolution: {integrity: sha512-gzUt/qt81nXsFGKIFcC3YnfEAx5NkunCfnDlvuBSSFS02bcXu4Lmea0AFIUwbLWxWPx3d9p8S5QoaujKcNQxcQ==}
zod@4.0.5:
resolution: {integrity: sha512-/5UuuRPStvHXu7RS+gmvRf4NXrNxpSllGwDnCBcJZtQsKrviYXm54yDGV2KYNLT5kq0lHGcl7lqWJLgSaG+tgA==}
zod@4.0.10:
resolution: {integrity: sha512-3vB+UU3/VmLL2lvwcY/4RV2i9z/YU0DTV/tDuYjrwmx5WeJ7hwy+rGEEx8glHp6Yxw7ibRbKSaIFBgReRPe5KA==}
zustand@4.5.7:
resolution: {integrity: sha512-CHOUy7mu3lbD6o6LJLfllpjkzhHXSBlX8B9+qPddUsIfeF5S/UZ5q0kmCsnRqT1UHFQZchNFDDzMbQsuesHWlw==}
@@ -4633,6 +4640,8 @@ snapshots:
'@babel/runtime@7.27.6': {}
'@babel/runtime@7.28.2': {}
'@babel/template@7.27.2':
dependencies:
'@babel/code-frame': 7.27.1
@@ -5736,7 +5745,7 @@ snapshots:
'@testing-library/dom@10.4.0':
dependencies:
'@babel/code-frame': 7.27.1
'@babel/runtime': 7.27.6
'@babel/runtime': 7.28.2
'@types/aria-query': 5.0.4
aria-query: 5.3.0
chalk: 4.1.2
@@ -7266,7 +7275,7 @@ snapshots:
optionalDependencies:
typescript: 5.8.3
idb-keyval@6.2.2: {}
idb-keyval@6.2.1: {}
ieee754@1.2.1: {}
@@ -9062,13 +9071,13 @@ snapshots:
dependencies:
zod: 3.25.76
zod-validation-error@3.5.3(zod@4.0.5):
zod-validation-error@3.5.3(zod@4.0.10):
dependencies:
zod: 4.0.5
zod: 4.0.10
zod@3.25.76: {}
zod@4.0.5: {}
zod@4.0.10: {}
zustand@4.5.7(@types/react@18.3.23)(immer@10.1.1)(react@18.3.1):
dependencies:

View File

@@ -1470,7 +1470,6 @@
"ui": {
"tabs": {
"queue": "Warteschlange",
"generation": "Erzeugung",
"gallery": "Galerie",
"models": "Modelle",
"upscaling": "Hochskalierung",

View File

@@ -610,6 +610,10 @@
"title": "Toggle Non-Raster Layers",
"desc": "Show or hide all non-raster layer categories (Control Layers, Inpaint Masks, Regional Guidance)."
},
"fitBboxToLayers": {
"title": "Fit Bbox To Layers",
"desc": "Automatically adjust the generation bounding box to fit visible layers"
},
"fitBboxToMasks": {
"title": "Fit Bbox To Masks",
"desc": "Automatically adjust the generation bounding box to fit visible inpaint masks"
@@ -1235,7 +1239,7 @@
"modelIncompatibleScaledBboxWidth": "Scaled bbox width is {{width}} but {{model}} requires multiple of {{multiple}}",
"modelIncompatibleScaledBboxHeight": "Scaled bbox height is {{height}} but {{model}} requires multiple of {{multiple}}",
"fluxModelMultipleControlLoRAs": "Can only use 1 Control LoRA at a time",
"fluxKontextMultipleReferenceImages": "Can only use 1 Reference Image at a time with Flux Kontext",
"fluxKontextMultipleReferenceImages": "Can only use 1 Reference Image at a time with FLUX Kontext via BFL API",
"canvasIsFiltering": "Canvas is busy (filtering)",
"canvasIsTransforming": "Canvas is busy (transforming)",
"canvasIsRasterizing": "Canvas is busy (rasterizing)",
@@ -2066,6 +2070,8 @@
"asControlLayer": "As $t(controlLayers.controlLayer)",
"asControlLayerResize": "As $t(controlLayers.controlLayer) (Resize)",
"referenceImage": "Reference Image",
"maxRefImages": "Max Ref Images",
"useAsReferenceImage": "Use as Reference Image",
"regionalReferenceImage": "Regional Reference Image",
"globalReferenceImage": "Global Reference Image",
"sendingToCanvas": "Staging Generations on Canvas",
@@ -2533,7 +2539,7 @@
},
"ui": {
"tabs": {
"generation": "Generation",
"generate": "Generate",
"canvas": "Canvas",
"workflows": "Workflows",
"workflowsTab": "$t(ui.tabs.workflows) $t(common.tab)",
@@ -2544,6 +2550,12 @@
"upscalingTab": "$t(ui.tabs.upscaling) $t(common.tab)",
"gallery": "Gallery"
},
"panels": {
"launchpad": "Launchpad",
"workflowEditor": "Workflow Editor",
"imageViewer": "Image Viewer",
"canvas": "Canvas"
},
"launchpad": {
"workflowsTitle": "Go deep with Workflows.",
"upscalingTitle": "Upscale and add detail.",
@@ -2551,6 +2563,28 @@
"generateTitle": "Generate images from text prompts.",
"modelGuideText": "Want to learn what prompts work best for each model?",
"modelGuideLink": "Check out our Model Guide.",
"createNewWorkflowFromScratch": "Create a new Workflow from scratch",
"browseAndLoadWorkflows": "Browse and load existing workflows",
"addStyleRef": {
"title": "Add a Style Reference",
"description": "Add an image to transfer its look."
},
"editImage": {
"title": "Edit Image",
"description": "Add an image to refine."
},
"generateFromText": {
"title": "Generate from Text",
"description": "Enter a prompt and Invoke."
},
"useALayoutImage": {
"title": "Use a Layout Image",
"description": "Add an image to control composition."
},
"generate": {
"canvasCalloutTitle": "Looking to get more control, edit, and iterate on your images?",
"canvasCalloutLink": "Navigate to Canvas for more capabilities."
},
"workflows": {
"description": "Workflows are reusable templates that automate image generation tasks, allowing you to quickly perform complex operations and get consistent results.",
"learnMoreLink": "Learn more about creating workflows",
@@ -2587,6 +2621,13 @@
"upscaleModel": "Upscale Model",
"model": "Model",
"scale": "Scale",
"creativityAndStructure": {
"title": "Creativity & Structure Defaults",
"conservative": "Conservative",
"balanced": "Balanced",
"creative": "Creative",
"artistic": "Artistic"
},
"helpText": {
"promptAdvice": "When upscaling, use a prompt that describes the medium and style. Avoid describing specific content details in the image.",
"styleAdvice": "Upscaling works best with the general style of your image."
@@ -2631,10 +2672,8 @@
"whatsNew": {
"whatsNewInInvoke": "What's New in Invoke",
"items": [
"New setting to send all Canvas generations directly to the Gallery.",
"New Invert Mask (Shift+V) and Fit BBox to Mask (Shift+B) capabilities.",
"Expanded support for Model Thumbnails and configurations.",
"Various other quality of life updates and fixes"
"Studio state is saved to the server, allowing you to continue your work on any device.",
"Support for multiple reference images for FLUX Kontext (local model only)."
],
"readReleaseNotes": "Read Release Notes",
"watchRecentReleaseVideos": "Watch Recent Release Videos",

View File

@@ -399,7 +399,6 @@
"ui": {
"tabs": {
"canvas": "Lienzo",
"generation": "Generación",
"queue": "Cola",
"workflows": "Flujos de trabajo",
"models": "Modelos",

View File

@@ -1820,7 +1820,6 @@
"upscaling": "Agrandissement",
"gallery": "Galerie",
"upscalingTab": "$t(ui.tabs.upscaling) $t(common.tab)",
"generation": "Génération",
"workflows": "Workflows",
"workflowsTab": "$t(ui.tabs.workflows) $t(common.tab)",
"models": "Modèles",

View File

@@ -254,8 +254,8 @@
"desc": "Attiva/disattiva il pannello destro."
},
"resetPanelLayout": {
"title": "Ripristina il layout del pannello",
"desc": "Ripristina le dimensioni e il layout predefiniti dei pannelli sinistro e destro."
"title": "Ripristina lo schema del pannello",
"desc": "Ripristina le dimensioni e lo schema predefiniti dei pannelli sinistro e destro."
},
"togglePanels": {
"title": "Attiva/disattiva i pannelli",
@@ -539,6 +539,10 @@
"galleryNavUpAlt": {
"desc": "Uguale a Naviga verso l'alto, ma seleziona l'immagine da confrontare, aprendo la modalità di confronto se non è già aperta.",
"title": "Naviga verso l'alto (Confronta immagine)"
},
"starImage": {
"desc": "Aggiungi/Rimuovi contrassegno all'immagine selezionata.",
"title": "Aggiungi / Rimuovi contrassegno immagine"
}
}
},
@@ -794,7 +798,7 @@
"modelIncompatibleScaledBboxWidth": "La larghezza scalata del riquadro è {{width}} ma {{model}} richiede multipli di {{multiple}}",
"modelIncompatibleScaledBboxHeight": "L'altezza scalata del riquadro è {{height}} ma {{model}} richiede multipli di {{multiple}}",
"modelDisabledForTrial": "La generazione con {{modelName}} non è disponibile per gli account di prova. Accedi alle impostazioni del tuo account per effettuare l'upgrade.",
"fluxKontextMultipleReferenceImages": "È possibile utilizzare solo 1 immagine di riferimento alla volta con Flux Kontext",
"fluxKontextMultipleReferenceImages": "È possibile utilizzare solo 1 immagine di riferimento alla volta con FLUX Kontext tramite BFL API",
"promptExpansionResultPending": "Accetta o ignora il risultato dell'espansione del prompt",
"promptExpansionPending": "Espansione del prompt in corso"
},
@@ -1162,7 +1166,19 @@
"unexpectedField_withName": "Campo \"{{name}}\" inaspettato",
"missingSourceOrTargetHandle": "Identificatore del nodo sorgente o di destinazione mancante",
"layout": {
"alignmentDR": "In basso a destra"
"alignmentDR": "In basso a destra",
"autoLayout": "Schema automatico",
"nodeSpacing": "Spaziatura nodi",
"layerSpacing": "Spaziatura livelli",
"layeringStrategy": "Strategia livelli",
"longestPath": "Percorso più lungo",
"layoutDirection": "Direzione schema",
"layoutDirectionRight": "A destra",
"layoutDirectionDown": "In basso",
"alignment": "Allineamento nodi",
"alignmentUL": "In alto a sinistra",
"alignmentDL": "In basso a sinistra",
"alignmentUR": "In alto a destra"
}
},
"boards": {
@@ -1240,7 +1256,7 @@
"batchQueuedDesc_other": "Aggiunte {{count}} sessioni a {{direction}} della coda",
"graphQueued": "Grafico in coda",
"batch": "Lotto",
"clearQueueAlertDialog": "Lo svuotamento della coda annulla immediatamente tutti gli elementi in elaborazione e cancella completamente la coda. I filtri in sospeso verranno annullati.",
"clearQueueAlertDialog": "La cancellazione della coda annulla immediatamente tutti gli elementi in elaborazione e cancella completamente la coda. I filtri in sospeso verranno annullati e l'area di lavoro della Tela verrà reimpostata.",
"pending": "In attesa",
"completedIn": "Completato in",
"resumeFailed": "Problema nel riavvio dell'elaborazione",
@@ -1296,7 +1312,8 @@
"retrySucceeded": "Elemento rieseguito",
"retryItem": "Riesegui elemento",
"retryFailed": "Problema riesecuzione elemento",
"credits": "Crediti"
"credits": "Crediti",
"cancelAllExceptCurrent": "Annulla tutto tranne quello corrente"
},
"models": {
"noMatchingModels": "Nessun modello corrispondente",
@@ -1711,7 +1728,7 @@
"structure": {
"heading": "Struttura",
"paragraphs": [
"La struttura determina quanto l'immagine finale rispecchierà il layout dell'originale. Una struttura bassa permette cambiamenti significativi, mentre una struttura alta conserva la composizione e il layout originali."
"La struttura determina quanto l'immagine finale rispecchierà lo schema dell'originale. Un valore struttura basso permette cambiamenti significativi, mentre un valore struttura alto conserva la composizione e lo schema originali."
]
},
"fluxDevLicense": {
@@ -1877,7 +1894,7 @@
"opened": "Aperto",
"convertGraph": "Converti grafico",
"loadWorkflow": "$t(common.load) Flusso di lavoro",
"autoLayout": "Disposizione automatica",
"autoLayout": "Schema automatico",
"loadFromGraph": "Carica il flusso di lavoro dal grafico",
"userWorkflows": "Flussi di lavoro utente",
"projectWorkflows": "Flussi di lavoro del progetto",
@@ -2478,11 +2495,12 @@
"off": "Spento"
},
"invertMask": "Inverti maschera",
"fitBboxToMasks": "Adatta il riquadro di delimitazione alle maschere"
"fitBboxToMasks": "Adatta il riquadro di delimitazione alle maschere",
"maxRefImages": "Max Immagini di rif.to",
"useAsReferenceImage": "Usa come immagine di riferimento"
},
"ui": {
"tabs": {
"generation": "Generazione",
"canvas": "Tela",
"workflows": "Flussi di lavoro",
"workflowsTab": "$t(ui.tabs.workflows) $t(common.tab)",
@@ -2491,7 +2509,8 @@
"queue": "Coda",
"upscaling": "Amplia",
"upscalingTab": "$t(ui.tabs.upscaling) $t(common.tab)",
"gallery": "Galleria"
"gallery": "Galleria",
"generate": "Genera"
},
"launchpad": {
"workflowsTitle": "Approfondisci i flussi di lavoro.",
@@ -2539,8 +2558,43 @@
"helpText": {
"promptAdvice": "Durante l'ampliamento, utilizza un prompt che descriva il mezzo e lo stile. Evita di descrivere dettagli specifici del contenuto dell'immagine.",
"styleAdvice": "L'ampliamento funziona meglio con lo stile generale dell'immagine."
},
"creativityAndStructure": {
"title": "Creatività e struttura predefinite",
"conservative": "Conservativo",
"balanced": "Bilanciato",
"creative": "Creativo",
"artistic": "Artistico"
}
},
"createNewWorkflowFromScratch": "Crea un nuovo flusso di lavoro da zero",
"browseAndLoadWorkflows": "Sfoglia e carica i flussi di lavoro esistenti",
"addStyleRef": {
"title": "Aggiungi un riferimento di stile",
"description": "Aggiungi un'immagine per trasferirne l'aspetto."
},
"editImage": {
"title": "Modifica immagine",
"description": "Aggiungi un'immagine da perfezionare."
},
"generateFromText": {
"title": "Genera da testo",
"description": "Inserisci un prompt e genera."
},
"useALayoutImage": {
"description": "Aggiungi un'immagine per controllare la composizione.",
"title": "Usa una immagine guida"
},
"generate": {
"canvasCalloutTitle": "Vuoi avere più controllo, modificare e affinare le tue immagini?",
"canvasCalloutLink": "Per ulteriori funzionalità, vai su Tela."
}
},
"panels": {
"launchpad": "Rampa di lancio",
"workflowEditor": "Editor del flusso di lavoro",
"imageViewer": "Visualizzatore immagini",
"canvas": "Tela"
}
},
"upscaling": {
@@ -2631,9 +2685,8 @@
"watchRecentReleaseVideos": "Guarda i video su questa versione",
"watchUiUpdatesOverview": "Guarda le novità dell'interfaccia",
"items": [
"Genera immagini più velocemente con le nuove Rampe di lancio e una scheda Genera semplificata.",
"Modifica con prompt utilizzando Flux Kontext Dev.",
"Esporta in PSD, nascondi sovrapposizioni in blocco, organizza modelli e immagini: il tutto in un'interfaccia riprogettata e pensata per il controllo."
"Lo stato dello studio viene salvato sul server, consentendoti di continuare a lavorare su qualsiasi dispositivo.",
"Supporto per più immagini di riferimento per FLUX Kontext (solo modello locale)."
]
},
"system": {

View File

@@ -1783,7 +1783,6 @@
"workflows": "ワークフロー",
"models": "モデル",
"gallery": "ギャラリー",
"generation": "生成",
"workflowsTab": "$t(ui.tabs.workflows) $t(common.tab)",
"modelsTab": "$t(ui.tabs.models) $t(common.tab)",
"upscaling": "アップスケーリング",

View File

@@ -1931,7 +1931,6 @@
},
"ui": {
"tabs": {
"generation": "Генерация",
"canvas": "Холст",
"workflowsTab": "$t(ui.tabs.workflows) $t(common.tab)",
"models": "Модели",

View File

@@ -299,7 +299,7 @@
"pruneTooltip": "Cắt bớt {{item_count}} mục đã hoàn tất",
"pruneSucceeded": "Đã cắt bớt {{item_count}} mục đã hoàn tất khỏi hàng",
"clearTooltip": "Huỷ Và Dọn Dẹp Tất Cả Mục",
"clearQueueAlertDialog": "Dọn dẹp hàng đợi sẽ ngay lập tức huỷ tất cả mục đang xử lý và làm sạch hàng hoàn toàn. Bộ lọc đang chờ xử lý sẽ bị huỷ bỏ.",
"clearQueueAlertDialog": "Dọn dẹp hàng đợi sẽ ngay lập tức huỷ tất cả mục đang xử lý và làm sạch hàng hoàn toàn. Bộ lọc đang chờ xử lý sẽ bị huỷ bỏ và Vùng Dựng Canva sẽ được khởi động lại.",
"session": "Phiên",
"item": "Mục",
"resumeFailed": "Có Vấn Đề Khi Tiếp Tục Bộ Xử Lý",
@@ -343,13 +343,14 @@
"retrySucceeded": "Mục Đã Thử Lại",
"retryFailed": "Có Vấn Đề Khi Thử Lại Mục",
"retryItem": "Thử Lại Mục",
"credits": "Nguồn"
"credits": "Nguồn",
"cancelAllExceptCurrent": "Huỷ Bỏ Tất Cả Ngoại Trừ Mục Hiện Tại"
},
"hotkeys": {
"canvas": {
"fitLayersToCanvas": {
"title": "Xếp Vừa Layers Vào Canvas",
"desc": "Căn chỉnh để góc nhìn vừa vặn với tất cả layer."
"desc": "Căn chỉnh để góc nhìn vừa vặn với tất cả layer nhìn thấy dược."
},
"setZoomTo800Percent": {
"desc": "Phóng to canvas lên 800%.",
@@ -473,6 +474,24 @@
"toggleNonRasterLayers": {
"title": "Bật/Tắt Layer Không Thuộc Dạng Raster",
"desc": "Hiện hoặc ẩn tất cả layer không thuộc dạng raster (Layer Điều Khiển Được, Lớp Phủ Inpaint, Chỉ Dẫn Khu Vực)."
},
"invertMask": {
"title": "Đảo Ngược Lớp Phủ",
"desc": "Đảo ngược lớp phủ inpaint được chọn, tạo một lớp phủ mới với độ trong suốt đối nghịch."
},
"fitBboxToMasks": {
"title": "Xếp Vừa Hộp Giới Hạn Vào Lớp Phủ",
"desc": "Tự động điểu chỉnh hộp giới hạn tạo sinh vừa vặn vào lớp phủ inpaint nhìn thấy được"
},
"applySegmentAnything": {
"title": "Áp Dụng Segment Anything",
"desc": "Áp dụng lớp phủ Segment Anything hiện tại.",
"key": "enter"
},
"cancelSegmentAnything": {
"title": "Huỷ Segment Anything",
"desc": "Huỷ hoạt động Segment Anything hiện tại.",
"key": "esc"
}
},
"workflows": {
@@ -602,6 +621,10 @@
"clearSelection": {
"desc": "Xoá phần lựa chọn hiện tại nếu có.",
"title": "Xoá Phần Lựa Chọn"
},
"starImage": {
"title": "Dấu/Huỷ Sao Hình Ảnh",
"desc": "Đánh dấu sao hoặc huỷ đánh dấu sao ảnh được chọn."
}
},
"app": {
@@ -661,6 +684,11 @@
"selectModelsTab": {
"desc": "Chọn tab Model (Mô Hình).",
"title": "Chọn Tab Model"
},
"selectGenerateTab": {
"title": "Chọn Tab Tạo Sinh",
"desc": "Chọn tab Tạo Sinh.",
"key": "1"
}
},
"searchHotkeys": "Tìm Phím tắt",
@@ -1090,7 +1118,23 @@
"unknownField_withName": "Vùng Dữ Liệu Không Rõ \"{{name}}\"",
"unexpectedField_withName": "Sai Vùng Dữ Liệu \"{{name}}\"",
"unknownFieldEditWorkflowToFix_withName": "Workflow chứa vùng dữ liệu không rõ \"{{name}}\".\nHãy biên tập workflow để sửa lỗi.",
"missingField_withName": "Thiếu Vùng Dữ Liệu \"{{name}}\""
"missingField_withName": "Thiếu Vùng Dữ Liệu \"{{name}}\"",
"layout": {
"autoLayout": "Bố Cục Tự Động",
"layeringStrategy": "Chiến Lược Phân Layer",
"networkSimplex": "Network Simplex",
"longestPath": "Đường Đi Dài Nhất",
"nodeSpacing": "Khoảng Cách Node",
"layerSpacing": "Khoảng Cách Layer",
"layoutDirection": "Hướng Bố Cục",
"layoutDirectionRight": "Phải",
"layoutDirectionDown": "Xuống",
"alignment": "Căn Chỉnh Node",
"alignmentUL": "Trên Cùng Bên Trái",
"alignmentDL": "Dưới Cùng Bên Trái",
"alignmentUR": "Trên Cùng Bên Phải",
"alignmentDR": "Dưới Cùng Bên Phải"
}
},
"popovers": {
"paramCFGRescaleMultiplier": {
@@ -1597,7 +1641,7 @@
"modelIncompatibleScaledBboxHeight": "Chiều dài hộp giới hạn theo tỉ lệ là {{height}} nhưng {{model}} yêu cầu bội số của {{multiple}}",
"modelIncompatibleScaledBboxWidth": "Chiều rộng hộp giới hạn theo tỉ lệ là {{width}} nhưng {{model}} yêu cầu bội số của {{multiple}}",
"modelDisabledForTrial": "Tạo sinh với {{modelName}} là không thể với tài khoản trial. Vào phần thiết lập tài khoản để nâng cấp.",
"fluxKontextMultipleReferenceImages": "Chỉ có thể dùng 1 Ảnh Mẫu cùng lúc với Flux Kontext",
"fluxKontextMultipleReferenceImages": "Chỉ có thể dùng 1 Ảnh Mẫu cùng lúc với LUX Kontext thông qua BFL API",
"promptExpansionPending": "Trong quá trình mở rộng lệnh",
"promptExpansionResultPending": "Hãy chấp thuận hoặc huỷ bỏ kết quả mở rộng lệnh của bạn"
},
@@ -2192,7 +2236,11 @@
"off": "Tắt",
"switchOnStart": "Khi Bắt Đầu",
"switchOnFinish": "Khi Kết Thúc"
}
},
"fitBboxToMasks": "Xếp Vừa Hộp Giới Hạn Vào Lớp Phủ",
"invertMask": "Đảo Ngược Lớp Phủ",
"maxRefImages": "Ảnh Mẫu Tối Đa",
"useAsReferenceImage": "Dùng Làm Ảnh Mẫu"
},
"stylePresets": {
"negativePrompt": "Lệnh Tiêu Cực",
@@ -2354,20 +2402,28 @@
"noValidLayerAdapters": "Không có Layer Adaper Phù Hợp",
"promptGenerationStarted": "Trình tạo sinh lệnh khởi động",
"uploadAndPromptGenerationFailed": "Thất bại khi tải lên ảnh để tạo sinh lệnh",
"promptExpansionFailed": "Có vấn đề xảy ra. Hãy thử mở rộng lệnh lại."
"promptExpansionFailed": "Có vấn đề xảy ra. Hãy thử mở rộng lệnh lại.",
"maskInverted": "Đã Đảo Ngược Lớp Phủ",
"maskInvertFailed": "Thất Bại Khi Đảo Ngược Lớp Phủ",
"noVisibleMasks": "Không Có Lớp Phủ Đang Hiển Thị",
"noVisibleMasksDesc": "Tạo hoặc bật ít nhất một lớp phủ inpaint để đảo ngược",
"noInpaintMaskSelected": "Không Có Lớp Phủ Inpant Được Chọn",
"noInpaintMaskSelectedDesc": "Chọn một lớp phủ inpaint để đảo ngược",
"invalidBbox": "Hộp Giới Hạn Không Hợp Lệ",
"invalidBboxDesc": "Hợp giới hạn có kích thước không hợp lệ"
},
"ui": {
"tabs": {
"gallery": "Thư Viện Ảnh",
"models": "Models",
"generation": "Generation (Máy Tạo Sinh)",
"upscaling": "Upscale (Nâng Cấp Chất Lượng Hình Ảnh)",
"canvas": "Canvas (Vùng Ảnh)",
"upscalingTab": "$t(common.tab) $t(ui.tabs.upscaling)",
"modelsTab": "$t(common.tab) $t(ui.tabs.models)",
"queue": "Queue (Hàng Đợi)",
"workflows": "Workflow (Luồng Làm Việc)",
"workflowsTab": "$t(common.tab) $t(ui.tabs.workflows)"
"workflowsTab": "$t(common.tab) $t(ui.tabs.workflows)",
"generate": "Tạo Sinh"
},
"launchpad": {
"workflowsTitle": "Đi sâu hơn với Workflow.",
@@ -2415,8 +2471,43 @@
"promptAdvice": "Khi upscale, dùng lệnh để mô tả phương thức và phong cách. Tránh mô tả các chi tiết cụ thể trong ảnh.",
"styleAdvice": "Upscale thích hợp nhất cho phong cách chung của ảnh."
},
"scale": "Kích Thước"
"scale": "Kích Thước",
"creativityAndStructure": {
"title": "Độ Sáng Tạo & Cấu Trúc Mặc Định",
"conservative": "Bảo toàn",
"balanced": "Cân bằng",
"creative": "Sáng tạo",
"artistic": "Thẩm mỹ"
}
},
"createNewWorkflowFromScratch": "Tạo workflow mới từ đầu",
"browseAndLoadWorkflows": "Duyệt và tải workflow có sẵn",
"addStyleRef": {
"title": "Thêm Phong Cách Mẫu",
"description": "Thêm ảnh để chuyển đổi diện mạo của nó."
},
"editImage": {
"title": "Biên Tập Ảnh",
"description": "Thêm ảnh để chỉnh sửa."
},
"generateFromText": {
"title": "Tạo Sinh Từ Chữ",
"description": "Nhập lệnh vào và Kích Hoạt."
},
"useALayoutImage": {
"title": "Dùng Bố Cục Ảnh",
"description": "Thêm ảnh để điều khiển bố cục."
},
"generate": {
"canvasCalloutTitle": "Đang tìm cách để điều khiển, chỉnh sửa, và làm lại ảnh?",
"canvasCalloutLink": "Vào Canvas cho nhiều tính năng hơn."
}
},
"panels": {
"launchpad": "Launchpad",
"workflowEditor": "Trình Biên Tập Workflow",
"imageViewer": "Trình Xem Ảnh",
"canvas": "Canvas"
}
},
"workflows": {
@@ -2588,9 +2679,8 @@
"watchRecentReleaseVideos": "Xem Video Phát Hành Mới Nhất",
"watchUiUpdatesOverview": "Xem Tổng Quan Về Những Cập Nhật Cho Giao Diện Người Dùng",
"items": [
"Tạo sinh ảnh nhanh hơn với Launchpad và thẻ Tạo Sinh đã cơ bản hoá.",
"Biên tập với lệnh bằng Flux Kontext Dev.",
"Xuất ra file PSD, ẩn số lượng lớn lớp phủ, sắp xếp model & ảnh — tất cả cho một giao diện đã thiết kế lại để chuyên điều khiển."
"Trạng thái Studio được lưu vào server, giúp bạn tiếp tục công việc ở mọi thiết bị.",
"Hỗ trợ nhiều ảnh mẫu cho FLUX KONTEXT (chỉ cho model trên máy)."
]
},
"upsell": {

View File

@@ -1772,7 +1772,6 @@
},
"ui": {
"tabs": {
"generation": "生成",
"queue": "队列",
"canvas": "画布",
"upscaling": "放大中",

View File

@@ -3,9 +3,9 @@ import { useStore } from '@nanostores/react';
import { GlobalHookIsolator } from 'app/components/GlobalHookIsolator';
import { GlobalModalIsolator } from 'app/components/GlobalModalIsolator';
import { $didStudioInit, type StudioInitAction } from 'app/hooks/useStudioInitAction';
import { clearStorage } from 'app/store/enhancers/reduxRemember/driver';
import type { PartialAppConfig } from 'app/types/invokeai';
import Loading from 'common/components/Loading/Loading';
import { useClearStorage } from 'common/hooks/useClearStorage';
import { AppContent } from 'features/ui/components/AppContent';
import { memo, useCallback } from 'react';
import { ErrorBoundary } from 'react-error-boundary';
@@ -21,13 +21,12 @@ interface Props {
const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
const didStudioInit = useStore($didStudioInit);
const clearStorage = useClearStorage();
const handleReset = useCallback(() => {
clearStorage();
location.reload();
return false;
}, [clearStorage]);
}, []);
return (
<ThemeLocaleProvider>

View File

@@ -5,6 +5,7 @@ import type { StudioInitAction } from 'app/hooks/useStudioInitAction';
import { $didStudioInit } from 'app/hooks/useStudioInitAction';
import type { LoggingOverrides } from 'app/logging/logger';
import { $loggingOverrides, configureLogging } from 'app/logging/logger';
import { addStorageListeners } from 'app/store/enhancers/reduxRemember/driver';
import { $accountSettingsLink } from 'app/store/nanostores/accountSettingsLink';
import { $authToken } from 'app/store/nanostores/authToken';
import { $baseUrl } from 'app/store/nanostores/baseUrl';
@@ -35,7 +36,7 @@ import {
import type { WorkflowCategory } from 'features/nodes/types/workflow';
import type { ToastConfig } from 'features/toast/toast';
import type { PropsWithChildren, ReactNode } from 'react';
import React, { lazy, memo, useEffect, useLayoutEffect, useMemo } from 'react';
import React, { lazy, memo, useEffect, useLayoutEffect, useState } from 'react';
import { Provider } from 'react-redux';
import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares';
import { $socketOptions } from 'services/events/stores';
@@ -70,6 +71,7 @@ interface Props extends PropsWithChildren {
* If provided, overrides in-app navigation to the model manager
*/
onClickGoToModelManager?: () => void;
storagePersistThrottle?: number;
}
const InvokeAIUI = ({
@@ -96,7 +98,11 @@ const InvokeAIUI = ({
loggingOverrides,
onClickGoToModelManager,
whatsNew,
storagePersistThrottle = 2000,
}: Props) => {
const [store, setStore] = useState<ReturnType<typeof createStore> | undefined>(undefined);
const [didRehydrate, setDidRehydrate] = useState(false);
useLayoutEffect(() => {
/*
* We need to configure logging before anything else happens - useLayoutEffect ensures we set this at the first
@@ -308,22 +314,30 @@ const InvokeAIUI = ({
};
}, [isDebugging]);
const store = useMemo(() => {
return createStore(projectId);
}, [projectId]);
useEffect(() => {
const onRehydrated = () => {
setDidRehydrate(true);
};
const store = createStore({ persist: true, persistThrottle: storagePersistThrottle, onRehydrated });
setStore(store);
$store.set(store);
if (import.meta.env.MODE === 'development') {
window.$store = $store;
}
const removeStorageListeners = addStorageListeners();
return () => {
removeStorageListeners();
setStore(undefined);
$store.set(undefined);
if (import.meta.env.MODE === 'development') {
window.$store = undefined;
}
};
}, [store]);
}, [storagePersistThrottle]);
if (!store || !didRehydrate) {
return <Loading />;
}
return (
<React.StrictMode>

View File

@@ -93,5 +93,7 @@ export const configureLogging = (
localStorage.setItem('ROARR_FILTER', filter);
}
ROARR.write = createLogWriter();
const styleOutput = localStorage.getItem('ROARR_STYLE_OUTPUT') === 'false' ? false : true;
ROARR.write = createLogWriter({ styleOutput });
};

View File

@@ -1,3 +1,2 @@
export const STORAGE_PREFIX = '@@invokeai-';
export const EMPTY_ARRAY = [];
export const EMPTY_OBJECT = {};

View File

@@ -1,40 +1,209 @@
import { logger } from 'app/logging/logger';
import { StorageError } from 'app/store/enhancers/reduxRemember/errors';
import { $authToken } from 'app/store/nanostores/authToken';
import { $projectId } from 'app/store/nanostores/projectId';
import { $queueId } from 'app/store/nanostores/queueId';
import type { UseStore } from 'idb-keyval';
import { clear, createStore as createIDBKeyValStore, get, set } from 'idb-keyval';
import { atom } from 'nanostores';
import { createStore as idbCreateStore, del as idbDel, get as idbGet } from 'idb-keyval';
import type { Driver } from 'redux-remember';
import { serializeError } from 'serialize-error';
import { buildV1Url, getBaseUrl } from 'services/api';
import type { JsonObject } from 'type-fest';
// Create a custom idb-keyval store (just needed to customize the name)
const $idbKeyValStore = atom<UseStore>(createIDBKeyValStore('invoke', 'invoke-store'));
const log = logger('system');
export const clearIdbKeyValStore = () => {
clear($idbKeyValStore.get());
const getUrl = (endpoint: 'get_by_key' | 'set_by_key' | 'delete', key?: string) => {
const baseUrl = getBaseUrl();
const query: Record<string, string> = {};
if (key) {
query['key'] = key;
}
const path = buildV1Url(`client_state/${$queueId.get()}/${endpoint}`, query);
const url = `${baseUrl}/${path}`;
return url;
};
// Create redux-remember driver, wrapping idb-keyval
export const idbKeyValDriver: Driver = {
getItem: (key) => {
try {
return get(key, $idbKeyValStore.get());
} catch (originalError) {
throw new StorageError({
key,
projectId: $projectId.get(),
originalError,
});
}
},
setItem: (key, value) => {
try {
return set(key, value, $idbKeyValStore.get());
} catch (originalError) {
throw new StorageError({
key,
value,
projectId: $projectId.get(),
originalError,
});
}
},
const getHeaders = () => {
const headers = new Headers();
const authToken = $authToken.get();
const projectId = $projectId.get();
if (authToken) {
headers.set('Authorization', `Bearer ${authToken}`);
}
if (projectId) {
headers.set('project-id', projectId);
}
return headers;
};
// Persistence happens per slice. To track when persistence is in progress, maintain a ref count, incrementing
// it when a slice is being persisted and decrementing it when the persistence is done.
let persistRefCount = 0;
// Keep track of the last persisted state for each key to avoid unnecessary network requests.
//
// `redux-remember` persists individual slices of state, so we can implicity denylist a slice by not giving it a
// persist config.
//
// However, we may need to avoid persisting individual _fields_ of a slice. `redux-remember` does not provide a
// way to do this directly.
//
// To accomplish this, we add a layer of logic on top of the `redux-remember`. In the state serializer function
// provided to `redux-remember`, we can omit certain fields from the state that we do not want to persist. See
// the implementation in `store.ts` for this logic.
//
// This logic is unknown to `redux-remember`. When an omitted field changes, it will still attempt to persist the
// whole slice, even if the final, _serialized_ slice value is unchanged.
//
// To avoid unnecessary network requests, we keep track of the last persisted state for each key in this map.
// If the value to be persisted is the same as the last persisted value, we will skip the network request.
const lastPersistedState = new Map<string, string | undefined>();
// As of v6.3.0, we use server-backed storage for client state. This replaces the previous IndexedDB-based storage,
// which was implemented using `idb-keyval`.
//
// To facilitate a smooth transition, we implement a migration strategy that attempts to retrieve values from IndexedDB
// and persist them to the new server-backed storage. This is done on a best-effort basis.
// These constants were used in the previous IndexedDB-based storage implementation.
const IDB_DB_NAME = 'invoke';
const IDB_STORE_NAME = 'invoke-store';
const IDB_STORAGE_PREFIX = '@@invokeai-';
// Lazy store creation
let _idbKeyValStore: UseStore | null = null;
const getIdbKeyValStore = () => {
if (_idbKeyValStore === null) {
_idbKeyValStore = idbCreateStore(IDB_DB_NAME, IDB_STORE_NAME);
}
return _idbKeyValStore;
};
const getIdbKey = (key: string) => {
return `${IDB_STORAGE_PREFIX}${key}`;
};
const getItem = async (key: string) => {
try {
const url = getUrl('get_by_key', key);
const headers = getHeaders();
const res = await fetch(url, { method: 'GET', headers });
if (!res.ok) {
throw new Error(`Response status: ${res.status}`);
}
const value = await res.json();
// Best-effort migration from IndexedDB to the new storage system
log.trace({ key, value }, 'Server-backed storage value retrieved');
if (!value) {
const idbKey = getIdbKey(key);
try {
// It's a bit tricky to query IndexedDB directly to check if value exists, so we use `idb-keyval` to do it.
// Thing is, `idb-keyval` requires you to create a store to query it. End result - we are creating a store
// even if we don't use it for anything besides checking if the key is present.
const idbKeyValStore = getIdbKeyValStore();
const idbValue = await idbGet(idbKey, idbKeyValStore);
if (idbValue) {
log.debug(
{ key, idbKey, idbValue },
'No value in server-backed storage, but found value in IndexedDB - attempting migration'
);
await idbDel(idbKey, idbKeyValStore);
await setItem(key, idbValue);
log.debug({ key, idbKey, idbValue }, 'Migration successful');
return idbValue;
}
} catch (error) {
// Just log if IndexedDB retrieval fails - this is a best-effort migration.
log.debug(
{ key, idbKey, error: serializeError(error) } as JsonObject,
'Error checking for or migrating from IndexedDB'
);
}
}
lastPersistedState.set(key, value);
log.trace({ key, last: lastPersistedState.get(key), next: value }, `Getting state for ${key}`);
return value;
} catch (originalError) {
throw new StorageError({
key,
projectId: $projectId.get(),
originalError,
});
}
};
const setItem = async (key: string, value: string) => {
try {
persistRefCount++;
if (lastPersistedState.get(key) === value) {
log.trace(
{ key, last: lastPersistedState.get(key), next: value },
`Skipping persist for ${key} as value is unchanged`
);
return value;
}
log.trace({ key, last: lastPersistedState.get(key), next: value }, `Persisting state for ${key}`);
const url = getUrl('set_by_key', key);
const headers = getHeaders();
const res = await fetch(url, { method: 'POST', headers, body: value });
if (!res.ok) {
throw new Error(`Response status: ${res.status}`);
}
const resultValue = await res.json();
lastPersistedState.set(key, resultValue);
return resultValue;
} catch (originalError) {
throw new StorageError({
key,
value,
projectId: $projectId.get(),
originalError,
});
} finally {
persistRefCount--;
if (persistRefCount < 0) {
log.trace('Persist ref count is negative, resetting to 0');
persistRefCount = 0;
}
}
};
export const reduxRememberDriver: Driver = { getItem, setItem };
export const clearStorage = async () => {
try {
persistRefCount++;
const url = getUrl('delete');
const headers = getHeaders();
const res = await fetch(url, { method: 'POST', headers });
if (!res.ok) {
throw new Error(`Response status: ${res.status}`);
}
} catch {
log.error('Failed to reset client state');
} finally {
persistRefCount--;
lastPersistedState.clear();
if (persistRefCount < 0) {
log.trace('Persist ref count is negative, resetting to 0');
persistRefCount = 0;
}
}
};
export const addStorageListeners = () => {
const onBeforeUnload = (e: BeforeUnloadEvent) => {
if (persistRefCount > 0) {
e.preventDefault();
}
};
window.addEventListener('beforeunload', onBeforeUnload);
return () => {
window.removeEventListener('beforeunload', onBeforeUnload);
};
};

View File

@@ -33,8 +33,9 @@ export class StorageError extends Error {
}
}
const log = logger('system');
export const errorHandler = (err: PersistError | RehydrateError) => {
const log = logger('system');
if (err instanceof PersistError) {
log.error({ error: serializeError(err) }, 'Problem persisting state');
} else if (err instanceof RehydrateError) {

View File

@@ -1,73 +0,0 @@
import type { TypedStartListening } from '@reduxjs/toolkit';
import { addListener, createListenerMiddleware } from '@reduxjs/toolkit';
import { addAdHocPostProcessingRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/addAdHocPostProcessingRequestedListener';
import { addAnyEnqueuedListener } from 'app/store/middleware/listenerMiddleware/listeners/anyEnqueued';
import { addAppConfigReceivedListener } from 'app/store/middleware/listenerMiddleware/listeners/appConfigReceived';
import { addAppStartedListener } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
import { addBatchEnqueuedListener } from 'app/store/middleware/listenerMiddleware/listeners/batchEnqueued';
import { addDeleteBoardAndImagesFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/boardAndImagesDeleted';
import { addBoardIdSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/boardIdSelected';
import { addBulkDownloadListeners } from 'app/store/middleware/listenerMiddleware/listeners/bulkDownload';
import { addGetOpenAPISchemaListener } from 'app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema';
import { addImageAddedToBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageAddedToBoard';
import { addImageRemovedFromBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageRemovedFromBoard';
import { addImageUploadedFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageUploaded';
import { addModelSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelSelected';
import { addModelsLoadedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelsLoaded';
import { addSetDefaultSettingsListener } from 'app/store/middleware/listenerMiddleware/listeners/setDefaultSettings';
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketConnected';
import type { AppDispatch, RootState } from 'app/store/store';
import { addArchivedOrDeletedBoardListener } from './listeners/addArchivedOrDeletedBoardListener';
export const listenerMiddleware = createListenerMiddleware();
export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
const startAppListening = listenerMiddleware.startListening as AppStartListening;
export const addAppListener = addListener.withTypes<RootState, AppDispatch>();
/**
* The RTK listener middleware is a lightweight alternative sagas/observables.
*
* Most side effect logic should live in a listener.
*/
// Image uploaded
addImageUploadedFulfilledListener(startAppListening);
// Image deleted
addDeleteBoardAndImagesFulfilledListener(startAppListening);
// User Invoked
addAnyEnqueuedListener(startAppListening);
addBatchEnqueuedListener(startAppListening);
// Socket.IO
addSocketConnectedEventListener(startAppListening);
// Gallery bulk download
addBulkDownloadListeners(startAppListening);
// Boards
addImageAddedToBoardFulfilledListener(startAppListening);
addImageRemovedFromBoardFulfilledListener(startAppListening);
addBoardIdSelectedListener(startAppListening);
addArchivedOrDeletedBoardListener(startAppListening);
// Node schemas
addGetOpenAPISchemaListener(startAppListening);
// Models
addModelSelectedListener(startAppListening);
// app startup
addAppStartedListener(startAppListening);
addModelsLoadedListener(startAppListening);
addAppConfigReceivedListener(startAppListening);
// Ad-hoc upscale workflwo
addAdHocPostProcessingRequestedListener(startAppListening);
addSetDefaultSettingsListener(startAppListening);

View File

@@ -1,6 +1,6 @@
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppStartListening } from 'app/store/store';
import { buildAdHocPostProcessingGraph } from 'features/nodes/util/graph/buildAdHocPostProcessingGraph';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';

View File

@@ -1,5 +1,5 @@
import { isAnyOf } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppStartListening } from 'app/store/store';
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
import {
autoAddBoardIdChanged,

View File

@@ -1,4 +1,4 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppStartListening } from 'app/store/store';
import { queueApi, selectQueueStatus } from 'services/api/endpoints/queue';
export const addAnyEnqueuedListener = (startAppListening: AppStartListening) => {

View File

@@ -1,4 +1,4 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppStartListening } from 'app/store/store';
import { setInfillMethod } from 'features/controlLayers/store/paramsSlice';
import { shouldUseNSFWCheckerChanged, shouldUseWatermarkerChanged } from 'features/system/store/systemSlice';
import { appInfoApi } from 'services/api/endpoints/appInfo';

View File

@@ -1,5 +1,5 @@
import { createAction } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppStartListening } from 'app/store/store';
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { imagesApi } from 'services/api/endpoints/images';

View File

@@ -1,5 +1,5 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppStartListening } from 'app/store/store';
import { truncate } from 'es-toolkit/compat';
import { zPydanticValidationError } from 'features/system/store/zodSchemas';
import { toast } from 'features/toast/toast';

View File

@@ -1,4 +1,4 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppStartListening } from 'app/store/store';
import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { getImageUsage } from 'features/deleteImageModal/store/state';

View File

@@ -1,5 +1,5 @@
import { isAnyOf } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppStartListening } from 'app/store/store';
import { selectGetImageNamesQueryArgs, selectSelectedBoardId } from 'features/gallery/store/gallerySelectors';
import { boardIdSelected, galleryViewChanged, imageSelected } from 'features/gallery/store/gallerySlice';
import { imagesApi } from 'services/api/endpoints/images';

View File

@@ -1,5 +1,5 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppStartListening } from 'app/store/store';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';

View File

@@ -1,5 +1,5 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppStartListening } from 'app/store/store';
import { parseify } from 'common/util/serialize';
import { size } from 'es-toolkit/compat';
import { $templates } from 'features/nodes/store/nodesSlice';

View File

@@ -1,5 +1,5 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppStartListening } from 'app/store/store';
import { imagesApi } from 'services/api/endpoints/images';
const log = logger('gallery');

View File

@@ -1,5 +1,5 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppStartListening } from 'app/store/store';
import { imagesApi } from 'services/api/endpoints/images';
const log = logger('gallery');

View File

@@ -1,7 +1,6 @@
import { isAnyOf } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { RootState } from 'app/store/store';
import type { AppStartListening, RootState } from 'app/store/store';
import { omit } from 'es-toolkit/compat';
import { imageUploadedClientSide } from 'features/gallery/store/actions';
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';

View File

@@ -1,5 +1,5 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppStartListening } from 'app/store/store';
import { bboxSyncedToOptimalDimension, rgRefImageModelChanged } from 'features/controlLayers/store/canvasSlice';
import { buildSelectIsStaging, selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';

View File

@@ -1,6 +1,5 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppDispatch, RootState } from 'app/store/store';
import type { AppDispatch, AppStartListening, RootState } from 'app/store/store';
import { controlLayerModelChanged, rgRefImageModelChanged } from 'features/controlLayers/store/canvasSlice';
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
import {

View File

@@ -1,4 +1,4 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppStartListening } from 'app/store/store';
import { isNil } from 'es-toolkit';
import { bboxHeightChanged, bboxWidthChanged } from 'features/controlLayers/store/canvasSlice';
import { buildSelectIsStaging, selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';

View File

@@ -1,8 +1,8 @@
import { objectEquals } from '@observ33r/object-equals';
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { $baseUrl } from 'app/store/nanostores/baseUrl';
import type { AppStartListening } from 'app/store/store';
import { atom } from 'nanostores';
import { api } from 'services/api';
import { modelsApi } from 'services/api/endpoints/models';

View File

@@ -1,159 +1,165 @@
import type { ThunkDispatch, UnknownAction } from '@reduxjs/toolkit';
import { autoBatchEnhancer, combineReducers, configureStore } from '@reduxjs/toolkit';
import type { ThunkDispatch, TypedStartListening, UnknownAction } from '@reduxjs/toolkit';
import { addListener, combineReducers, configureStore, createAction, createListenerMiddleware } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import { idbKeyValDriver } from 'app/store/enhancers/reduxRemember/driver';
import { errorHandler } from 'app/store/enhancers/reduxRemember/errors';
import { addAdHocPostProcessingRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/addAdHocPostProcessingRequestedListener';
import { addAnyEnqueuedListener } from 'app/store/middleware/listenerMiddleware/listeners/anyEnqueued';
import { addAppConfigReceivedListener } from 'app/store/middleware/listenerMiddleware/listeners/appConfigReceived';
import { addAppStartedListener } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
import { addBatchEnqueuedListener } from 'app/store/middleware/listenerMiddleware/listeners/batchEnqueued';
import { addDeleteBoardAndImagesFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/boardAndImagesDeleted';
import { addBoardIdSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/boardIdSelected';
import { addBulkDownloadListeners } from 'app/store/middleware/listenerMiddleware/listeners/bulkDownload';
import { addGetOpenAPISchemaListener } from 'app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema';
import { addImageAddedToBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageAddedToBoard';
import { addImageRemovedFromBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageRemovedFromBoard';
import { addModelSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelSelected';
import { addModelsLoadedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelsLoaded';
import { addSetDefaultSettingsListener } from 'app/store/middleware/listenerMiddleware/listeners/setDefaultSettings';
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketConnected';
import { deepClone } from 'common/util/deepClone';
import { keys, mergeWith, omit, pick } from 'es-toolkit/compat';
import { changeBoardModalSlice } from 'features/changeBoardModal/store/slice';
import { canvasSettingsPersistConfig, canvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
import { canvasPersistConfig, canvasSlice, canvasUndoableConfig } from 'features/controlLayers/store/canvasSlice';
import {
canvasSessionSlice,
canvasStagingAreaPersistConfig,
} from 'features/controlLayers/store/canvasStagingAreaSlice';
import { lorasPersistConfig, lorasSlice } from 'features/controlLayers/store/lorasSlice';
import { paramsPersistConfig, paramsSlice } from 'features/controlLayers/store/paramsSlice';
import { refImagesPersistConfig, refImagesSlice } from 'features/controlLayers/store/refImagesSlice';
import { dynamicPromptsPersistConfig, dynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import { galleryPersistConfig, gallerySlice } from 'features/gallery/store/gallerySlice';
import { modelManagerV2PersistConfig, modelManagerV2Slice } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { nodesPersistConfig, nodesSlice, nodesUndoableConfig } from 'features/nodes/store/nodesSlice';
import { workflowLibraryPersistConfig, workflowLibrarySlice } from 'features/nodes/store/workflowLibrarySlice';
import { workflowSettingsPersistConfig, workflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
import { upscalePersistConfig, upscaleSlice } from 'features/parameters/store/upscaleSlice';
import { queueSlice } from 'features/queue/store/queueSlice';
import { stylePresetPersistConfig, stylePresetSlice } from 'features/stylePresets/store/stylePresetSlice';
import { configSlice } from 'features/system/store/configSlice';
import { systemPersistConfig, systemSlice } from 'features/system/store/systemSlice';
import { uiPersistConfig, uiSlice } from 'features/ui/store/uiSlice';
import { changeBoardModalSliceConfig } from 'features/changeBoardModal/store/slice';
import { canvasSettingsSliceConfig } from 'features/controlLayers/store/canvasSettingsSlice';
import { canvasSliceConfig } from 'features/controlLayers/store/canvasSlice';
import { canvasSessionSliceConfig } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { lorasSliceConfig } from 'features/controlLayers/store/lorasSlice';
import { paramsSliceConfig } from 'features/controlLayers/store/paramsSlice';
import { refImagesSliceConfig } from 'features/controlLayers/store/refImagesSlice';
import { dynamicPromptsSliceConfig } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import { gallerySliceConfig } from 'features/gallery/store/gallerySlice';
import { modelManagerSliceConfig } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { nodesSliceConfig } from 'features/nodes/store/nodesSlice';
import { workflowLibrarySliceConfig } from 'features/nodes/store/workflowLibrarySlice';
import { workflowSettingsSliceConfig } from 'features/nodes/store/workflowSettingsSlice';
import { upscaleSliceConfig } from 'features/parameters/store/upscaleSlice';
import { queueSliceConfig } from 'features/queue/store/queueSlice';
import { stylePresetSliceConfig } from 'features/stylePresets/store/stylePresetSlice';
import { configSliceConfig } from 'features/system/store/configSlice';
import { systemSliceConfig } from 'features/system/store/systemSlice';
import { uiSliceConfig } from 'features/ui/store/uiSlice';
import { diff } from 'jsondiffpatch';
import dynamicMiddlewares from 'redux-dynamic-middlewares';
import type { SerializeFunction, UnserializeFunction } from 'redux-remember';
import { rememberEnhancer, rememberReducer } from 'redux-remember';
import { REMEMBER_REHYDRATED, rememberEnhancer, rememberReducer } from 'redux-remember';
import undoable, { newHistory } from 'redux-undo';
import { serializeError } from 'serialize-error';
import { api } from 'services/api';
import { authToastMiddleware } from 'services/api/authToastMiddleware';
import type { JsonObject } from 'type-fest';
import { STORAGE_PREFIX } from './constants';
import { reduxRememberDriver } from './enhancers/reduxRemember/driver';
import { actionSanitizer } from './middleware/devtools/actionSanitizer';
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
import { listenerMiddleware } from './middleware/listenerMiddleware';
import { addArchivedOrDeletedBoardListener } from './middleware/listenerMiddleware/listeners/addArchivedOrDeletedBoardListener';
import { addImageUploadedFulfilledListener } from './middleware/listenerMiddleware/listeners/imageUploaded';
export const listenerMiddleware = createListenerMiddleware();
const log = logger('system');
const allReducers = {
[api.reducerPath]: api.reducer,
[gallerySlice.name]: gallerySlice.reducer,
[nodesSlice.name]: undoable(nodesSlice.reducer, nodesUndoableConfig),
[systemSlice.name]: systemSlice.reducer,
[configSlice.name]: configSlice.reducer,
[uiSlice.name]: uiSlice.reducer,
[dynamicPromptsSlice.name]: dynamicPromptsSlice.reducer,
[changeBoardModalSlice.name]: changeBoardModalSlice.reducer,
[modelManagerV2Slice.name]: modelManagerV2Slice.reducer,
[queueSlice.name]: queueSlice.reducer,
[canvasSlice.name]: undoable(canvasSlice.reducer, canvasUndoableConfig),
[workflowSettingsSlice.name]: workflowSettingsSlice.reducer,
[upscaleSlice.name]: upscaleSlice.reducer,
[stylePresetSlice.name]: stylePresetSlice.reducer,
[paramsSlice.name]: paramsSlice.reducer,
[canvasSettingsSlice.name]: canvasSettingsSlice.reducer,
[canvasSessionSlice.name]: canvasSessionSlice.reducer,
[lorasSlice.name]: lorasSlice.reducer,
[workflowLibrarySlice.name]: workflowLibrarySlice.reducer,
[refImagesSlice.name]: refImagesSlice.reducer,
// When adding a slice, add the config to the SLICE_CONFIGS object below, then add the reducer to ALL_REDUCERS.
const SLICE_CONFIGS = {
[canvasSessionSliceConfig.slice.reducerPath]: canvasSessionSliceConfig,
[canvasSettingsSliceConfig.slice.reducerPath]: canvasSettingsSliceConfig,
[canvasSliceConfig.slice.reducerPath]: canvasSliceConfig,
[changeBoardModalSliceConfig.slice.reducerPath]: changeBoardModalSliceConfig,
[configSliceConfig.slice.reducerPath]: configSliceConfig,
[dynamicPromptsSliceConfig.slice.reducerPath]: dynamicPromptsSliceConfig,
[gallerySliceConfig.slice.reducerPath]: gallerySliceConfig,
[lorasSliceConfig.slice.reducerPath]: lorasSliceConfig,
[modelManagerSliceConfig.slice.reducerPath]: modelManagerSliceConfig,
[nodesSliceConfig.slice.reducerPath]: nodesSliceConfig,
[paramsSliceConfig.slice.reducerPath]: paramsSliceConfig,
[queueSliceConfig.slice.reducerPath]: queueSliceConfig,
[refImagesSliceConfig.slice.reducerPath]: refImagesSliceConfig,
[stylePresetSliceConfig.slice.reducerPath]: stylePresetSliceConfig,
[systemSliceConfig.slice.reducerPath]: systemSliceConfig,
[uiSliceConfig.slice.reducerPath]: uiSliceConfig,
[upscaleSliceConfig.slice.reducerPath]: upscaleSliceConfig,
[workflowLibrarySliceConfig.slice.reducerPath]: workflowLibrarySliceConfig,
[workflowSettingsSliceConfig.slice.reducerPath]: workflowSettingsSliceConfig,
};
const rootReducer = combineReducers(allReducers);
// TS makes it really hard to dynamically create this object :/ so it's just hardcoded here.
// Remember to wrap undoable reducers in `undoable()`!
const ALL_REDUCERS = {
[api.reducerPath]: api.reducer,
[canvasSessionSliceConfig.slice.reducerPath]: canvasSessionSliceConfig.slice.reducer,
[canvasSettingsSliceConfig.slice.reducerPath]: canvasSettingsSliceConfig.slice.reducer,
// Undoable!
[canvasSliceConfig.slice.reducerPath]: undoable(
canvasSliceConfig.slice.reducer,
canvasSliceConfig.undoableConfig?.reduxUndoOptions
),
[changeBoardModalSliceConfig.slice.reducerPath]: changeBoardModalSliceConfig.slice.reducer,
[configSliceConfig.slice.reducerPath]: configSliceConfig.slice.reducer,
[dynamicPromptsSliceConfig.slice.reducerPath]: dynamicPromptsSliceConfig.slice.reducer,
[gallerySliceConfig.slice.reducerPath]: gallerySliceConfig.slice.reducer,
[lorasSliceConfig.slice.reducerPath]: lorasSliceConfig.slice.reducer,
[modelManagerSliceConfig.slice.reducerPath]: modelManagerSliceConfig.slice.reducer,
// Undoable!
[nodesSliceConfig.slice.reducerPath]: undoable(
nodesSliceConfig.slice.reducer,
nodesSliceConfig.undoableConfig?.reduxUndoOptions
),
[paramsSliceConfig.slice.reducerPath]: paramsSliceConfig.slice.reducer,
[queueSliceConfig.slice.reducerPath]: queueSliceConfig.slice.reducer,
[refImagesSliceConfig.slice.reducerPath]: refImagesSliceConfig.slice.reducer,
[stylePresetSliceConfig.slice.reducerPath]: stylePresetSliceConfig.slice.reducer,
[systemSliceConfig.slice.reducerPath]: systemSliceConfig.slice.reducer,
[uiSliceConfig.slice.reducerPath]: uiSliceConfig.slice.reducer,
[upscaleSliceConfig.slice.reducerPath]: upscaleSliceConfig.slice.reducer,
[workflowLibrarySliceConfig.slice.reducerPath]: workflowLibrarySliceConfig.slice.reducer,
[workflowSettingsSliceConfig.slice.reducerPath]: workflowSettingsSliceConfig.slice.reducer,
};
const rootReducer = combineReducers(ALL_REDUCERS);
const rememberedRootReducer = rememberReducer(rootReducer);
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
export type PersistConfig<T = any> = {
/**
* The name of the slice.
*/
name: keyof typeof allReducers;
/**
* The initial state of the slice.
*/
initialState: T;
/**
* Migrate the state to the current version during rehydration.
* @param state The rehydrated state.
* @returns A correctly-shaped state.
*/
migrate: (state: unknown) => T;
/**
* Keys to omit from the persisted state.
*/
persistDenylist: (keyof T)[];
};
const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
[galleryPersistConfig.name]: galleryPersistConfig,
[nodesPersistConfig.name]: nodesPersistConfig,
[systemPersistConfig.name]: systemPersistConfig,
[uiPersistConfig.name]: uiPersistConfig,
[dynamicPromptsPersistConfig.name]: dynamicPromptsPersistConfig,
[modelManagerV2PersistConfig.name]: modelManagerV2PersistConfig,
[canvasPersistConfig.name]: canvasPersistConfig,
[workflowSettingsPersistConfig.name]: workflowSettingsPersistConfig,
[upscalePersistConfig.name]: upscalePersistConfig,
[stylePresetPersistConfig.name]: stylePresetPersistConfig,
[paramsPersistConfig.name]: paramsPersistConfig,
[canvasSettingsPersistConfig.name]: canvasSettingsPersistConfig,
[canvasStagingAreaPersistConfig.name]: canvasStagingAreaPersistConfig,
[lorasPersistConfig.name]: lorasPersistConfig,
[workflowLibraryPersistConfig.name]: workflowLibraryPersistConfig,
[refImagesSlice.name]: refImagesPersistConfig,
};
const unserialize: UnserializeFunction = (data, key) => {
const persistConfig = persistConfigs[key as keyof typeof persistConfigs];
if (!persistConfig) {
const sliceConfig = SLICE_CONFIGS[key as keyof typeof SLICE_CONFIGS];
if (!sliceConfig?.persistConfig) {
throw new Error(`No persist config for slice "${key}"`);
}
const { getInitialState, persistConfig, undoableConfig } = sliceConfig;
let state;
try {
const { initialState, migrate } = persistConfig;
const initialState = getInitialState();
const parsed = JSON.parse(data);
// strip out old keys
const stripped = pick(deepClone(parsed), keys(initialState));
// run (additive) migrations
const migrated = migrate(stripped);
/*
* Merge in initial state as default values, covering any missing keys. You might be tempted to use _.defaultsDeep,
* but that merges arrays by index and partial objects by key. Using an identity function as the customizer results
* in behaviour like defaultsDeep, but doesn't overwrite any values that are not undefined in the migrated state.
*/
const transformed = mergeWith(migrated, initialState, (objVal) => objVal);
const unPersistDenylisted = mergeWith(stripped, initialState, (objVal) => objVal);
// run (additive) migrations
const migrated = persistConfig.migrate(unPersistDenylisted);
log.debug(
{
persistedData: parsed,
rehydratedData: transformed,
diff: diff(parsed, transformed) as JsonObject, // this is always serializable
persistedData: parsed as JsonObject,
rehydratedData: migrated as JsonObject,
diff: diff(data, migrated) as JsonObject,
},
`Rehydrated slice "${key}"`
);
state = transformed;
state = migrated;
} catch (err) {
log.warn(
{ error: serializeError(err as Error) },
`Error rehydrating slice "${key}", falling back to default initial state`
);
state = persistConfig.initialState;
state = getInitialState();
}
// If the slice is undoable, we need to wrap it in a new history - only nodes and canvas are undoable at the moment.
// TODO(psyche): make this automatic & remove the hard-coding for specific slices.
if (key === nodesSlice.name || key === canvasSlice.name) {
// Undoable slices must be wrapped in a history!
if (undoableConfig) {
return newHistory([], state, []);
} else {
return state;
@@ -161,43 +167,53 @@ const unserialize: UnserializeFunction = (data, key) => {
};
const serialize: SerializeFunction = (data, key) => {
const persistConfig = persistConfigs[key as keyof typeof persistConfigs];
if (!persistConfig) {
const sliceConfig = SLICE_CONFIGS[key as keyof typeof SLICE_CONFIGS];
if (!sliceConfig?.persistConfig) {
throw new Error(`No persist config for slice "${key}"`);
}
// Heuristic to determine if the slice is undoable - could just hardcode it in the persistConfig
const isUndoable = 'present' in data && 'past' in data && 'future' in data && '_latestUnfiltered' in data;
const result = omit(isUndoable ? data.present : data, persistConfig.persistDenylist);
const result = omit(
sliceConfig.undoableConfig ? data.present : data,
sliceConfig.persistConfig.persistDenylist ?? []
);
return JSON.stringify(result);
};
export const createStore = (uniqueStoreKey?: string, persist = true) =>
configureStore({
const PERSISTED_KEYS = Object.values(SLICE_CONFIGS)
.filter((sliceConfig) => !!sliceConfig.persistConfig)
.map((sliceConfig) => sliceConfig.slice.reducerPath);
export const createStore = (options?: { persist?: boolean; persistThrottle?: number; onRehydrated?: () => void }) => {
const store = configureStore({
reducer: rememberedRootReducer,
middleware: (getDefaultMiddleware) =>
getDefaultMiddleware({
// serializableCheck: false,
// immutableCheck: false,
serializableCheck: import.meta.env.MODE === 'development',
immutableCheck: import.meta.env.MODE === 'development',
})
.concat(api.middleware)
.concat(dynamicMiddlewares)
.concat(authToastMiddleware)
// .concat(getDebugLoggerMiddleware())
// .concat(getDebugLoggerMiddleware({ withDiff: true, withNextState: true }))
.prepend(listenerMiddleware.middleware),
enhancers: (getDefaultEnhancers) => {
const _enhancers = getDefaultEnhancers().concat(autoBatchEnhancer());
if (persist) {
_enhancers.push(
rememberEnhancer(idbKeyValDriver, keys(persistConfigs), {
persistDebounce: 300,
const enhancers = getDefaultEnhancers();
if (options?.persist) {
return enhancers.prepend(
rememberEnhancer(reduxRememberDriver, PERSISTED_KEYS, {
persistThrottle: options?.persistThrottle ?? 2000,
serialize,
unserialize,
prefix: uniqueStoreKey ? `${STORAGE_PREFIX}${uniqueStoreKey}-` : STORAGE_PREFIX,
prefix: '',
errorHandler,
})
);
} else {
return enhancers;
}
return _enhancers;
},
devTools: {
actionSanitizer,
@@ -212,9 +228,62 @@ export const createStore = (uniqueStoreKey?: string, persist = true) =>
},
});
// Once-off listener to support waiting for rehydration before rendering the app
startAppListening({
actionCreator: createAction(REMEMBER_REHYDRATED),
effect: (action, { unsubscribe }) => {
unsubscribe();
options?.onRehydrated?.();
},
});
return store;
};
export type AppStore = ReturnType<typeof createStore>;
export type RootState = ReturnType<AppStore['getState']>;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
export type AppThunkDispatch = ThunkDispatch<RootState, any, UnknownAction>;
export type AppDispatch = ReturnType<typeof createStore>['dispatch'];
export type AppGetState = ReturnType<typeof createStore>['getState'];
export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
export const addAppListener = addListener.withTypes<RootState, AppDispatch>();
const startAppListening = listenerMiddleware.startListening as AppStartListening;
addImageUploadedFulfilledListener(startAppListening);
// Image deleted
addDeleteBoardAndImagesFulfilledListener(startAppListening);
// User Invoked
addAnyEnqueuedListener(startAppListening);
addBatchEnqueuedListener(startAppListening);
// Socket.IO
addSocketConnectedEventListener(startAppListening);
// Gallery bulk download
addBulkDownloadListeners(startAppListening);
// Boards
addImageAddedToBoardFulfilledListener(startAppListening);
addImageRemovedFromBoardFulfilledListener(startAppListening);
addBoardIdSelectedListener(startAppListening);
addArchivedOrDeletedBoardListener(startAppListening);
// Node schemas
addGetOpenAPISchemaListener(startAppListening);
// Models
addModelSelectedListener(startAppListening);
// app startup
addAppStartedListener(startAppListening);
addModelsLoadedListener(startAppListening);
addAppConfigReceivedListener(startAppListening);
// Ad-hoc upscale workflwo
addAdHocPostProcessingRequestedListener(startAppListening);
addSetDefaultSettingsListener(startAppListening);

View File

@@ -0,0 +1,46 @@
import type { Slice } from '@reduxjs/toolkit';
import type { UndoableOptions } from 'redux-undo';
import type { ZodType } from 'zod';
type StateFromSlice<T extends Slice> = T extends Slice<infer U> ? U : never;
export type SliceConfig<T extends Slice> = {
/**
* The redux slice (return of createSlice).
*/
slice: T;
/**
* The zod schema for the slice.
*/
schema: ZodType<StateFromSlice<T>>;
/**
* A function that returns the initial state of the slice.
*/
getInitialState: () => StateFromSlice<T>;
/**
* The optional persist configuration for this slice. If omitted, the slice will not be persisted.
*/
persistConfig?: {
/**
* Migrate the state to the current version during rehydration. This method should throw an error if the migration
* fails.
*
* @param state The rehydrated state.
* @returns A correctly-shaped state.
*/
migrate: (state: unknown) => StateFromSlice<T>;
/**
* Keys to omit from the persisted state.
*/
persistDenylist?: (keyof StateFromSlice<T>)[];
};
/**
* The optional undoable configuration for this slice. If omitted, the slice will not be undoable.
*/
undoableConfig?: {
/**
* The options to be passed into redux-undo.
*/
reduxUndoOptions: UndoableOptions<StateFromSlice<T>>;
};
};

View File

@@ -1,130 +1,299 @@
import type { FilterType } from 'features/controlLayers/store/filters';
import type { ParameterPrecision, ParameterScheduler } from 'features/parameters/types/parameterSchemas';
import type { TabName } from 'features/ui/store/uiTypes';
import { zFilterType } from 'features/controlLayers/store/filters';
import { zParameterPrecision, zParameterScheduler } from 'features/parameters/types/parameterSchemas';
import { zTabName } from 'features/ui/store/uiTypes';
import type { PartialDeep } from 'type-fest';
import z from 'zod';
/**
* A disable-able application feature
*/
export type AppFeature =
| 'faceRestore'
| 'upscaling'
| 'lightbox'
| 'modelManager'
| 'githubLink'
| 'discordLink'
| 'bugLink'
| 'aboutModal'
| 'localization'
| 'consoleLogging'
| 'dynamicPrompting'
| 'batches'
| 'syncModels'
| 'multiselect'
| 'pauseQueue'
| 'resumeQueue'
| 'invocationCache'
| 'modelCache'
| 'bulkDownload'
| 'starterModels'
| 'hfToken'
| 'retryQueueItem'
| 'cancelAndClearAll'
| 'chatGPT4oHigh'
| 'modelRelationships';
/**
* A disable-able Stable Diffusion feature
*/
export type SDFeature =
| 'controlNet'
| 'noise'
| 'perlinNoise'
| 'noiseThreshold'
| 'variation'
| 'symmetry'
| 'seamless'
| 'hires'
| 'lora'
| 'embedding'
| 'vae'
| 'hrf';
const zAppFeature = z.enum([
'faceRestore',
'upscaling',
'lightbox',
'modelManager',
'githubLink',
'discordLink',
'bugLink',
'aboutModal',
'localization',
'consoleLogging',
'dynamicPrompting',
'batches',
'syncModels',
'multiselect',
'pauseQueue',
'resumeQueue',
'invocationCache',
'modelCache',
'bulkDownload',
'starterModels',
'hfToken',
'retryQueueItem',
'cancelAndClearAll',
'chatGPT4oHigh',
'modelRelationships',
]);
export type AppFeature = z.infer<typeof zAppFeature>;
export type NumericalParameterConfig = {
initial: number;
sliderMin: number;
sliderMax: number;
numberInputMin: number;
numberInputMax: number;
fineStep: number;
coarseStep: number;
};
const zSDFeature = z.enum([
'controlNet',
'noise',
'perlinNoise',
'noiseThreshold',
'variation',
'symmetry',
'seamless',
'hires',
'lora',
'embedding',
'vae',
'hrf',
]);
export type SDFeature = z.infer<typeof zSDFeature>;
const zNumericalParameterConfig = z.object({
initial: z.number().default(512),
sliderMin: z.number().default(64),
sliderMax: z.number().default(1536),
numberInputMin: z.number().default(64),
numberInputMax: z.number().default(4096),
fineStep: z.number().default(8),
coarseStep: z.number().default(64),
});
/**
* Configuration options for the InvokeAI UI.
* Distinct from system settings which may be changed inside the app.
*/
export type AppConfig = {
export const zAppConfig = z.object({
/**
* Whether or not we should update image urls when image loading errors
*/
shouldUpdateImagesOnConnect: boolean;
shouldFetchMetadataFromApi: boolean;
shouldUpdateImagesOnConnect: z.boolean(),
shouldFetchMetadataFromApi: z.boolean(),
/**
* Sets a size limit for outputs on the upscaling tab. This is a maximum dimension, so the actual max number of pixels
* will be the square of this value.
*/
maxUpscaleDimension?: number;
allowPrivateBoards: boolean;
allowPrivateStylePresets: boolean;
allowClientSideUpload: boolean;
allowPublishWorkflows: boolean;
allowPromptExpansion: boolean;
disabledTabs: TabName[];
disabledFeatures: AppFeature[];
disabledSDFeatures: SDFeature[];
nodesAllowlist: string[] | undefined;
nodesDenylist: string[] | undefined;
metadataFetchDebounce?: number;
workflowFetchDebounce?: number;
isLocal?: boolean;
shouldShowCredits: boolean;
sd: {
defaultModel?: string;
disabledControlNetModels: string[];
disabledControlNetProcessors: FilterType[];
maxUpscaleDimension: z.number().optional(),
allowPrivateBoards: z.boolean(),
allowPrivateStylePresets: z.boolean(),
allowClientSideUpload: z.boolean(),
allowPublishWorkflows: z.boolean(),
allowPromptExpansion: z.boolean(),
disabledTabs: z.array(zTabName),
disabledFeatures: z.array(zAppFeature),
disabledSDFeatures: z.array(zSDFeature),
nodesAllowlist: z.array(z.string()).optional(),
nodesDenylist: z.array(z.string()).optional(),
metadataFetchDebounce: z.number().int().optional(),
workflowFetchDebounce: z.number().int().optional(),
isLocal: z.boolean().optional(),
shouldShowCredits: z.boolean().optional(),
sd: z.object({
defaultModel: z.string().optional(),
disabledControlNetModels: z.array(z.string()),
disabledControlNetProcessors: z.array(zFilterType),
// Core parameters
iterations: NumericalParameterConfig;
width: NumericalParameterConfig; // initial value comes from model
height: NumericalParameterConfig; // initial value comes from model
steps: NumericalParameterConfig;
guidance: NumericalParameterConfig;
cfgRescaleMultiplier: NumericalParameterConfig;
img2imgStrength: NumericalParameterConfig;
scheduler?: ParameterScheduler;
vaePrecision?: ParameterPrecision;
iterations: zNumericalParameterConfig,
width: zNumericalParameterConfig,
height: zNumericalParameterConfig,
steps: zNumericalParameterConfig,
guidance: zNumericalParameterConfig,
cfgRescaleMultiplier: zNumericalParameterConfig,
img2imgStrength: zNumericalParameterConfig,
scheduler: zParameterScheduler.optional(),
vaePrecision: zParameterPrecision.optional(),
// Canvas
boundingBoxHeight: NumericalParameterConfig; // initial value comes from model
boundingBoxWidth: NumericalParameterConfig; // initial value comes from model
scaledBoundingBoxHeight: NumericalParameterConfig; // initial value comes from model
scaledBoundingBoxWidth: NumericalParameterConfig; // initial value comes from model
canvasCoherenceStrength: NumericalParameterConfig;
canvasCoherenceEdgeSize: NumericalParameterConfig;
infillTileSize: NumericalParameterConfig;
infillPatchmatchDownscaleSize: NumericalParameterConfig;
boundingBoxHeight: zNumericalParameterConfig,
boundingBoxWidth: zNumericalParameterConfig,
scaledBoundingBoxHeight: zNumericalParameterConfig,
scaledBoundingBoxWidth: zNumericalParameterConfig,
canvasCoherenceStrength: zNumericalParameterConfig,
canvasCoherenceEdgeSize: zNumericalParameterConfig,
infillTileSize: zNumericalParameterConfig,
infillPatchmatchDownscaleSize: zNumericalParameterConfig,
// Misc advanced
clipSkip: NumericalParameterConfig; // slider and input max are ignored for this, because the values depend on the model
maskBlur: NumericalParameterConfig;
hrfStrength: NumericalParameterConfig;
dynamicPrompts: {
maxPrompts: NumericalParameterConfig;
};
ca: {
weight: NumericalParameterConfig;
};
};
flux: {
guidance: NumericalParameterConfig;
};
};
clipSkip: zNumericalParameterConfig, // slider and input max are ignored for this, because the values depend on the model
maskBlur: zNumericalParameterConfig,
hrfStrength: zNumericalParameterConfig,
dynamicPrompts: z.object({
maxPrompts: zNumericalParameterConfig,
}),
ca: z.object({
weight: zNumericalParameterConfig,
}),
}),
flux: z.object({
guidance: zNumericalParameterConfig,
}),
});
export type AppConfig = z.infer<typeof zAppConfig>;
export type PartialAppConfig = PartialDeep<AppConfig>;
export const getDefaultAppConfig = (): AppConfig => ({
isLocal: true,
shouldUpdateImagesOnConnect: false,
shouldFetchMetadataFromApi: false,
allowPrivateBoards: false,
allowPrivateStylePresets: false,
allowClientSideUpload: false,
allowPublishWorkflows: false,
allowPromptExpansion: false,
shouldShowCredits: false,
disabledTabs: [],
disabledFeatures: ['lightbox', 'faceRestore', 'batches'] satisfies AppFeature[],
disabledSDFeatures: ['variation', 'symmetry', 'hires', 'perlinNoise', 'noiseThreshold'] satisfies SDFeature[],
sd: {
disabledControlNetModels: [],
disabledControlNetProcessors: [],
iterations: {
initial: 1,
sliderMin: 1,
sliderMax: 1000,
numberInputMin: 1,
numberInputMax: 10000,
fineStep: 1,
coarseStep: 1,
},
width: zNumericalParameterConfig.parse({}), // initial value comes from model
height: zNumericalParameterConfig.parse({}), // initial value comes from model
boundingBoxWidth: zNumericalParameterConfig.parse({}), // initial value comes from model
boundingBoxHeight: zNumericalParameterConfig.parse({}), // initial value comes from model
scaledBoundingBoxWidth: zNumericalParameterConfig.parse({}), // initial value comes from model
scaledBoundingBoxHeight: zNumericalParameterConfig.parse({}), // initial value comes from model
scheduler: 'dpmpp_3m_k' as const,
vaePrecision: 'fp32' as const,
steps: {
initial: 30,
sliderMin: 1,
sliderMax: 100,
numberInputMin: 1,
numberInputMax: 500,
fineStep: 1,
coarseStep: 1,
},
guidance: {
initial: 7,
sliderMin: 1,
sliderMax: 20,
numberInputMin: 1,
numberInputMax: 200,
fineStep: 0.1,
coarseStep: 0.5,
},
img2imgStrength: {
initial: 0.7,
sliderMin: 0,
sliderMax: 1,
numberInputMin: 0,
numberInputMax: 1,
fineStep: 0.01,
coarseStep: 0.05,
},
canvasCoherenceStrength: {
initial: 0.3,
sliderMin: 0,
sliderMax: 1,
numberInputMin: 0,
numberInputMax: 1,
fineStep: 0.01,
coarseStep: 0.05,
},
hrfStrength: {
initial: 0.45,
sliderMin: 0,
sliderMax: 1,
numberInputMin: 0,
numberInputMax: 1,
fineStep: 0.01,
coarseStep: 0.05,
},
canvasCoherenceEdgeSize: {
initial: 16,
sliderMin: 0,
sliderMax: 128,
numberInputMin: 0,
numberInputMax: 1024,
fineStep: 8,
coarseStep: 16,
},
cfgRescaleMultiplier: {
initial: 0,
sliderMin: 0,
sliderMax: 0.99,
numberInputMin: 0,
numberInputMax: 0.99,
fineStep: 0.05,
coarseStep: 0.1,
},
clipSkip: {
initial: 0,
sliderMin: 0,
sliderMax: 12, // determined by model selection, unused in practice
numberInputMin: 0,
numberInputMax: 12, // determined by model selection, unused in practice
fineStep: 1,
coarseStep: 1,
},
infillPatchmatchDownscaleSize: {
initial: 1,
sliderMin: 1,
sliderMax: 10,
numberInputMin: 1,
numberInputMax: 10,
fineStep: 1,
coarseStep: 1,
},
infillTileSize: {
initial: 32,
sliderMin: 16,
sliderMax: 64,
numberInputMin: 16,
numberInputMax: 256,
fineStep: 1,
coarseStep: 1,
},
maskBlur: {
initial: 16,
sliderMin: 0,
sliderMax: 128,
numberInputMin: 0,
numberInputMax: 512,
fineStep: 1,
coarseStep: 1,
},
ca: {
weight: {
initial: 1,
sliderMin: 0,
sliderMax: 2,
numberInputMin: -1,
numberInputMax: 2,
fineStep: 0.01,
coarseStep: 0.05,
},
},
dynamicPrompts: {
maxPrompts: {
initial: 100,
sliderMin: 1,
sliderMax: 1000,
numberInputMin: 1,
numberInputMax: 10000,
fineStep: 1,
coarseStep: 10,
},
},
},
flux: {
guidance: {
initial: 4,
sliderMin: 2,
sliderMax: 6,
numberInputMin: 1,
numberInputMax: 20,
fineStep: 0.1,
coarseStep: 0.5,
},
},
});

View File

@@ -1,9 +1,9 @@
import { MenuItem } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { canvasReset } from 'features/controlLayers/store/actions';
import { inpaintMaskAdded } from 'features/controlLayers/store/canvasSlice';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { allEntitiesDeleted, inpaintMaskAdded } from 'features/controlLayers/store/canvasSlice';
import { $canvasManager } from 'features/controlLayers/store/ephemeral';
import { paramsReset } from 'features/controlLayers/store/paramsSlice';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowsCounterClockwiseBold } from 'react-icons/pi';
@@ -11,9 +11,10 @@ import { PiArrowsCounterClockwiseBold } from 'react-icons/pi';
export const SessionMenuItems = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const tab = useAppSelector(selectActiveTab);
const resetCanvasLayers = useCallback(() => {
dispatch(canvasReset());
dispatch(allEntitiesDeleted());
dispatch(inpaintMaskAdded({ isSelected: true, isBookmarked: true }));
$canvasManager.get()?.stage.fitBboxToStage();
}, [dispatch]);
@@ -22,12 +23,16 @@ export const SessionMenuItems = memo(() => {
}, [dispatch]);
return (
<>
<MenuItem icon={<PiArrowsCounterClockwiseBold />} onClick={resetCanvasLayers}>
{t('controlLayers.resetCanvasLayers')}
</MenuItem>
<MenuItem icon={<PiArrowsCounterClockwiseBold />} onClick={resetGenerationSettings}>
{t('controlLayers.resetGenerationSettings')}
</MenuItem>
{tab === 'canvas' && (
<MenuItem icon={<PiArrowsCounterClockwiseBold />} onClick={resetCanvasLayers}>
{t('controlLayers.resetCanvasLayers')}
</MenuItem>
)}
{(tab === 'canvas' || tab === 'generate') && (
<MenuItem icon={<PiArrowsCounterClockwiseBold />} onClick={resetGenerationSettings}>
{t('controlLayers.resetGenerationSettings')}
</MenuItem>
)}
</>
);
});

View File

@@ -1,11 +0,0 @@
import { clearIdbKeyValStore } from 'app/store/enhancers/reduxRemember/driver';
import { useCallback } from 'react';
export const useClearStorage = () => {
const clearStorage = useCallback(() => {
clearIdbKeyValStore();
localStorage.clear();
}, []);
return clearStorage;
};

View File

@@ -1,6 +0,0 @@
import type { ChangeBoardModalState } from './types';
export const initialState: ChangeBoardModalState = {
isModalOpen: false,
image_names: [],
};

View File

@@ -1,12 +1,20 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import z from 'zod';
import { initialState } from './initialState';
const zChangeBoardModalState = z.object({
isModalOpen: z.boolean().default(false),
image_names: z.array(z.string()).default(() => []),
});
type ChangeBoardModalState = z.infer<typeof zChangeBoardModalState>;
export const changeBoardModalSlice = createSlice({
const getInitialState = (): ChangeBoardModalState => zChangeBoardModalState.parse({});
const slice = createSlice({
name: 'changeBoardModal',
initialState,
initialState: getInitialState(),
reducers: {
isModalOpenChanged: (state, action: PayloadAction<boolean>) => {
state.isModalOpen = action.payload;
@@ -21,6 +29,12 @@ export const changeBoardModalSlice = createSlice({
},
});
export const { isModalOpenChanged, imagesToChangeSelected, changeBoardReset } = changeBoardModalSlice.actions;
export const { isModalOpenChanged, imagesToChangeSelected, changeBoardReset } = slice.actions;
export const selectChangeBoardModalSlice = (state: RootState) => state.changeBoardModal;
export const changeBoardModalSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zChangeBoardModalState,
getInitialState,
};

View File

@@ -1,4 +0,0 @@
export type ChangeBoardModalState = {
isModalOpen: boolean;
image_names: string[];
};

View File

@@ -1,15 +1,20 @@
import { Flex } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
import { UploadImageIconButton } from 'common/hooks/useImageUploadButton';
import { bboxSizeOptimized, bboxSizeRecalled } from 'features/controlLayers/store/canvasSlice';
import { useCanvasIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { sizeOptimized, sizeRecalled } from 'features/controlLayers/store/paramsSlice';
import type { ImageWithDims } from 'features/controlLayers/store/types';
import type { setGlobalReferenceImageDndTarget, setRegionalGuidanceReferenceImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { DndImage } from 'features/dnd/DndImage';
import { DndImageIcon } from 'features/dnd/DndImageIcon';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { memo, useCallback, useEffect } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowCounterClockwiseBold } from 'react-icons/pi';
import { PiArrowCounterClockwiseBold, PiRulerBold } from 'react-icons/pi';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import { $isConnected } from 'services/events/stores';
@@ -29,7 +34,10 @@ export const RefImageImage = memo(
dndTargetData,
}: Props<T>) => {
const { t } = useTranslation();
const store = useAppStore();
const isConnected = useStore($isConnected);
const tab = useAppSelector(selectActiveTab);
const isStaging = useCanvasIsStaging();
const { currentData: imageDTO, isError } = useGetImageDTOQuery(image?.image_name ?? skipToken);
const handleResetControlImage = useCallback(() => {
onChangeImage(null);
@@ -48,6 +56,20 @@ export const RefImageImage = memo(
[onChangeImage]
);
const recallSizeAndOptimize = useCallback(() => {
if (!imageDTO || (tab === 'canvas' && isStaging)) {
return;
}
const { width, height } = imageDTO;
if (tab === 'canvas') {
store.dispatch(bboxSizeRecalled({ width, height }));
store.dispatch(bboxSizeOptimized());
} else if (tab === 'generate') {
store.dispatch(sizeRecalled({ width, height }));
store.dispatch(sizeOptimized());
}
}, [imageDTO, isStaging, store, tab]);
return (
<Flex position="relative" w="full" h="full" alignItems="center" data-error={!imageDTO && !image?.image_name}>
{!imageDTO && (
@@ -69,6 +91,14 @@ export const RefImageImage = memo(
tooltip={t('common.reset')}
/>
</Flex>
<Flex position="absolute" flexDir="column" bottom={2} insetInlineEnd={2} gap={1}>
<DndImageIcon
onClick={recallSizeAndOptimize}
icon={<PiRulerBold size={16} />}
tooltip={t('parameters.useSize')}
isDisabled={!imageDTO || (tab === 'canvas' && isStaging)}
/>
</Flex>
</>
)}
<DndDropTarget dndTarget={dndTarget} dndTargetData={dndTargetData} label={t('gallery.drop')} />

View File

@@ -63,6 +63,7 @@ RefImageList.displayName = 'RefImageList';
const dndTargetData = addGlobalReferenceImageDndTarget.getData();
const MaxRefImages = memo(() => {
const { t } = useTranslation();
return (
<Button
position="relative"
@@ -75,7 +76,7 @@ const MaxRefImages = memo(() => {
borderRadius="base"
isDisabled
>
Max Ref Images
{t('controlLayers.maxRefImages')}
</Button>
);
});
@@ -83,6 +84,7 @@ MaxRefImages.displayName = 'MaxRefImages';
const AddRefImageDropTargetAndButton = memo(() => {
const { dispatch, getState } = useAppStore();
const { t } = useTranslation();
const tab = useAppSelector(selectActiveTab);
const uploadOptions = useMemo(
@@ -114,7 +116,7 @@ const AddRefImageDropTargetAndButton = memo(() => {
leftIcon={<PiUploadBold />}
{...uploadApi.getUploadButtonProps()}
>
Reference Image
{t('controlLayers.referenceImage')}
<input {...uploadApi.getUploadInputProps()} />
<DndDropTarget label="Drop" dndTarget={addGlobalReferenceImageDndTarget} dndTargetData={dndTargetData} />
</Button>

View File

@@ -1,6 +1,8 @@
import { IconButton } from '@invoke-ai/ui-library';
import { useIsRegionFocused } from 'common/hooks/focus';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiResizeBold } from 'react-icons/pi';
@@ -9,9 +11,23 @@ export const CanvasToolbarFitBboxToLayersButton = memo(() => {
const { t } = useTranslation();
const canvasManager = useCanvasManager();
const isBusy = useCanvasIsBusy();
const isCanvasFocused = useIsRegionFocused('canvas');
const onClick = useCallback(() => {
canvasManager.tool.tools.bbox.fitToLayers();
}, [canvasManager.tool.tools.bbox]);
canvasManager.stage.fitLayersToStage();
}, [canvasManager.tool.tools.bbox, canvasManager.stage]);
useRegisteredHotkeys({
id: 'fitBboxToLayers',
category: 'canvas',
callback: () => {
canvasManager.tool.tools.bbox.fitToLayers();
canvasManager.stage.fitLayersToStage();
},
options: { enabled: isCanvasFocused && !isBusy, preventDefault: true },
dependencies: [isCanvasFocused, isBusy],
});
return (
<IconButton

View File

@@ -1,7 +1,7 @@
import { $alt, $ctrl, $meta, $shift } from '@invoke-ai/ui-library';
import type { Selector } from '@reduxjs/toolkit';
import { addAppListener } from 'app/store/middleware/listenerMiddleware';
import type { AppStore, RootState } from 'app/store/store';
import { addAppListener } from 'app/store/store';
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer';
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
@@ -319,6 +319,14 @@ export class CanvasStateApiModule extends CanvasModuleBase {
getPositionGridSize = (): number => {
const snapToGrid = this.getSettings().snapToGrid;
if (!snapToGrid) {
const overrideSnap = this.$ctrlKey.get() || this.$metaKey.get();
if (overrideSnap) {
const useFine = this.$shiftKey.get();
if (useFine) {
return 8;
}
return 64;
}
return 1;
}
const useFine = this.$ctrlKey.get() || this.$metaKey.get();

View File

@@ -1,6 +1,7 @@
import type { PayloadAction, Selector } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { zRgbaColor } from 'features/controlLayers/store/types';
import { z } from 'zod';
@@ -11,32 +12,32 @@ const zCanvasSettingsState = z.object({
/**
* Whether to show HUD (Heads-Up Display) on the canvas.
*/
showHUD: z.boolean().default(true),
showHUD: z.boolean(),
/**
* Whether to clip lines and shapes to the generation bounding box. If disabled, lines and shapes will be clipped to
* the canvas bounds.
*/
clipToBbox: z.boolean().default(false),
clipToBbox: z.boolean(),
/**
* Whether to show a dynamic grid on the canvas. If disabled, a checkerboard pattern will be shown instead.
*/
dynamicGrid: z.boolean().default(false),
dynamicGrid: z.boolean(),
/**
* Whether to invert the scroll direction when adjusting the brush or eraser width with the scroll wheel.
*/
invertScrollForToolWidth: z.boolean().default(false),
invertScrollForToolWidth: z.boolean(),
/**
* The width of the brush tool.
*/
brushWidth: z.int().gt(0).default(50),
brushWidth: z.int().gt(0),
/**
* The width of the eraser tool.
*/
eraserWidth: z.int().gt(0).default(50),
eraserWidth: z.int().gt(0),
/**
* The color to use when drawing lines or filling shapes.
*/
color: zRgbaColor.default({ r: 31, g: 160, b: 224, a: 1 }), // invokeBlue.500
color: zRgbaColor,
/**
* Whether to composite inpainted/outpainted regions back onto the source image when saving canvas generations.
*
@@ -44,57 +45,77 @@ const zCanvasSettingsState = z.object({
*
* When `sendToCanvas` is disabled, this setting is ignored, masked regions will always be composited.
*/
outputOnlyMaskedRegions: z.boolean().default(true),
outputOnlyMaskedRegions: z.boolean(),
/**
* Whether to automatically process the operations like filtering and auto-masking.
*/
autoProcess: z.boolean().default(true),
autoProcess: z.boolean(),
/**
* The snap-to-grid setting for the canvas.
*/
snapToGrid: z.boolean().default(true),
snapToGrid: z.boolean(),
/**
* Whether to show progress on the canvas when generating images.
*/
showProgressOnCanvas: z.boolean().default(true),
showProgressOnCanvas: z.boolean(),
/**
* Whether to show the bounding box overlay on the canvas.
*/
bboxOverlay: z.boolean().default(false),
bboxOverlay: z.boolean(),
/**
* Whether to preserve the masked region instead of inpainting it.
*/
preserveMask: z.boolean().default(false),
preserveMask: z.boolean(),
/**
* Whether to show only raster layers while staging.
*/
isolatedStagingPreview: z.boolean().default(true),
isolatedStagingPreview: z.boolean(),
/**
* Whether to show only the selected layer while filtering, transforming, or doing other operations.
*/
isolatedLayerPreview: z.boolean().default(true),
isolatedLayerPreview: z.boolean(),
/**
* Whether to use pressure sensitivity for the brush and eraser tool when a pen device is used.
*/
pressureSensitivity: z.boolean().default(true),
pressureSensitivity: z.boolean(),
/**
* Whether to show the rule of thirds composition guide overlay on the canvas.
*/
ruleOfThirds: z.boolean().default(false),
ruleOfThirds: z.boolean(),
/**
* Whether to save all staging images to the gallery instead of keeping them as intermediate images.
*/
saveAllImagesToGallery: z.boolean().default(false),
saveAllImagesToGallery: z.boolean(),
/**
* The auto-switch mode for the canvas staging area.
*/
stagingAreaAutoSwitch: zAutoSwitchMode.default('switch_on_start'),
stagingAreaAutoSwitch: zAutoSwitchMode,
});
type CanvasSettingsState = z.infer<typeof zCanvasSettingsState>;
const getInitialState = () => zCanvasSettingsState.parse({});
const getInitialState = (): CanvasSettingsState => ({
showHUD: true,
clipToBbox: false,
dynamicGrid: false,
invertScrollForToolWidth: false,
brushWidth: 50,
eraserWidth: 50,
color: { r: 31, g: 160, b: 224, a: 1 }, // invokeBlue.500
outputOnlyMaskedRegions: true,
autoProcess: true,
snapToGrid: true,
showProgressOnCanvas: true,
bboxOverlay: false,
preserveMask: false,
isolatedStagingPreview: true,
isolatedLayerPreview: true,
pressureSensitivity: true,
ruleOfThirds: false,
saveAllImagesToGallery: false,
stagingAreaAutoSwitch: 'switch_on_start',
});
export const canvasSettingsSlice = createSlice({
const slice = createSlice({
name: 'canvasSettings',
initialState: getInitialState(),
reducers: {
@@ -184,18 +205,15 @@ export const {
settingsRuleOfThirdsToggled,
settingsSaveAllImagesToGalleryToggled,
settingsStagingAreaAutoSwitchChanged,
} = canvasSettingsSlice.actions;
} = slice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrate = (state: any): any => {
return state;
};
export const canvasSettingsPersistConfig: PersistConfig<CanvasSettingsState> = {
name: canvasSettingsSlice.name,
initialState: getInitialState(),
migrate,
persistDenylist: [],
export const canvasSettingsSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zCanvasSettingsState,
getInitialState,
persistConfig: {
migrate: (state) => zCanvasSettingsState.parse(state),
},
};
export const selectCanvasSettingsSlice = (s: RootState) => s.canvasSettings;

View File

@@ -1,6 +1,6 @@
import type { PayloadAction, UnknownAction } from '@reduxjs/toolkit';
import { createSlice, isAnyOf } from '@reduxjs/toolkit';
import type { PersistConfig } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { moveOneToEnd, moveOneToStart, moveToEnd, moveToStart } from 'common/util/arrayUtils';
import { deepClone } from 'common/util/deepClone';
import { roundDownToMultiple, roundToMultiple } from 'common/util/roundDownToMultiple';
@@ -80,6 +80,7 @@ import {
isFLUXReduxConfig,
isImagenAspectRatioID,
isIPAdapterConfig,
zCanvasState,
} from './types';
import {
converters,
@@ -95,7 +96,7 @@ import {
initialT2IAdapter,
} from './util';
export const canvasSlice = createSlice({
const slice = createSlice({
name: 'canvas',
initialState: getInitialCanvasState(),
reducers: {
@@ -1090,6 +1091,15 @@ export const canvasSlice = createSlice({
syncScaledSize(state);
},
bboxSizeRecalled: (state, action: PayloadAction<{ width: number; height: number }>) => {
const { width, height } = action.payload;
const gridSize = getGridSize(state.bbox.modelBase);
state.bbox.rect.width = Math.max(roundDownToMultiple(width, gridSize), 64);
state.bbox.rect.height = Math.max(roundDownToMultiple(height, gridSize), 64);
state.bbox.aspectRatio.value = state.bbox.rect.width / state.bbox.rect.height;
state.bbox.aspectRatio.id = 'Free';
state.bbox.aspectRatio.isLocked = true;
},
bboxAspectRatioLockToggled: (state) => {
state.bbox.aspectRatio.isLocked = !state.bbox.aspectRatio.isLocked;
syncScaledSize(state);
@@ -1618,6 +1628,7 @@ export const {
entityArrangedToBack,
entityOpacityChanged,
entitiesReordered,
allEntitiesDeleted,
allEntitiesOfTypeIsHiddenToggled,
allNonRasterLayersIsHiddenToggled,
// bbox
@@ -1625,6 +1636,7 @@ export const {
bboxScaledWidthChanged,
bboxScaledHeightChanged,
bboxScaleMethodChanged,
bboxSizeRecalled,
bboxWidthChanged,
bboxHeightChanged,
bboxAspectRatioLockToggled,
@@ -1675,19 +1687,7 @@ export const {
inpaintMaskDenoiseLimitChanged,
inpaintMaskDenoiseLimitDeleted,
// inpaintMaskRecalled,
} = canvasSlice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrate = (state: any): any => {
return state;
};
export const canvasPersistConfig: PersistConfig<CanvasState> = {
name: canvasSlice.name,
initialState: getInitialCanvasState(),
migrate,
persistDenylist: [],
};
} = slice.actions;
const syncScaledSize = (state: CanvasState) => {
if (API_BASE_MODELS.includes(state.bbox.modelBase)) {
@@ -1710,14 +1710,14 @@ const syncScaledSize = (state: CanvasState) => {
let filter = true;
export const canvasUndoableConfig: UndoableOptions<CanvasState, UnknownAction> = {
const canvasUndoableConfig: UndoableOptions<CanvasState, UnknownAction> = {
limit: 64,
undoType: canvasUndo.type,
redoType: canvasRedo.type,
clearHistoryType: canvasClearHistory.type,
filter: (action, _state, _history) => {
// Ignore all actions from other slices
if (!action.type.startsWith(canvasSlice.name)) {
if (!action.type.startsWith(slice.name)) {
return false;
}
// Throttle rapid actions of the same type
@@ -1728,6 +1728,18 @@ export const canvasUndoableConfig: UndoableOptions<CanvasState, UnknownAction> =
// debug: import.meta.env.MODE === 'development',
};
export const canvasSliceConfig: SliceConfig<typeof slice> = {
slice,
getInitialState: getInitialCanvasState,
schema: zCanvasState,
persistConfig: {
migrate: (state) => zCanvasState.parse(state),
},
undoableConfig: {
reduxUndoOptions: canvasUndoableConfig,
},
};
const doNotGroupMatcher = isAnyOf(entityBrushLineAdded, entityEraserLineAdded, entityRectAdded);
// Store rapid actions of the same type at most once every x time.

View File

@@ -1,27 +1,29 @@
import { createSelector, createSlice, type PayloadAction } from '@reduxjs/toolkit';
import { EMPTY_ARRAY } from 'app/store/constants';
import type { PersistConfig, RootState } from 'app/store/store';
import type { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { deepClone } from 'common/util/deepClone';
import type { SliceConfig } from 'app/store/types';
import { isPlainObject } from 'es-toolkit';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { useMemo } from 'react';
import { queueApi } from 'services/api/endpoints/queue';
import { assert } from 'tsafe';
import z from 'zod';
type CanvasStagingAreaState = {
_version: 1;
canvasSessionId: string;
canvasDiscardedQueueItems: number[];
};
const zCanvasStagingAreaState = z.object({
_version: z.literal(1),
canvasSessionId: z.string(),
canvasDiscardedQueueItems: z.array(z.number().int()),
});
type CanvasStagingAreaState = z.infer<typeof zCanvasStagingAreaState>;
const INITIAL_STATE: CanvasStagingAreaState = {
const getInitialState = (): CanvasStagingAreaState => ({
_version: 1,
canvasSessionId: getPrefixedId('canvas'),
canvasDiscardedQueueItems: [],
};
});
const getInitialState = (): CanvasStagingAreaState => deepClone(INITIAL_STATE);
export const canvasSessionSlice = createSlice({
const slice = createSlice({
name: 'canvasSession',
initialState: getInitialState(),
reducers: {
@@ -48,26 +50,26 @@ export const canvasSessionSlice = createSlice({
},
});
export const { canvasSessionReset, canvasQueueItemDiscarded } = canvasSessionSlice.actions;
export const { canvasSessionReset, canvasQueueItemDiscarded } = slice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrate = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
state.canvasSessionId = state.canvasSessionId ?? getPrefixedId('canvas');
}
export const canvasSessionSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zCanvasStagingAreaState,
getInitialState,
persistConfig: {
migrate: (state) => {
assert(isPlainObject(state));
if (!('_version' in state)) {
state._version = 1;
state.canvasSessionId = state.canvasSessionId ?? getPrefixedId('canvas');
}
return state;
return zCanvasStagingAreaState.parse(state);
},
},
};
export const canvasStagingAreaPersistConfig: PersistConfig<CanvasStagingAreaState> = {
name: canvasSessionSlice.name,
initialState: getInitialState(),
migrate,
persistDenylist: [],
};
export const selectCanvasSessionSlice = (s: RootState) => s[canvasSessionSlice.name];
export const selectCanvasSessionSlice = (s: RootState) => s[slice.name];
export const selectCanvasSessionId = createSelector(selectCanvasSessionSlice, ({ canvasSessionId }) => canvasSessionId);
const selectDiscardedItems = createSelector(

View File

@@ -166,7 +166,7 @@ const _zFilterConfig = z.discriminatedUnion('type', [
]);
export type FilterConfig = z.infer<typeof _zFilterConfig>;
const zFilterType = z.enum([
export const zFilterType = z.enum([
'adjust_image',
'canny_edge_detection',
'color_map',

View File

@@ -1,30 +1,32 @@
import { createSelector, createSlice, type PayloadAction } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { paramsReset } from 'features/controlLayers/store/paramsSlice';
import type { LoRA } from 'features/controlLayers/store/types';
import { type LoRA, zLoRA } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
import type { LoRAModelConfig } from 'services/api/types';
import { v4 as uuidv4 } from 'uuid';
import z from 'zod';
type LoRAsState = {
loras: LoRA[];
};
const zLoRAsState = z.object({
loras: z.array(zLoRA),
});
type LoRAsState = z.infer<typeof zLoRAsState>;
const defaultLoRAConfig: Pick<LoRA, 'weight' | 'isEnabled'> = {
weight: 0.75,
isEnabled: true,
};
const initialState: LoRAsState = {
const getInitialState = (): LoRAsState => ({
loras: [],
};
});
const selectLoRA = (state: LoRAsState, id: string) => state.loras.find((lora) => lora.id === id);
export const lorasSlice = createSlice({
const slice = createSlice({
name: 'loras',
initialState,
initialState: getInitialState(),
reducers: {
loraAdded: {
reducer: (state, action: PayloadAction<{ model: LoRAModelConfig; id: string }>) => {
@@ -66,24 +68,21 @@ export const lorasSlice = createSlice({
extraReducers(builder) {
builder.addCase(paramsReset, () => {
// When a new session is requested, clear all LoRAs
return deepClone(initialState);
return getInitialState();
});
},
});
export const { loraAdded, loraRecalled, loraDeleted, loraWeightChanged, loraIsEnabledChanged, loraAllDeleted } =
lorasSlice.actions;
slice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrate = (state: any): any => {
return state;
};
export const lorasPersistConfig: PersistConfig<LoRAsState> = {
name: lorasSlice.name,
initialState,
migrate,
persistDenylist: [],
export const lorasSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zLoRAsState,
getInitialState,
persistConfig: {
migrate: (state) => zLoRAsState.parse(state),
},
};
export const selectLoRAsSlice = (state: RootState) => state.loras;

View File

@@ -1,6 +1,7 @@
import type { PayloadAction, Selector } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { deepClone } from 'common/util/deepClone';
import { roundDownToMultiple, roundToMultiple } from 'common/util/roundDownToMultiple';
import { clamp } from 'es-toolkit/compat';
@@ -15,6 +16,7 @@ import {
isChatGPT4oAspectRatioID,
isFluxKontextAspectRatioID,
isImagenAspectRatioID,
zParamsState,
} from 'features/controlLayers/store/types';
import { calculateNewSize } from 'features/controlLayers/util/getScaledBoundingBoxDimensions';
import { CLIP_SKIP_MAP } from 'features/parameters/types/constants';
@@ -40,7 +42,7 @@ import { getGridSize, getIsSizeOptimal, getOptimalDimension } from 'features/par
import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models';
import { isNonRefinerMainModelConfig } from 'services/api/types';
export const paramsSlice = createSlice({
const slice = createSlice({
name: 'params',
initialState: getInitialParamsState(),
reducers: {
@@ -92,7 +94,12 @@ export const paramsSlice = createSlice({
state,
action: PayloadAction<{ model: ParameterModel | null; previousModel?: ParameterModel | null }>
) => {
const { model, previousModel } = action.payload;
const { previousModel } = action.payload;
const result = zParamsState.shape.model.safeParse(action.payload.model);
if (!result.success) {
return;
}
const model = result.data;
state.model = model;
// If the model base changes (e.g. SD1.5 -> SDXL), we need to change a few things
@@ -111,25 +118,53 @@ export const paramsSlice = createSlice({
},
vaeSelected: (state, action: PayloadAction<ParameterVAEModel | null>) => {
// null is a valid VAE!
state.vae = action.payload;
const result = zParamsState.shape.vae.safeParse(action.payload);
if (!result.success) {
return;
}
state.vae = result.data;
},
fluxVAESelected: (state, action: PayloadAction<ParameterVAEModel | null>) => {
state.fluxVAE = action.payload;
const result = zParamsState.shape.fluxVAE.safeParse(action.payload);
if (!result.success) {
return;
}
state.fluxVAE = result.data;
},
t5EncoderModelSelected: (state, action: PayloadAction<ParameterT5EncoderModel | null>) => {
state.t5EncoderModel = action.payload;
const result = zParamsState.shape.t5EncoderModel.safeParse(action.payload);
if (!result.success) {
return;
}
state.t5EncoderModel = result.data;
},
controlLoRAModelSelected: (state, action: PayloadAction<ParameterControlLoRAModel | null>) => {
state.controlLora = action.payload;
const result = zParamsState.shape.controlLora.safeParse(action.payload);
if (!result.success) {
return;
}
state.controlLora = result.data;
},
clipEmbedModelSelected: (state, action: PayloadAction<ParameterCLIPEmbedModel | null>) => {
state.clipEmbedModel = action.payload;
const result = zParamsState.shape.clipEmbedModel.safeParse(action.payload);
if (!result.success) {
return;
}
state.clipEmbedModel = result.data;
},
clipLEmbedModelSelected: (state, action: PayloadAction<ParameterCLIPLEmbedModel | null>) => {
state.clipLEmbedModel = action.payload;
const result = zParamsState.shape.clipLEmbedModel.safeParse(action.payload);
if (!result.success) {
return;
}
state.clipLEmbedModel = result.data;
},
clipGEmbedModelSelected: (state, action: PayloadAction<ParameterCLIPGEmbedModel | null>) => {
state.clipGEmbedModel = action.payload;
const result = zParamsState.shape.clipGEmbedModel.safeParse(action.payload);
if (!result.success) {
return;
}
state.clipGEmbedModel = result.data;
},
vaePrecisionChanged: (state, action: PayloadAction<ParameterPrecision>) => {
state.vaePrecision = action.payload;
@@ -156,7 +191,11 @@ export const paramsSlice = createSlice({
state.shouldConcatPrompts = action.payload;
},
refinerModelChanged: (state, action: PayloadAction<ParameterSDXLRefinerModel | null>) => {
state.refinerModel = action.payload;
const result = zParamsState.shape.refinerModel.safeParse(action.payload);
if (!result.success) {
return;
}
state.refinerModel = result.data;
},
setRefinerSteps: (state, action: PayloadAction<number>) => {
state.refinerSteps = action.payload;
@@ -202,6 +241,15 @@ export const paramsSlice = createSlice({
},
//#region Dimensions
sizeRecalled: (state, action: PayloadAction<{ width: number; height: number }>) => {
const { width, height } = action.payload;
const gridSize = getGridSize(state.model?.base);
state.dimensions.rect.width = Math.max(roundDownToMultiple(width, gridSize), 64);
state.dimensions.rect.height = Math.max(roundDownToMultiple(height, gridSize), 64);
state.dimensions.aspectRatio.value = state.dimensions.rect.width / state.dimensions.rect.height;
state.dimensions.aspectRatio.id = 'Free';
state.dimensions.aspectRatio.isLocked = true;
},
widthChanged: (state, action: PayloadAction<{ width: number; updateAspectRatio?: boolean; clamp?: boolean }>) => {
const { width, updateAspectRatio, clamp } = action.payload;
const gridSize = getGridSize(state.model?.base);
@@ -330,14 +378,16 @@ export const paramsSlice = createSlice({
const resetState = (state: ParamsState): ParamsState => {
// When a new session is requested, we need to keep the current model selections, plus dependent state
// like VAE precision. Everything else gets reset to default.
const oldState = deepClone(state);
const newState = getInitialParamsState();
newState.model = state.model;
newState.vae = state.vae;
newState.fluxVAE = state.fluxVAE;
newState.vaePrecision = state.vaePrecision;
newState.t5EncoderModel = state.t5EncoderModel;
newState.clipEmbedModel = state.clipEmbedModel;
newState.refinerModel = state.refinerModel;
newState.dimensions = oldState.dimensions;
newState.model = oldState.model;
newState.vae = oldState.vae;
newState.fluxVAE = oldState.fluxVAE;
newState.vaePrecision = oldState.vaePrecision;
newState.t5EncoderModel = oldState.t5EncoderModel;
newState.clipEmbedModel = oldState.clipEmbedModel;
newState.refinerModel = oldState.refinerModel;
return newState;
};
@@ -388,6 +438,7 @@ export const {
modelChanged,
// Dimensions
sizeRecalled,
widthChanged,
heightChanged,
aspectRatioLockToggled,
@@ -397,18 +448,15 @@ export const {
syncedToOptimalDimension,
paramsReset,
} = paramsSlice.actions;
} = slice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrate = (state: any): any => {
return state;
};
export const paramsPersistConfig: PersistConfig<ParamsState> = {
name: paramsSlice.name,
initialState: getInitialParamsState(),
migrate,
persistDenylist: [],
export const paramsSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zParamsState,
getInitialState: getInitialParamsState,
persistConfig: {
migrate: (state) => zParamsState.parse(state),
},
};
export const selectParamsSlice = (state: RootState) => state.params;

View File

@@ -2,7 +2,8 @@ import { objectEquals } from '@observ33r/object-equals';
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import type { PersistConfig, RootState } from 'app/store/store';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { clamp } from 'es-toolkit/compat';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import type { FLUXReduxImageInfluence, RefImagesState } from 'features/controlLayers/store/types';
@@ -18,7 +19,7 @@ import { assert } from 'tsafe';
import type { PartialDeep } from 'type-fest';
import type { CLIPVisionModelV2, IPMethodV2, RefImageState } from './types';
import { getInitialRefImagesState, isFLUXReduxConfig, isIPAdapterConfig } from './types';
import { getInitialRefImagesState, isFLUXReduxConfig, isIPAdapterConfig, zRefImagesState } from './types';
import {
getReferenceImageState,
imageDTOToImageWithDims,
@@ -36,7 +37,7 @@ type PayloadActionWithId<T = void> = T extends void
} & T
>;
export const refImagesSlice = createSlice({
const slice = createSlice({
name: 'refImages',
initialState: getInitialRefImagesState(),
reducers: {
@@ -263,18 +264,16 @@ export const {
refImageFLUXReduxImageInfluenceChanged,
refImageIsEnabledToggled,
refImagesRecalled,
} = refImagesSlice.actions;
} = slice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrate = (state: any): any => {
return state;
};
export const refImagesPersistConfig: PersistConfig<RefImagesState> = {
name: refImagesSlice.name,
initialState: getInitialRefImagesState(),
migrate,
persistDenylist: ['selectedEntityId', 'isPanelOpen'],
export const refImagesSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zRefImagesState,
getInitialState: getInitialRefImagesState,
persistConfig: {
migrate: (state) => zRefImagesState.parse(state),
persistDenylist: ['selectedEntityId', 'isPanelOpen'],
},
};
export const selectRefImagesSlice = (state: RootState) => state.refImages;

View File

@@ -1,9 +1,6 @@
import { deepClone } from 'common/util/deepClone';
import type { CanvasEntityAdapter } from 'features/controlLayers/konva/CanvasEntity/types';
import { fetchModelConfigByIdentifier } from 'features/metadata/util/modelFetchingHelpers';
import type { ProgressImage } from 'features/nodes/types/common';
import { zMainModelBase, zModelIdentifierField } from 'features/nodes/types/common';
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
import {
zParameterCanvasCoherenceMode,
zParameterCFGRescaleMultiplier,
@@ -29,33 +26,17 @@ import {
zParameterT5EncoderModel,
zParameterVAEModel,
} from 'features/parameters/types/parameterSchemas';
import { getImageDTOSafe } from 'services/api/endpoints/images';
import type { JsonObject } from 'type-fest';
import { z } from 'zod';
const zId = z.string().min(1);
const zName = z.string().min(1).nullable();
const zServerValidatedModelIdentifierField = zModelIdentifierField.refine(async (modelIdentifier) => {
try {
await fetchModelConfigByIdentifier(modelIdentifier);
return true;
} catch {
return false;
}
export const zImageWithDims = z.object({
image_name: z.string(),
width: z.number().int().positive(),
height: z.number().int().positive(),
});
const zImageWithDims = z
.object({
image_name: z.string(),
width: z.number().int().positive(),
height: z.number().int().positive(),
})
.refine(async (v) => {
const { image_name } = v;
const imageDTO = await getImageDTOSafe(image_name, { forceRefetch: true });
return imageDTO !== null;
});
export type ImageWithDims = z.infer<typeof zImageWithDims>;
const zImageWithDimsDataURL = z.object({
@@ -253,7 +234,7 @@ export type CanvasObjectState = z.infer<typeof zCanvasObjectState>;
const zIPAdapterConfig = z.object({
type: z.literal('ip_adapter'),
image: zImageWithDims.nullable(),
model: zServerValidatedModelIdentifierField.nullable(),
model: zModelIdentifierField.nullable(),
weight: z.number().gte(-1).lte(2),
beginEndStepPct: zBeginEndStepPct,
method: zIPMethodV2,
@@ -268,7 +249,7 @@ export type FLUXReduxImageInfluence = z.infer<typeof zFLUXReduxImageInfluence>;
const zFLUXReduxConfig = z.object({
type: z.literal('flux_redux'),
image: zImageWithDims.nullable(),
model: zServerValidatedModelIdentifierField.nullable(),
model: zModelIdentifierField.nullable(),
imageInfluence: zFLUXReduxImageInfluence.default('highest'),
});
export type FLUXReduxConfig = z.infer<typeof zFLUXReduxConfig>;
@@ -281,14 +262,14 @@ const zChatGPT4oReferenceImageConfig = z.object({
* But we use a model drop down to switch between different ref image types, so there needs to be a model here else
* there will be no way to switch between ref image types.
*/
model: zServerValidatedModelIdentifierField.nullable(),
model: zModelIdentifierField.nullable(),
});
export type ChatGPT4oReferenceImageConfig = z.infer<typeof zChatGPT4oReferenceImageConfig>;
const zFluxKontextReferenceImageConfig = z.object({
type: z.literal('flux_kontext_reference_image'),
image: zImageWithDims.nullable(),
model: zServerValidatedModelIdentifierField.nullable(),
model: zModelIdentifierField.nullable(),
});
export type FluxKontextReferenceImageConfig = z.infer<typeof zFluxKontextReferenceImageConfig>;
@@ -360,7 +341,7 @@ export type CanvasInpaintMaskState = z.infer<typeof zCanvasInpaintMaskState>;
const zControlNetConfig = z.object({
type: z.literal('controlnet'),
model: zServerValidatedModelIdentifierField.nullable(),
model: zModelIdentifierField.nullable(),
weight: z.number().gte(-1).lte(2),
beginEndStepPct: zBeginEndStepPct,
controlMode: zControlModeV2,
@@ -369,7 +350,7 @@ export type ControlNetConfig = z.infer<typeof zControlNetConfig>;
const zT2IAdapterConfig = z.object({
type: z.literal('t2i_adapter'),
model: zServerValidatedModelIdentifierField.nullable(),
model: zModelIdentifierField.nullable(),
weight: z.number().gte(-1).lte(2),
beginEndStepPct: zBeginEndStepPct,
});
@@ -378,7 +359,7 @@ export type T2IAdapterConfig = z.infer<typeof zT2IAdapterConfig>;
const zControlLoRAConfig = z.object({
type: z.literal('control_lora'),
weight: z.number().gte(-1).lte(2),
model: zServerValidatedModelIdentifierField.nullable(),
model: zModelIdentifierField.nullable(),
});
export type ControlLoRAConfig = z.infer<typeof zControlLoRAConfig>;
@@ -424,14 +405,13 @@ export const zCanvasEntityIdentifer = z.object({
});
export type CanvasEntityIdentifier<T extends CanvasEntityType = CanvasEntityType> = { id: string; type: T };
export type LoRA = {
id: string;
isEnabled: boolean;
model: ParameterLoRAModel;
weight: number;
};
export type EphemeralProgressImage = { sessionId: string; image: ProgressImage };
export const zLoRA = z.object({
id: z.string(),
isEnabled: z.boolean(),
model: zModelIdentifierField,
weight: z.number().gte(-1).lte(2),
});
export type LoRA = z.infer<typeof zLoRA>;
export const zAspectRatioID = z.enum(['Free', '21:9', '16:9', '3:2', '4:3', '1:1', '3:4', '2:3', '9:16', '9:21']);
export type AspectRatioID = z.infer<typeof zAspectRatioID>;
@@ -522,62 +502,108 @@ const zDimensionsState = z.object({
aspectRatio: zAspectRatioConfig,
});
const zParamsState = z.object({
maskBlur: z.number().default(16),
maskBlurMethod: zParameterMaskBlurMethod.default('box'),
canvasCoherenceMode: zParameterCanvasCoherenceMode.default('Gaussian Blur'),
canvasCoherenceMinDenoise: zParameterStrength.default(0),
canvasCoherenceEdgeSize: z.number().default(16),
infillMethod: z.string().default('lama'),
infillTileSize: z.number().default(32),
infillPatchmatchDownscaleSize: z.number().default(1),
infillColorValue: zRgbaColor.default({ r: 0, g: 0, b: 0, a: 1 }),
cfgScale: zParameterCFGScale.default(7.5),
cfgRescaleMultiplier: zParameterCFGRescaleMultiplier.default(0),
guidance: zParameterGuidance.default(4),
img2imgStrength: zParameterStrength.default(0.75),
optimizedDenoisingEnabled: z.boolean().default(true),
iterations: z.number().default(1),
scheduler: zParameterScheduler.default('dpmpp_3m_k'),
upscaleScheduler: zParameterScheduler.default('kdpm_2'),
upscaleCfgScale: zParameterCFGScale.default(2),
seed: zParameterSeed.default(0),
shouldRandomizeSeed: z.boolean().default(true),
steps: zParameterSteps.default(30),
model: zParameterModel.nullable().default(null),
vae: zParameterVAEModel.nullable().default(null),
vaePrecision: zParameterPrecision.default('fp32'),
fluxVAE: zParameterVAEModel.nullable().default(null),
seamlessXAxis: z.boolean().default(false),
seamlessYAxis: z.boolean().default(false),
clipSkip: z.number().default(0),
shouldUseCpuNoise: z.boolean().default(true),
positivePrompt: zParameterPositivePrompt.default(''),
// Negative prompt may be disabled, in which case it will be null
negativePrompt: zParameterNegativePrompt.default(null),
positivePrompt2: zParameterPositiveStylePromptSDXL.default(''),
negativePrompt2: zParameterNegativeStylePromptSDXL.default(''),
shouldConcatPrompts: z.boolean().default(true),
refinerModel: zParameterSDXLRefinerModel.nullable().default(null),
refinerSteps: z.number().default(20),
refinerCFGScale: z.number().default(7.5),
refinerScheduler: zParameterScheduler.default('euler'),
refinerPositiveAestheticScore: z.number().default(6),
refinerNegativeAestheticScore: z.number().default(2.5),
refinerStart: z.number().default(0.8),
t5EncoderModel: zParameterT5EncoderModel.nullable().default(null),
clipEmbedModel: zParameterCLIPEmbedModel.nullable().default(null),
clipLEmbedModel: zParameterCLIPLEmbedModel.nullable().default(null),
clipGEmbedModel: zParameterCLIPGEmbedModel.nullable().default(null),
controlLora: zParameterControlLoRAModel.nullable().default(null),
dimensions: zDimensionsState.default({
rect: { x: 0, y: 0, width: 512, height: 512 },
aspectRatio: DEFAULT_ASPECT_RATIO_CONFIG,
}),
export const zParamsState = z.object({
maskBlur: z.number(),
maskBlurMethod: zParameterMaskBlurMethod,
canvasCoherenceMode: zParameterCanvasCoherenceMode,
canvasCoherenceMinDenoise: zParameterStrength,
canvasCoherenceEdgeSize: z.number(),
infillMethod: z.string(),
infillTileSize: z.number(),
infillPatchmatchDownscaleSize: z.number(),
infillColorValue: zRgbaColor,
cfgScale: zParameterCFGScale,
cfgRescaleMultiplier: zParameterCFGRescaleMultiplier,
guidance: zParameterGuidance,
img2imgStrength: zParameterStrength,
optimizedDenoisingEnabled: z.boolean(),
iterations: z.number(),
scheduler: zParameterScheduler,
upscaleScheduler: zParameterScheduler,
upscaleCfgScale: zParameterCFGScale,
seed: zParameterSeed,
shouldRandomizeSeed: z.boolean(),
steps: zParameterSteps,
model: zParameterModel.nullable(),
vae: zParameterVAEModel.nullable(),
vaePrecision: zParameterPrecision,
fluxVAE: zParameterVAEModel.nullable(),
seamlessXAxis: z.boolean(),
seamlessYAxis: z.boolean(),
clipSkip: z.number(),
shouldUseCpuNoise: z.boolean(),
positivePrompt: zParameterPositivePrompt,
negativePrompt: zParameterNegativePrompt,
positivePrompt2: zParameterPositiveStylePromptSDXL,
negativePrompt2: zParameterNegativeStylePromptSDXL,
shouldConcatPrompts: z.boolean(),
refinerModel: zParameterSDXLRefinerModel.nullable(),
refinerSteps: z.number(),
refinerCFGScale: z.number(),
refinerScheduler: zParameterScheduler,
refinerPositiveAestheticScore: z.number(),
refinerNegativeAestheticScore: z.number(),
refinerStart: z.number(),
t5EncoderModel: zParameterT5EncoderModel.nullable(),
clipEmbedModel: zParameterCLIPEmbedModel.nullable(),
clipLEmbedModel: zParameterCLIPLEmbedModel.nullable(),
clipGEmbedModel: zParameterCLIPGEmbedModel.nullable(),
controlLora: zParameterControlLoRAModel.nullable(),
dimensions: zDimensionsState,
});
export type ParamsState = z.infer<typeof zParamsState>;
const INITIAL_PARAMS_STATE = zParamsState.parse({});
export const getInitialParamsState = () => deepClone(INITIAL_PARAMS_STATE);
export const getInitialParamsState = (): ParamsState => ({
maskBlur: 16,
maskBlurMethod: 'box',
canvasCoherenceMode: 'Gaussian Blur',
canvasCoherenceMinDenoise: 0,
canvasCoherenceEdgeSize: 16,
infillMethod: 'lama',
infillTileSize: 32,
infillPatchmatchDownscaleSize: 1,
infillColorValue: { r: 0, g: 0, b: 0, a: 1 },
cfgScale: 7.5,
cfgRescaleMultiplier: 0,
guidance: 4,
img2imgStrength: 0.75,
optimizedDenoisingEnabled: true,
iterations: 1,
scheduler: 'dpmpp_3m_k',
upscaleScheduler: 'kdpm_2',
upscaleCfgScale: 2,
seed: 0,
shouldRandomizeSeed: true,
steps: 30,
model: null,
vae: null,
vaePrecision: 'fp32',
fluxVAE: null,
seamlessXAxis: false,
seamlessYAxis: false,
clipSkip: 0,
shouldUseCpuNoise: true,
positivePrompt: '',
negativePrompt: null,
positivePrompt2: '',
negativePrompt2: '',
shouldConcatPrompts: true,
refinerModel: null,
refinerSteps: 20,
refinerCFGScale: 7.5,
refinerScheduler: 'euler',
refinerPositiveAestheticScore: 6,
refinerNegativeAestheticScore: 2.5,
refinerStart: 0.8,
t5EncoderModel: null,
clipEmbedModel: null,
clipLEmbedModel: null,
clipGEmbedModel: null,
controlLora: null,
dimensions: {
rect: { x: 0, y: 0, width: 512, height: 512 },
aspectRatio: deepClone(DEFAULT_ASPECT_RATIO_CONFIG),
},
});
const zInpaintMasks = z.object({
isHidden: z.boolean(),
@@ -595,38 +621,45 @@ const zRegionalGuidance = z.object({
isHidden: z.boolean(),
entities: z.array(zCanvasRegionalGuidanceState),
});
const zCanvasState = z.object({
_version: z.literal(3).default(3),
selectedEntityIdentifier: zCanvasEntityIdentifer.nullable().default(null),
bookmarkedEntityIdentifier: zCanvasEntityIdentifer.nullable().default(null),
inpaintMasks: zInpaintMasks.default({ isHidden: false, entities: [] }),
rasterLayers: zRasterLayers.default({ isHidden: false, entities: [] }),
controlLayers: zControlLayers.default({ isHidden: false, entities: [] }),
regionalGuidance: zRegionalGuidance.default({ isHidden: false, entities: [] }),
bbox: zBboxState.default({
export const zCanvasState = z.object({
_version: z.literal(3),
selectedEntityIdentifier: zCanvasEntityIdentifer.nullable(),
bookmarkedEntityIdentifier: zCanvasEntityIdentifer.nullable(),
inpaintMasks: zInpaintMasks,
rasterLayers: zRasterLayers,
controlLayers: zControlLayers,
regionalGuidance: zRegionalGuidance,
bbox: zBboxState,
});
export type CanvasState = z.infer<typeof zCanvasState>;
export const getInitialCanvasState = (): CanvasState => ({
_version: 3,
selectedEntityIdentifier: null,
bookmarkedEntityIdentifier: null,
inpaintMasks: { isHidden: false, entities: [] },
rasterLayers: { isHidden: false, entities: [] },
controlLayers: { isHidden: false, entities: [] },
regionalGuidance: { isHidden: false, entities: [] },
bbox: {
rect: { x: 0, y: 0, width: 512, height: 512 },
aspectRatio: DEFAULT_ASPECT_RATIO_CONFIG,
aspectRatio: deepClone(DEFAULT_ASPECT_RATIO_CONFIG),
scaleMethod: 'auto',
scaledSize: { width: 512, height: 512 },
modelBase: 'sd-1',
}),
},
});
export type CanvasState = z.infer<typeof zCanvasState>;
const zRefImagesState = z.object({
selectedEntityId: z.string().nullable().default(null),
isPanelOpen: z.boolean().default(false),
entities: z.array(zRefImageState).default(() => []),
export const zRefImagesState = z.object({
selectedEntityId: z.string().nullable(),
isPanelOpen: z.boolean(),
entities: z.array(zRefImageState),
});
export type RefImagesState = z.infer<typeof zRefImagesState>;
const INITIAL_REF_IMAGES_STATE = zRefImagesState.parse({});
export const getInitialRefImagesState = () => deepClone(INITIAL_REF_IMAGES_STATE);
/**
* Gets a fresh canvas initial state with no references in memory to existing objects.
*/
const CANVAS_INITIAL_STATE = zCanvasState.parse({});
export const getInitialCanvasState = () => deepClone(CANVAS_INITIAL_STATE);
export const getInitialRefImagesState = (): RefImagesState => ({
selectedEntityId: null,
isPanelOpen: false,
entities: [],
});
export const zCanvasReferenceImageState_OLD = zCanvasEntityBase.extend({
type: z.literal('reference_image'),

View File

@@ -27,6 +27,7 @@ export const DndImageIcon = memo((props: Props) => {
return (
<IconButton
onClick={onClick}
tooltip={tooltip}
aria-label={tooltip}
icon={icon}
variant="link"

View File

@@ -1,25 +1,29 @@
import type { PayloadAction, Selector } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import type { RootState } from 'app/store/store';
import type { SliceConfig } from 'app/store/types';
import { buildZodTypeGuard } from 'common/util/zodUtils';
import { isPlainObject } from 'es-toolkit';
import { assert } from 'tsafe';
import { z } from 'zod';
const zSeedBehaviour = z.enum(['PER_ITERATION', 'PER_PROMPT']);
export const isSeedBehaviour = buildZodTypeGuard(zSeedBehaviour);
export type SeedBehaviour = z.infer<typeof zSeedBehaviour>;
export interface DynamicPromptsState {
_version: 1;
maxPrompts: number;
combinatorial: boolean;
prompts: string[];
parsingError: string | undefined | null;
isError: boolean;
isLoading: boolean;
seedBehaviour: SeedBehaviour;
}
const zDynamicPromptsState = z.object({
_version: z.literal(1),
maxPrompts: z.number().int().min(1).max(1000),
combinatorial: z.boolean(),
prompts: z.array(z.string()),
parsingError: z.string().nullish(),
isError: z.boolean(),
isLoading: z.boolean(),
seedBehaviour: zSeedBehaviour,
});
export type DynamicPromptsState = z.infer<typeof zDynamicPromptsState>;
const initialDynamicPromptsState: DynamicPromptsState = {
const getInitialState = (): DynamicPromptsState => ({
_version: 1,
maxPrompts: 100,
combinatorial: true,
@@ -28,11 +32,11 @@ const initialDynamicPromptsState: DynamicPromptsState = {
isError: false,
isLoading: false,
seedBehaviour: 'PER_ITERATION',
};
});
export const dynamicPromptsSlice = createSlice({
const slice = createSlice({
name: 'dynamicPrompts',
initialState: initialDynamicPromptsState,
initialState: getInitialState(),
reducers: {
maxPromptsChanged: (state, action: PayloadAction<number>) => {
state.maxPrompts = action.payload;
@@ -63,21 +67,22 @@ export const {
isErrorChanged,
isLoadingChanged,
seedBehaviourChanged,
} = dynamicPromptsSlice.actions;
} = slice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrateDynamicPromptsState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
return state;
};
export const dynamicPromptsPersistConfig: PersistConfig<DynamicPromptsState> = {
name: dynamicPromptsSlice.name,
initialState: initialDynamicPromptsState,
migrate: migrateDynamicPromptsState,
persistDenylist: ['prompts'],
export const dynamicPromptsSliceConfig: SliceConfig<typeof slice> = {
slice,
schema: zDynamicPromptsState,
getInitialState,
persistConfig: {
migrate: (state) => {
assert(isPlainObject(state));
if (!('_version' in state)) {
state._version = 1;
}
return zDynamicPromptsState.parse(state);
},
persistDenylist: ['prompts', 'parsingError', 'isError', 'isLoading'],
},
};
export const selectDynamicPromptsSlice = (state: RootState) => state.dynamicPrompts;

View File

@@ -53,6 +53,7 @@ export const BoardEditableTitle = memo(({ board, isSelected }: Props) => {
color={isSelected ? 'base.100' : 'base.300'}
onDoubleClick={editable.startEditing}
cursor="text"
noOfLines={1}
>
{editable.value}
</Text>

View File

@@ -37,6 +37,7 @@ export const BoardTooltip = ({ board }: Props) => {
/>
)}
<Flex flexDir="column" alignItems="center">
{board && <Text fontWeight="semibold">{board.board_name}</Text>}
<Text noOfLines={1}>
{t('boards.imagesWithCount', { count: imagesTotal })}, {t('boards.assetsWithCount', { count: assetsTotal })}
</Text>

View File

@@ -59,7 +59,7 @@ export const BoardsPanel = memo(() => {
onClick={collapsibleApi.toggle}
leftIcon={isCollapsed ? <PiCaretDownBold /> : <PiCaretUpBold />}
>
Boards
{t('boards.boards')}
</Button>
</Flex>
<Flex>

View File

@@ -75,6 +75,7 @@ export const GalleryPanel = memo(() => {
variant="ghost"
onClick={collapsibleApi.toggle}
leftIcon={isCollapsed ? <PiCaretDownBold /> : <PiCaretUpBold />}
noOfLines={1}
>
{boardName}
</Button>

View File

@@ -2,6 +2,7 @@ import { Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
import { useAppStore } from 'app/store/storeHooks';
import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu';
import { useCanvasIsBusySafe } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { useCanvasIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { useImageDTOContext } from 'features/gallery/contexts/ImageDTOContext';
import { newCanvasFromImage } from 'features/imageActions/actions';
import { toast } from 'features/toast/toast';
@@ -17,6 +18,7 @@ export const ImageMenuItemNewCanvasFromImageSubMenu = memo(() => {
const store = useAppStore();
const imageDTO = useImageDTOContext();
const isBusy = useCanvasIsBusySafe();
const isStaging = useCanvasIsStaging();
const onClickNewCanvasWithRasterLayerFromImage = useCallback(async () => {
const { dispatch, getState } = store;
@@ -97,27 +99,31 @@ export const ImageMenuItemNewCanvasFromImageSubMenu = memo(() => {
<SubMenuButtonContent label={t('controlLayers.newCanvasFromImage')} />
</MenuButton>
<MenuList {...subMenu.menuListProps}>
<MenuItem icon={<PiFileBold />} onClickCapture={onClickNewCanvasWithRasterLayerFromImage} isDisabled={isBusy}>
<MenuItem
icon={<PiFileBold />}
onClickCapture={onClickNewCanvasWithRasterLayerFromImage}
isDisabled={isStaging || isBusy}
>
{t('controlLayers.asRasterLayer')}
</MenuItem>
<MenuItem
icon={<PiFileBold />}
onClickCapture={onClickNewCanvasWithRasterLayerFromImageWithResize}
isDisabled={isBusy}
isDisabled={isStaging || isBusy}
>
{t('controlLayers.asRasterLayerResize')}
</MenuItem>
<MenuItem
icon={<PiFileBold />}
onClickCapture={onClickNewCanvasWithControlLayerFromImage}
isDisabled={isBusy}
isDisabled={isStaging || isBusy}
>
{t('controlLayers.asControlLayer')}
</MenuItem>
<MenuItem
icon={<PiFileBold />}
onClickCapture={onClickNewCanvasWithControlLayerFromImageWithResize}
isDisabled={isBusy}
isDisabled={isStaging || isBusy}
>
{t('controlLayers.asControlLayerResize')}
</MenuItem>

View File

@@ -1,5 +1,6 @@
import { MenuItem } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { imageDTOToImageWithDims } from 'features/controlLayers/store/util';
import { useImageDTOContext } from 'features/gallery/contexts/ImageDTOContext';
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
import { toast } from 'features/toast/toast';
@@ -14,7 +15,7 @@ export const ImageMenuItemSendToUpscale = memo(() => {
const imageDTO = useImageDTOContext();
const handleSendToCanvas = useCallback(() => {
dispatch(upscaleInitialImageChanged(imageDTO));
dispatch(upscaleInitialImageChanged(imageDTOToImageWithDims(imageDTO)));
navigationApi.switchToTab('upscaling');
toast({
id: 'SENT_TO_CANVAS',

View File

@@ -28,7 +28,7 @@ export const ImageMenuItemUseAsRefImage = memo(() => {
return (
<MenuItem icon={<PiImageBold />} onClickCapture={onClickNewGlobalReferenceImageFromImage}>
Use as Reference Image
{t('controlLayers.useAsReferenceImage')}
</MenuItem>
);
});

View File

@@ -60,7 +60,7 @@ export const CompareToolbar = memo(() => {
useRegisteredHotkeys({ id: 'nextComparisonMode', category: 'viewer', callback: nextMode, dependencies: [nextMode] });
return (
<Flex w="full" px={2} gap={2} bg="base.750" borderTopRadius="base" h={12}>
<Flex w="full" justifyContent="center" h={8}>
<Flex flex={1} justifyContent="center">
<Flex marginInlineEnd="auto" alignItems="center">
<IconButton
@@ -85,7 +85,7 @@ export const CompareToolbar = memo(() => {
</Flex>
</Flex>
<Flex flex={1} justifyContent="center">
<ButtonGroup variant="outline" alignItems="center">
<ButtonGroup size="sm" variant="outline" alignItems="center">
<Button
flexShrink={0}
onClick={setComparisonModeSlider}
@@ -117,6 +117,7 @@ export const CompareToolbar = memo(() => {
</Flex>
</Tooltip>
<Button
size="sm"
variant="link"
alignSelf="stretch"
px={2}

View File

@@ -1,15 +0,0 @@
import { Flex } from '@invoke-ai/ui-library';
import { ProgressImage } from 'features/gallery/components/ImageViewer/ProgressImage';
import { memo } from 'react';
import { ProgressIndicator } from './ProgressIndicator';
export const GenerationProgressPanel = memo(() => {
return (
<Flex position="relative" flexDir="column" w="full" h="full" overflow="hidden">
<ProgressImage />
<ProgressIndicator position="absolute" top={6} right={6} size={8} />
</Flex>
);
});
GenerationProgressPanel.displayName = 'GenerationProgressPanel';

View File

@@ -1,32 +1,36 @@
import { Box, Flex } from '@invoke-ai/ui-library';
import { Box, Divider, Flex } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import type { ComparisonProps } from 'features/gallery/components/ImageViewer/common';
import { debounce } from 'es-toolkit';
import type { ComparisonWrapperProps } from 'features/gallery/components/ImageViewer/common';
import { selectImageToCompare } from 'features/gallery/components/ImageViewer/common';
import { CompareToolbar } from 'features/gallery/components/ImageViewer/CompareToolbar';
import { ImageComparisonDroppable } from 'features/gallery/components/ImageViewer/ImageComparisonDroppable';
import { ImageComparisonHover } from 'features/gallery/components/ImageViewer/ImageComparisonHover';
import { ImageComparisonSideBySide } from 'features/gallery/components/ImageViewer/ImageComparisonSideBySide';
import { ImageComparisonSlider } from 'features/gallery/components/ImageViewer/ImageComparisonSlider';
import { selectComparisonMode } from 'features/gallery/store/gallerySelectors';
import { memo } from 'react';
import { useMeasure } from 'react-use';
import { selectComparisonMode, selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
import { memo, useCallback, useLayoutEffect, useRef, useState } from 'react';
import { useImageDTO } from 'services/api/endpoints/images';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
const ImageComparisonContent = memo(({ firstImage, secondImage, containerDims }: ComparisonProps) => {
const ImageComparisonContent = memo(({ firstImage, secondImage, rect }: ComparisonWrapperProps) => {
const comparisonMode = useAppSelector(selectComparisonMode);
if (!firstImage || !secondImage) {
return null;
}
if (comparisonMode === 'slider') {
return <ImageComparisonSlider firstImage={firstImage} secondImage={secondImage} containerDims={containerDims} />;
return <ImageComparisonSlider firstImage={firstImage} secondImage={secondImage} rect={rect} />;
}
if (comparisonMode === 'side-by-side') {
return (
<ImageComparisonSideBySide firstImage={firstImage} secondImage={secondImage} containerDims={containerDims} />
);
return <ImageComparisonSideBySide firstImage={firstImage} secondImage={secondImage} rect={rect} />;
}
if (comparisonMode === 'hover') {
return <ImageComparisonHover firstImage={firstImage} secondImage={secondImage} containerDims={containerDims} />;
return <ImageComparisonHover firstImage={firstImage} secondImage={secondImage} rect={rect} />;
}
assert<Equals<never, typeof comparisonMode>>(false);
@@ -34,16 +38,51 @@ const ImageComparisonContent = memo(({ firstImage, secondImage, containerDims }:
ImageComparisonContent.displayName = 'ImageComparisonContent';
export const ImageComparison = memo(({ firstImage, secondImage }: Omit<ComparisonProps, 'containerDims'>) => {
const [containerRef, containerDims] = useMeasure<HTMLDivElement>();
export const ImageComparison = memo(() => {
const lastSelectedImageName = useAppSelector(selectLastSelectedImage);
const lastSelectedImageDTO = useImageDTO(lastSelectedImageName);
const comparisonImageName = useAppSelector(selectImageToCompare);
const comparisonImageDTO = useImageDTO(comparisonImageName);
const [rect, setRect] = useState<DOMRect | null>(null);
const ref = useRef<HTMLDivElement | null>(null);
// Ref callback runs synchronously when the DOM node is attached, ensuring we have a measurement before
// the comparison content is rendered.
const measureNode = useCallback((node: HTMLDivElement) => {
if (node) {
ref.current = node;
const boundingRect = node.getBoundingClientRect();
setRect(boundingRect);
}
}, []);
useLayoutEffect(() => {
const el = ref.current;
if (!el) {
return;
}
const measureRect = debounce(() => {
const boundingRect = el.getBoundingClientRect();
setRect(boundingRect);
}, 300);
const observer = new ResizeObserver(measureRect);
observer.observe(el);
return () => {
observer.disconnect();
};
}, []);
return (
<Flex flexDir="column" w="full" h="full" position="relative">
<Flex flexDir="column" w="full" h="full" overflow="hidden" gap={2} position="relative">
<CompareToolbar />
<Box ref={containerRef} w="full" h="full" p={2} overflow="hidden">
<ImageComparisonContent firstImage={firstImage} secondImage={secondImage} containerDims={containerDims} />
</Box>
<ImageComparisonDroppable />
<Divider />
<Flex w="full" h="full" position="relative">
<Box ref={measureNode} w="full" h="full" overflow="hidden">
<ImageComparisonContent firstImage={lastSelectedImageDTO} secondImage={comparisonImageDTO} rect={rect} />
</Box>
<ImageComparisonDroppable />
</Flex>
</Flex>
);
});

View File

@@ -11,14 +11,16 @@ import { memo, useMemo, useRef } from 'react';
import type { ComparisonProps } from './common';
import { fitDimsToContainer, getSecondImageDims } from './common';
export const ImageComparisonHover = memo(({ firstImage, secondImage, containerDims }: ComparisonProps) => {
export const ImageComparisonHover = memo(({ firstImage, secondImage, rect }: ComparisonProps) => {
const comparisonFit = useAppSelector(selectComparisonFit);
const imageContainerRef = useRef<HTMLDivElement>(null);
const mouseOver = useBoolean(false);
const fittedDims = useMemo<Dimensions>(
() => fitDimsToContainer(containerDims, firstImage),
[containerDims, firstImage]
);
const fittedDims = useMemo<Dimensions>(() => {
if (!rect) {
return { width: 0, height: 0 };
}
return fitDimsToContainer(rect, firstImage);
}, [firstImage, rect]);
const compareImageDims = useMemo<Dimensions>(
() => getSecondImageDims(comparisonFit, fittedDims, firstImage, secondImage),
[comparisonFit, fittedDims, firstImage, secondImage]

View File

@@ -19,7 +19,7 @@ const HANDLE_HITBOX_PX = `${HANDLE_HITBOX}px`;
const HANDLE_INNER_LEFT_PX = `${HANDLE_HITBOX / 2 - HANDLE_WIDTH / 2}px`;
const HANDLE_LEFT_INITIAL_PX = `calc(${INITIAL_POS} - ${HANDLE_HITBOX / 2}px)`;
export const ImageComparisonSlider = memo(({ firstImage, secondImage, containerDims }: ComparisonProps) => {
export const ImageComparisonSlider = memo(({ firstImage, secondImage, rect }: ComparisonProps) => {
const comparisonFit = useAppSelector(selectComparisonFit);
// How far the handle is from the left - this will be a CSS calculation that takes into account the handle width
@@ -33,10 +33,12 @@ export const ImageComparisonSlider = memo(({ firstImage, secondImage, containerD
const rafRef = useRef<number | null>(null);
const lastMoveTimeRef = useRef<number>(0);
const fittedDims = useMemo<Dimensions>(
() => fitDimsToContainer(containerDims, firstImage),
[containerDims, firstImage]
);
const fittedDims = useMemo<Dimensions>(() => {
if (!rect) {
return { width: 0, height: 0 };
}
return fitDimsToContainer(rect, firstImage);
}, [firstImage, rect]);
const compareImageDims = useMemo<Dimensions>(
() => getSecondImageDims(comparisonFit, fittedDims, firstImage, secondImage),

Some files were not shown because too many files have changed in this diff Show More