Compare commits

...

190 Commits

Author SHA1 Message Date
psychedelicious
005402b0a7 chore: ruff 2025-06-25 01:20:04 +10:00
psychedelicious
a6baeba357 chore: bump version to v6.0.0a6 2025-06-25 01:14:12 +10:00
psychedelicious
ccdd58838a fix(ui): minor jank when siwtching images rapidly 2025-06-25 01:10:49 +10:00
psychedelicious
e4170df91d feat(ui): scrollbar styles 2025-06-25 01:06:02 +10:00
psychedelicious
190a7eef58 refactor: gallery scroll (improved impl) 2025-06-25 00:51:16 +10:00
psychedelicious
18a93a5164 refactor: gallery scroll (improved impl) 2025-06-25 00:38:28 +10:00
psychedelicious
a1cf3af732 refactor: gallery scroll (improved impl) 2025-06-24 23:33:37 +10:00
psychedelicious
4f191fe4b3 refactor: gallery scroll (improved impl) 2025-06-24 23:00:24 +10:00
psychedelicious
097d0da09f refactor: gallery scroll (improved impl) 2025-06-24 21:35:41 +10:00
psychedelicious
184bcfaf06 refactor: gallery scroll (improved impl) 2025-06-24 21:31:37 +10:00
psychedelicious
994236e9a8 refactor: gallery scroll (improved impl) 2025-06-24 21:01:29 +10:00
psychedelicious
74004aea04 refactor: gallery scroll (improved impl) 2025-06-24 20:46:45 +10:00
psychedelicious
ab27832d0c refactor: gallery scroll (improved impl) 2025-06-24 19:58:59 +10:00
psychedelicious
ef3d260657 refactor: gallery scroll (improved impl) 2025-06-24 17:23:04 +10:00
psychedelicious
4c462c2423 refactor: gallery scroll (improved impl) 2025-06-24 16:03:29 +10:00
psychedelicious
f797061390 refactor: gallery scroll 2025-06-24 15:51:28 +10:00
psychedelicious
83cca78e7c fix(ui): fix metadata toggle stuck disabled 2025-06-24 06:07:17 +10:00
psychedelicious
89cd24d3e2 chore: bump version to v6.0.0a5 2025-06-24 05:41:21 +10:00
psychedelicious
ae3e9f0007 chore(ui): lint 2025-06-24 05:41:21 +10:00
psychedelicious
7a5d0a8973 refactor(ui): use image names for selection instead of dtos
Update the frontend to incorporate the previous changes to how image
selection and general image identification is handled in the frontend.
2025-06-24 05:41:21 +10:00
psychedelicious
515e270908 chore(ui): typegen 2025-06-24 05:41:20 +10:00
psychedelicious
69cd265124 feat(api): return more data when doing image/board mutations
When we delete images, boards, or do any other board mutation, we need
to invalidate numerous query caches and related internal frontend state.
This gets complicated very quickly.

We can drastically reduce the complexity by having the backend return
some more information when we make these mutations.

For example, when deleting a list of images by name, we can return a
list of deleted image name and affected boards. The frontend can use
this information to determine which queries to invalidate with far less
tedium.

This will also enable the more efficient storage of images (e.g. in the
gallery selection). Previously, we had to store the entire image DTO
object, else we wouldn't be able to figure out which queries to
invalidate. But now that the backend tells us exactly what images/boards
have changed, we can just store image names in frontend state. This
amounts to a substantial improvement in DX and reduction in frontend
complexity.
2025-06-24 05:41:20 +10:00
psychedelicious
88164ed268 feat(ui): viewer integrates progress (wip) 2025-06-24 05:41:20 +10:00
psychedelicious
8566ede81a feat(ui): switch to viewer/canvas on invoke 2025-06-24 05:41:20 +10:00
psychedelicious
1690d10197 feat(ui): generation progress tab improvements 2025-06-24 05:41:20 +10:00
psychedelicious
5ec022323a feat(ui): show last progress message & placeholder in generation progress panel 2025-06-24 05:41:20 +10:00
psychedelicious
90815551d6 fix(ui): staging area does not show placeholder on first render 2025-06-24 05:41:20 +10:00
psychedelicious
dac23c54c9 feat(ui): double-click staging area image to disable auto-switch 2025-06-24 05:41:20 +10:00
psychedelicious
da116eb09b fix(ui): reset last started item id when doing autoswitch 2025-06-24 05:41:20 +10:00
psychedelicious
5f8e21e809 feat(ui): re-implement multiple auto-switch modes 2025-06-24 05:41:20 +10:00
psychedelicious
8d53cbbcdd chore: bump version to v6.0.0a4 2025-06-24 05:41:20 +10:00
psychedelicious
baac5f06d6 feat(ui): no model error state for ref images 2025-06-24 05:41:20 +10:00
psychedelicious
cf781a3b99 feat(ui): mini metadata viewer 2025-06-24 05:41:20 +10:00
psychedelicious
1e47e6fe0a feat(ui): clean up image view components & code 2025-06-24 05:41:19 +10:00
psychedelicious
cb15841eaf fix(ui): launchpad layouts 2025-06-24 05:41:19 +10:00
psychedelicious
527d89d07b fix(ui): don't use layers when generating on generate tab 2025-06-24 05:41:19 +10:00
psychedelicious
d36f02a20f feat(ui): tweak vertical tab bar layout 2025-06-24 05:41:19 +10:00
psychedelicious
432c65795a fix(ui): unable to resize prompt box bc negative prompt button is over
the handle
2025-06-24 05:41:19 +10:00
psychedelicious
c734924ea5 feat(ui): standardize auto layout structure 2025-06-24 05:41:19 +10:00
psychedelicious
b61e6d5760 feat(ui): tweak dockview tabs 2025-06-24 05:41:19 +10:00
psychedelicious
06289da0c9 refactor(ui): rip out image viewer as modal 2025-06-24 05:41:19 +10:00
psychedelicious
d97eb84c4e chore: bump version to v6.0.0a3 2025-06-24 05:41:19 +10:00
psychedelicious
15f212b9f0 chore(ui): lint 2025-06-24 05:41:19 +10:00
psychedelicious
5e573119ab feat(ui): restore all panel hotkeys 2025-06-24 05:41:19 +10:00
psychedelicious
20effc5da6 fix(ui): generate tab hotkey 2025-06-24 05:41:19 +10:00
psychedelicious
c2266da827 feat(ui): restore floating panel buttons 2025-06-24 05:41:18 +10:00
psychedelicious
9d8e182227 feat(ui): get all tabs working w/ new layout 2025-06-24 05:41:18 +10:00
psychedelicious
49420d3449 fix(ui): unnecessary dependency on tab selection in
useCanvasDeleteLayerHotkey
2025-06-24 05:41:18 +10:00
psychedelicious
7e04454106 fix(ui): inverted logic for resume queue button 2025-06-24 05:41:18 +10:00
psychedelicious
78fa526312 feat(ui): get layouts working 2025-06-24 05:41:18 +10:00
psychedelicious
8d0541b06e feat(ui): canvas launchpad 2025-06-24 05:41:18 +10:00
psychedelicious
f5b093b980 wip 2025-06-24 05:41:18 +10:00
psychedelicious
906b2f852c fix(ui): wonky stage sizing on first visibility 2025-06-24 05:41:18 +10:00
psychedelicious
83d41c8bd1 wip 2025-06-24 05:41:18 +10:00
psychedelicious
afb318fd76 feat(ui): port UI slice to zod 2025-06-24 05:41:18 +10:00
psychedelicious
00410d1376 fix(ui): only show weight for IP adapters 2025-06-24 05:41:18 +10:00
psychedelicious
1662063152 feat(ui): represent IP adapter weight in ref image thumbnail 2025-06-24 05:41:18 +10:00
psychedelicious
376e5836b8 fix(ui): overflow on ref image model 2025-06-24 05:41:17 +10:00
psychedelicious
1a79109998 feat(ui): ref images feel more like buttons 2025-06-24 05:41:17 +10:00
psychedelicious
d0ff7256c9 feat(ui): switch tab on drag over tab button 2025-06-24 05:41:17 +10:00
psychedelicious
87ddaa602d feat(ui): tweak splash screen layout 2025-06-24 05:41:17 +10:00
psychedelicious
32f268af20 chore(ui): lint 2025-06-24 05:41:17 +10:00
psychedelicious
006c90127b feat(ui): rework simple session initial state 2025-06-24 05:41:17 +10:00
psychedelicious
ff21055cfb fix(ui): invoke button tooltip on generate tab 2025-06-24 05:41:17 +10:00
psychedelicious
92d16ffa96 fix(ui): progress image fixes 2025-06-24 05:41:17 +10:00
psychedelicious
e8a64ac766 feat(ui): make autoswitch on/off
When the invocation cache is used, we might skip all progress images. This can prevent auto-switch-on-first-progress from working, as we don't get any of those events.

It's much easier to only support auto-switch on complete.
2025-06-24 05:41:17 +10:00
psychedelicious
fd42db0b1b feat(ui): refine ref images UI 2025-06-24 05:41:17 +10:00
psychedelicious
5ad3f611b6 feat(ui): toggleable negative prompt 2025-06-24 05:41:17 +10:00
psychedelicious
6046ac2f75 fix(ui): remove old isSelected from refImageAdded call 2025-06-24 05:41:17 +10:00
psychedelicious
933a45e0fe chore: bump version to v6.0.0a2 2025-06-24 05:41:16 +10:00
psychedelicious
177f879bb3 fix(ui): update queue item preview images on init of queue items context 2025-06-24 05:41:16 +10:00
psychedelicious
fe09ff3501 fix(ui): hack to close chakra tooltips on drag 2025-06-24 05:41:16 +10:00
psychedelicious
5313aac7f8 tweak(ui): ref image header 2025-06-24 05:41:16 +10:00
psychedelicious
8ece17f39f experiment(ui): add generate tab 2025-06-24 05:41:16 +10:00
psychedelicious
d18c4cb6e5 refactor(ui): ref images (WIP) 2025-06-24 05:41:16 +10:00
psychedelicious
1a8bf3cac8 refactor(ui): ref images (WIP) 2025-06-24 05:41:16 +10:00
psychedelicious
194e42ac99 refactor(ui): refImage.ipAdapter -> refImage.config 2025-06-24 05:41:16 +10:00
psychedelicious
25e002e4d6 feat(ui): split out ref images into own slice (WIP) 2025-06-24 05:41:16 +10:00
psychedelicious
2206bc543e feat(ui): simple session initial state cards are buttons 2025-06-24 05:41:16 +10:00
psychedelicious
7547c17758 chore(ui): dpdm 2025-06-24 05:41:16 +10:00
psychedelicious
3d58da6d18 refactor(ui): async modal pattern; use for deleting images
This was needed for a canvas flow change which is currently paused, but the new API is much much nicer to use, so I am keeping it.
2025-06-24 05:41:16 +10:00
psychedelicious
2b06f252be fix(ui): use imageDTO in staging area 2025-06-24 05:41:15 +10:00
psychedelicious
265f74e642 fix(ui): wait until last queue item deleted before flagging canvas session finished 2025-06-24 05:41:15 +10:00
psychedelicious
c9ee9d9c27 feat(ui): store output image DTO in session context instead of just the name 2025-06-24 05:41:15 +10:00
psychedelicious
72ab5268c7 feat(ui): add AppGetState type 2025-06-24 05:41:15 +10:00
psychedelicious
237ae28373 feat(ui): close viewer on escape 2025-06-24 05:41:15 +10:00
psychedelicious
f6e7a1bde7 fix(ui): switch only on first progress image 2025-06-24 05:41:15 +10:00
psychedelicious
4fb40fcf86 feat(ui): add on first progress autoswitch mode 2025-06-24 05:41:15 +10:00
psychedelicious
2a10c00117 feat(ui): move canvas-specific staging subscriptions to CanvasStagingAreaModule 2025-06-24 05:41:15 +10:00
psychedelicious
46e6334e27 chore(ui): lint 2025-06-24 05:41:15 +10:00
psychedelicious
e0fbb9b916 feat(ui): make main panel styling and title consistent 2025-06-24 05:41:15 +10:00
psychedelicious
607b292a16 feat(ui): add startover button to canvas toolbar 2025-06-24 05:41:15 +10:00
psychedelicious
a05b8cf536 feat(ui): fiddle w/ staging area header 2025-06-24 05:41:15 +10:00
psychedelicious
2145dd217e feat(ui): remove technical progress message from full preview 2025-06-24 05:41:15 +10:00
psychedelicious
7d58b37d68 feat(ui): simple session initial state 2025-06-24 05:41:14 +10:00
psychedelicious
c85f3f88db feat(ui): remove vary and edit as control buttons 2025-06-24 05:41:14 +10:00
psychedelicious
b7751a85c2 refactor(ui): migrate from canceling queue items to deleteing, make queue hook APIs consistent 2025-06-24 05:41:14 +10:00
psychedelicious
921dcd81b8 fix(ui): mini preview bg color 2025-06-24 05:41:14 +10:00
psychedelicious
01acac8c4e fix(ui): hide layers when not on canvas tab 2025-06-24 05:41:14 +10:00
psychedelicious
3f7a9b7d82 build(ui): temporarily ignore all knip issues 2025-06-24 05:41:14 +10:00
psychedelicious
21f8263ddf feat(ui): finish generation when discarding last item 2025-06-24 05:41:14 +10:00
psychedelicious
c2db93669c feat(ui): when discarding last item, select new last instead of first 2025-06-24 05:41:14 +10:00
psychedelicious
37c0961597 feat(ui): tweak staging image display 2025-06-24 05:41:14 +10:00
psychedelicious
b13be4891a feat(ui): add staging area toolbar to simple session 2025-06-24 05:41:14 +10:00
psychedelicious
5cad4ed06d fix(ui): ensure canvas tool modules are destroyed 2025-06-24 05:41:14 +10:00
psychedelicious
c8c8a79c07 fix(ui): reset layers when changing session type 2025-06-24 05:41:14 +10:00
psychedelicious
22ff5f71cc feat(ui): improved staging placeholders 2025-06-24 05:41:13 +10:00
psychedelicious
b2ee934e8c feat(ui): improved staging placeholders 2025-06-24 05:41:13 +10:00
psychedelicious
19f8bb4795 feat(ui): more staging fixes 2025-06-24 05:41:13 +10:00
psychedelicious
85e17aa36b feat(ui): update canvas session state handling for new staging strat 2025-06-24 05:41:13 +10:00
psychedelicious
88680a75c9 chore(ui): lint (partial cleanup) 2025-06-24 05:41:13 +10:00
psychedelicious
b038c79451 feat(ui): rough out canvas staging area 2025-06-24 05:41:13 +10:00
psychedelicious
088eea9a0e feat(app): support deleting queue items by id or destination 2025-06-24 05:41:13 +10:00
psychedelicious
c575eb14ca feat(ui): tweak canvas scroll to zoom feel 2025-06-24 05:41:13 +10:00
psychedelicious
ffb3dc2bcd docs(ui): add comment about auto-switch not being quite right yet 2025-06-24 05:41:13 +10:00
psychedelicious
b1b1d7c2a6 feat: canvas flow rework (wip) 2025-06-24 05:41:13 +10:00
psychedelicious
89b34cb225 feat(ui): prevent flicker of image action buttons 2025-06-24 05:41:13 +10:00
psychedelicious
aa4b0d6705 feat(ui): move socket events handling into ctx component 2025-06-24 05:41:13 +10:00
psychedelicious
b97ad8518c feat(ui): modularize all staging area logic so it can be shared w/ canvas more easily 2025-06-24 05:41:13 +10:00
psychedelicious
eda3cc2306 perf(ui): queue actions menu is lazy 2025-06-24 05:41:12 +10:00
psychedelicious
ec564725b1 fix(ui): cursor on staging area preview image 2025-06-24 05:41:12 +10:00
psychedelicious
c2fc3c6328 feat(ui): remove clear queue ui components 2025-06-24 05:41:12 +10:00
psychedelicious
fdd2051257 feat(app): do not prune queue on startup
With the new canvas design, this will result in loss of staging area images.
2025-06-24 05:41:12 +10:00
psychedelicious
4d7213baf7 tidy(ui): component organization 2025-06-24 05:41:12 +10:00
psychedelicious
533018b14e fix(ui): prevent drag of progress images 2025-06-24 05:41:12 +10:00
psychedelicious
d2b486cd8e feat: canvas flow rework (wip) 2025-06-24 05:41:12 +10:00
psychedelicious
e05b3cdd8a feat: canvas flow rework (wip) 2025-06-24 05:41:12 +10:00
psychedelicious
6c98b4b38d chore(ui): typegen 2025-06-24 05:41:12 +10:00
psychedelicious
a10f5efd16 feat(api): remove status from list all queue items query 2025-06-24 05:41:12 +10:00
psychedelicious
6354d363c5 tidy(ui): app layout components 2025-06-24 05:41:12 +10:00
psychedelicious
161559a1fa feat: canvas flow rework (wip) 2025-06-24 05:41:11 +10:00
psychedelicious
c0c1649436 feat: canvas flow rework (wip) 2025-06-24 05:41:11 +10:00
psychedelicious
10105a7e4e feat: canvas flow rework (wip) 2025-06-24 05:41:11 +10:00
psychedelicious
ffb6e87d50 fix(ui): unstable selector results in lora drop down 2025-06-24 05:41:11 +10:00
psychedelicious
d667bb1741 feat: canvas flow rework (wip) 2025-06-24 05:41:11 +10:00
psychedelicious
52f8cf6840 feat: canvas flow rework (wip) 2025-06-24 05:41:11 +10:00
psychedelicious
1a3b3dad8a wip progress events 2025-06-24 05:41:11 +10:00
psychedelicious
4957dd8fa2 refactor(ui): canvas flow (wip) 2025-06-24 05:41:11 +10:00
psychedelicious
30e22a0728 fix(ui): ref goes undefined in GalleryImage
This appears to be a bug in Chakra UI v2 - use of a fallback component makes the ref passed to an image end up undefined. Had to remove the skeleton loader fallback component.
2025-06-24 05:41:11 +10:00
psychedelicious
2b297d4d37 fix(ui): merge refs when forwardingin DndImage 2025-06-24 05:41:11 +10:00
psychedelicious
5cbac72167 fix(ui): remove unused sessionId field from type 2025-06-24 05:41:10 +10:00
psychedelicious
3b3672e4ae fix(ui): ensure all args are passed to handler when creating new canvas from image 2025-06-24 05:41:10 +10:00
psychedelicious
e39ce1fee2 feat(ui): bookmark new inpaint masks 2025-06-24 05:41:10 +10:00
psychedelicious
6e0e394095 feat(ui): support bookmarking an entity when adding it 2025-06-24 05:41:10 +10:00
psychedelicious
3ec39c9b3f fix(ui): ensure images are added to gallery in simple sessions 2025-06-24 05:41:10 +10:00
psychedelicious
b19df9ddcf feat(ui): images always added to gallery in simple session 2025-06-24 05:41:10 +10:00
psychedelicious
0c34a49c58 wip 2025-06-24 05:41:10 +10:00
psychedelicious
a0721835d6 refactor(ui): canvas flow (wip) 2025-06-24 05:41:10 +10:00
psychedelicious
6a534d5b4f refactor(ui): canvas flow (wip) 2025-06-24 05:41:10 +10:00
psychedelicious
0f680e16b6 refactor(ui): canvas flow events (wip) 2025-06-24 05:41:10 +10:00
psychedelicious
83e82e25a6 refactor(ui): canvas flow (wip) 2025-06-24 05:41:09 +10:00
psychedelicious
b85736ccf3 refactor(ui): canvas flow (wip) 2025-06-24 05:41:09 +10:00
psychedelicious
2fb1d93038 refactor(ui): canvas flow (wip) 2025-06-24 05:41:09 +10:00
psychedelicious
05543b6073 refactor(ui): canvas flow (wip) 2025-06-24 05:41:09 +10:00
psychedelicious
95986e4aa0 fix(ui): circular import issue 2025-06-24 05:41:09 +10:00
psychedelicious
f7bf459721 refactor(ui): params state zodification 2025-06-24 05:41:09 +10:00
psychedelicious
b266ce78f5 refactor(ui): move params state to big file of canvas zod stuff 2025-06-24 05:41:09 +10:00
psychedelicious
fcaf216a3d refactor(ui): zod-ify params slice state 2025-06-24 05:41:09 +10:00
psychedelicious
12fdc934ee refactor(ui): org state in prep for new flow 2025-06-24 05:41:09 +10:00
psychedelicious
dee7b74c59 refactor(ui): image viewer & comparison convolutedness 2025-06-24 05:41:09 +10:00
psychedelicious
1411395b06 feat(ui): default canvas tool is move 2025-06-24 05:41:08 +10:00
psychedelicious
e3a56c8b81 chore(ui): bump @reduxjs/toolkit to latest 2025-06-24 05:41:08 +10:00
psychedelicious
eb41fd1ca6 feat(ui): viewer is a modal (wip) 2025-06-24 05:41:08 +10:00
Kent Keirsey
6a78739076 Change save button to Invoke Blue 2025-06-20 15:07:40 +10:00
psychedelicious
0794eb43e7 fix(nodes): ensure each invocation overrides _original_model_fields with own field data 2025-06-20 15:03:55 +10:00
Mary Hipp
291e0736d6 fix names of unpublishable nodes 2025-06-16 12:40:54 -04:00
psychedelicious
4bfa6439d4 chore(ui): typgen 2025-06-16 19:33:19 +10:00
psychedelicious
a8d7969a1d fix(app): config docstrings 2025-06-16 19:33:19 +10:00
Heathen711
46bfa24af3 ruff format 2025-06-16 19:33:19 +10:00
Heathen711
a8cb8e128d run "make frontend-typegen" 2025-06-16 19:33:19 +10:00
Heathen711
8cef0f5bf5 Update supported cuda slot input. 2025-06-16 19:33:19 +10:00
psychedelicious
911baeb58b chore(ui): bump version to v5.15.0 2025-06-16 19:18:25 +10:00
Kevin Turner
312960645b fix: move AI Toolkit to the bottom of the detection list
to avoid disrupting already-working LoRA
2025-06-16 19:08:11 +10:00
Kevin Turner
50cf285efb fix: group aitoolkit lora layers 2025-06-16 19:08:11 +10:00
Kevin Turner
a214f4fff5 fix: group aitoolkit lora layers 2025-06-16 19:08:11 +10:00
Kevin Turner
2981591c36 test: add some aitoolkit lora tests 2025-06-16 19:08:11 +10:00
Kevin Turner
b08f90c99f WIP!: …they weren't in diffusers format… 2025-06-16 19:08:11 +10:00
Kevin Turner
ab8c739cd8 fix(LoRA): add ai-toolkit to lora loader 2025-06-16 19:08:11 +10:00
Kevin Turner
5c5108c28a feat(LoRA): support AI Toolkit LoRA for FLUX [WIP] 2025-06-16 19:08:11 +10:00
j-brooke
3df7cfd605 Updated fracturedjsonjs to version 4.1.0 and included settings adjustments for more pleasing comma placement. 2025-06-14 14:59:43 +10:00
psychedelicious
1ff3d44dba fix(app): guard against possible race conditions during enqueue
In #7724 we made a number of perf optimisations related to enqueuing. One of these optimisations included moving the enqueue logic - including expensive prep work and db writes - to a separate thread.

At the same time manual DB locking was abandoned in favor of WAL mode.

Finally, we set `check_same_thread=False` to allow multiple threads to access the connection at a given time.

I think this may be the cause of #7950:
- We start an enqueue in a thread (running in bg)
- We dequeue
- Dequeue pulls a partially-written queue item from DB and we get the errors in the linked issue

To be honest, I don't understand enough about SQLite to confidently say that this kind of race condition is actually possible. But:
- The error started popping up around the time we made this change.
- I have reviewed the logic from enqueue to dequeue very carefully _many_ times over the past month or so, and I am confident that the error is only possible if we are getting unexpectedly `NULL` values from the DB.
- The DB schema includes `NOT NULL` constraints for the column that is apparently returning `NULL`.
- Therefore, without some kind of race condition or schema issue, the error should not be possible.
- The `enqueue_batch` call is the only place I can find where we have the possibility of a race condition due to async logic. Everywhere else, all DB interaction for the queue is synchronous, as far as I can tell.

This change retains the perf benefits by running the heavy enqueue prep logic in a separate thread, but moves back to the main thread for the DB write. It also uses an explicit transaction for the write.

Will just have to wait and see if this fixes the issue.
2025-06-13 23:51:47 +10:00
Emmanuel Ferdman
c80ad90f72 Migrate to modern logger interface
Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com>
2025-06-13 13:07:09 +10:00
psychedelicious
3b4d1b8786 perf(app): gc before every queue item
This reduces peak memory usage at a negligible cost. Queue items typically take on the order of seconds, making the time cost of a GC essentially free.

Not a great idea on a hotter code path though.
2025-06-11 12:56:16 +10:00
psychedelicious
c66201c7e1 perf(app): skip TI logic when no TIs to apply 2025-06-11 12:56:16 +10:00
psychedelicious
35c7c59455 fix(app): reduce peak memory usage
We've long suspected there is a memory leak in Invoke, but that may not be true. What looks like a memory leak may in fact be the expected behaviour for our allocation patterns.

We observe ~20 to ~30 MB increase in memory usage per session executed. I did some prolonged tests, where I measured the process's RSS in bytes while doing 200 SDXL generations. I found that it eventually leveled off at around 100 generations, at which point memory usage had climbed by ~900MB from its starting point.

I used tracemalloc to diff the allocations of single session executions and found that we are allocating ~20MB or so per session in `ModelPatcher.apply_ti()`.

In `ModelPatcher.apply_ti()` we add tokens to the tokenizer when handling TIs. The added tokens should be scoped to only the current invocation, but there is no simple way to remove the tokens afterwards.

As a workaround for this, we clone the tokenizer, add the TI tokens to the clone, and use the clone to when running compel. Afterwards, this cloned tokenizer is discarded.

The tokenizer uses ~20MB of memory, and it has referrers/referents to other compel stuff. This is what is causing the observed increases in memory per session!

We'd expect these objects to be GC'd but python doesn't do it immediately. After creating the cond tensors, we quickly move on to denoising. So there isn't any time for the GC to happen to free up its existing memory arenas/blocks to reuse them. Instead, python needs to request more memory from the OS.

We can improve the situation by immediately calling `del` on the tokenizer clone and related objects. In fact, we already had some code in the compel nodes to `del` some of these objects, but not all.

Adding the `del`s vastly improves things. We hit peak RSS in half the sessions (~50 or less) and it's now ~100MB more than starting value. There is still a gradual increase in memory usage until we level off.
2025-06-11 12:56:16 +10:00
psychedelicious
85f98ab3eb fix(app): error on upload + resize for unusual image modes 2025-06-11 11:18:08 +10:00
Mary Hipp
dac75685be disable publish and cancel buttons once it begins 2025-06-10 19:50:09 -04:00
psychedelicious
d7b5a8b298 fix: opencv dependency conflict (#8095)
* build: prevent `opencv-python` from being installed

Fixes this error: `AttributeError: module 'cv2.ximgproc' has no attribute 'thinning'`

`opencv-contrib-python` supersedes `opencv-python`, providing the same API + additional features. The two packages should not be installed at the same time to avoid conflicts and/or errors.

The `invisible-watermark` package requires `opencv-python`, but we require the contrib variant.

This change updates `pyproject.toml` to prevent `opencv-python` from ever being installed using a `uv` features called dependency overrides.

* feat(ui): data viewer supports disabling wrap

* feat(api): list _all_ pkgs in app deps endpoint

* chore(ui): typegen

* feat(ui): update about modal to display new full deps list

* chore: uv lock
2025-06-10 08:33:41 -04:00
Kent Keirsey
d3ecaa740f Add Precise Reference to Starter Models 2025-06-09 22:02:11 +10:00
dunkeroni
b5a6765a3d also search image creation date 2025-06-09 21:54:26 +10:00
440 changed files with 12900 additions and 7298 deletions

1
.gitignore vendored
View File

@@ -180,6 +180,7 @@ cython_debug/
# Scratch folder
.scratch/
.vscode/
.zed/
# source installer files
installer/*zip

View File

@@ -1,8 +1,7 @@
import typing
from enum import Enum
from importlib.metadata import PackageNotFoundError, version
from importlib.metadata import distributions
from pathlib import Path
from platform import python_version
from typing import Optional
import torch
@@ -44,24 +43,6 @@ class AppVersion(BaseModel):
highlights: Optional[list[str]] = Field(default=None, description="Highlights of release")
class AppDependencyVersions(BaseModel):
"""App depencency Versions Response"""
accelerate: str = Field(description="accelerate version")
compel: str = Field(description="compel version")
cuda: Optional[str] = Field(description="CUDA version")
diffusers: str = Field(description="diffusers version")
numpy: str = Field(description="Numpy version")
opencv: str = Field(description="OpenCV version")
onnx: str = Field(description="ONNX version")
pillow: str = Field(description="Pillow (PIL) version")
python: str = Field(description="Python version")
torch: str = Field(description="PyTorch version")
torchvision: str = Field(description="PyTorch Vision version")
transformers: str = Field(description="transformers version")
xformers: Optional[str] = Field(description="xformers version")
class AppConfig(BaseModel):
"""App Config Response"""
@@ -76,27 +57,19 @@ async def get_version() -> AppVersion:
return AppVersion(version=__version__)
@app_router.get("/app_deps", operation_id="get_app_deps", status_code=200, response_model=AppDependencyVersions)
async def get_app_deps() -> AppDependencyVersions:
@app_router.get("/app_deps", operation_id="get_app_deps", status_code=200, response_model=dict[str, str])
async def get_app_deps() -> dict[str, str]:
deps: dict[str, str] = {dist.metadata["Name"]: dist.version for dist in distributions()}
try:
xformers = version("xformers")
except PackageNotFoundError:
xformers = None
return AppDependencyVersions(
accelerate=version("accelerate"),
compel=version("compel"),
cuda=torch.version.cuda,
diffusers=version("diffusers"),
numpy=version("numpy"),
opencv=version("opencv-python"),
onnx=version("onnx"),
pillow=version("pillow"),
python=python_version(),
torch=torch.version.__version__,
torchvision=version("torchvision"),
transformers=version("transformers"),
xformers=xformers,
)
cuda = torch.version.cuda or "N/A"
except Exception:
cuda = "N/A"
deps["CUDA"] = cuda
sorted_deps = dict(sorted(deps.items(), key=lambda item: item[0].lower()))
return sorted_deps
@app_router.get("/config", operation_id="get_config", status_code=200, response_model=AppConfig)

View File

@@ -1,21 +1,12 @@
from fastapi import Body, HTTPException
from fastapi.routing import APIRouter
from pydantic import BaseModel, Field
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.images.images_common import AddImagesToBoardResult, RemoveImagesFromBoardResult
board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"])
class AddImagesToBoardResult(BaseModel):
board_id: str = Field(description="The id of the board the images were added to")
added_image_names: list[str] = Field(description="The image names that were added to the board")
class RemoveImagesFromBoardResult(BaseModel):
removed_image_names: list[str] = Field(description="The image names that were removed from their board")
@board_images_router.post(
"/",
operation_id="add_image_to_board",
@@ -23,17 +14,26 @@ class RemoveImagesFromBoardResult(BaseModel):
201: {"description": "The image was added to a board successfully"},
},
status_code=201,
response_model=AddImagesToBoardResult,
)
async def add_image_to_board(
board_id: str = Body(description="The id of the board to add to"),
image_name: str = Body(description="The name of the image to add"),
):
) -> AddImagesToBoardResult:
"""Creates a board_image"""
try:
result = ApiDependencies.invoker.services.board_images.add_image_to_board(
board_id=board_id, image_name=image_name
added_images: set[str] = set()
affected_boards: set[str] = set()
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
ApiDependencies.invoker.services.board_images.add_image_to_board(board_id=board_id, image_name=image_name)
added_images.add(image_name)
affected_boards.add(board_id)
affected_boards.add(old_board_id)
return AddImagesToBoardResult(
added_images=list(added_images),
affected_boards=list(affected_boards),
)
return result
except Exception:
raise HTTPException(status_code=500, detail="Failed to add image to board")
@@ -45,14 +45,25 @@ async def add_image_to_board(
201: {"description": "The image was removed from the board successfully"},
},
status_code=201,
response_model=RemoveImagesFromBoardResult,
)
async def remove_image_from_board(
image_name: str = Body(description="The name of the image to remove", embed=True),
):
) -> RemoveImagesFromBoardResult:
"""Removes an image from its board, if it had one"""
try:
result = ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
return result
removed_images: set[str] = set()
affected_boards: set[str] = set()
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
removed_images.add(image_name)
affected_boards.add("none")
affected_boards.add(old_board_id)
return RemoveImagesFromBoardResult(
removed_images=list(removed_images),
affected_boards=list(affected_boards),
)
except Exception:
raise HTTPException(status_code=500, detail="Failed to remove image from board")
@@ -72,16 +83,25 @@ async def add_images_to_board(
) -> AddImagesToBoardResult:
"""Adds a list of images to a board"""
try:
added_image_names: list[str] = []
added_images: set[str] = set()
affected_boards: set[str] = set()
for image_name in image_names:
try:
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
ApiDependencies.invoker.services.board_images.add_image_to_board(
board_id=board_id, image_name=image_name
board_id=board_id,
image_name=image_name,
)
added_image_names.append(image_name)
added_images.add(image_name)
affected_boards.add(board_id)
affected_boards.add(old_board_id)
except Exception:
pass
return AddImagesToBoardResult(board_id=board_id, added_image_names=added_image_names)
return AddImagesToBoardResult(
added_images=list(added_images),
affected_boards=list(affected_boards),
)
except Exception:
raise HTTPException(status_code=500, detail="Failed to add images to board")
@@ -100,13 +120,20 @@ async def remove_images_from_board(
) -> RemoveImagesFromBoardResult:
"""Removes a list of images from their board, if they had one"""
try:
removed_image_names: list[str] = []
removed_images: set[str] = set()
affected_boards: set[str] = set()
for image_name in image_names:
try:
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
removed_image_names.append(image_name)
removed_images.add(image_name)
affected_boards.add("none")
affected_boards.add(old_board_id)
except Exception:
pass
return RemoveImagesFromBoardResult(removed_image_names=removed_image_names)
return RemoveImagesFromBoardResult(
removed_images=list(removed_images),
affected_boards=list(affected_boards),
)
except Exception:
raise HTTPException(status_code=500, detail="Failed to remove images from board")

View File

@@ -1,7 +1,7 @@
import io
import json
import traceback
from typing import ClassVar, Optional
from typing import ClassVar, Literal, Optional
from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request, Response, UploadFile
from fastapi.responses import FileResponse
@@ -14,10 +14,17 @@ from invokeai.app.api.extract_metadata_from_image import extract_metadata_from_i
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
ImageCollectionCounts,
ImageRecordChanges,
ResourceOrigin,
)
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
from invokeai.app.services.images.images_common import (
DeleteImagesResult,
ImageDTO,
ImageUrlsDTO,
StarredImagesResult,
UnstarredImagesResult,
)
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.util.controlnet_utils import heuristic_resize_fast
@@ -99,7 +106,9 @@ async def upload_image(
raise HTTPException(status_code=400, detail="Invalid resize_to format or size")
try:
np_image = pil_to_np(pil_image)
# heuristic_resize_fast expects an RGB or RGBA image
pil_rgba = pil_image.convert("RGBA")
np_image = pil_to_np(pil_rgba)
np_image = heuristic_resize_fast(np_image, (resize_dims.width, resize_dims.height))
pil_image = np_to_pil(np_image)
except Exception:
@@ -151,18 +160,30 @@ async def create_image_upload_entry(
raise HTTPException(status_code=501, detail="Not implemented")
@images_router.delete("/i/{image_name}", operation_id="delete_image")
@images_router.delete("/i/{image_name}", operation_id="delete_image", response_model=DeleteImagesResult)
async def delete_image(
image_name: str = Path(description="The name of the image to delete"),
) -> None:
) -> DeleteImagesResult:
"""Deletes an image"""
deleted_images: set[str] = set()
affected_boards: set[str] = set()
try:
image_dto = ApiDependencies.invoker.services.images.get_dto(image_name)
board_id = image_dto.board_id or "none"
ApiDependencies.invoker.services.images.delete(image_name)
deleted_images.add(image_name)
affected_boards.add(board_id)
except Exception:
# TODO: Does this need any exception handling at all?
pass
return DeleteImagesResult(
deleted_images=list(deleted_images),
affected_boards=list(affected_boards),
)
@images_router.delete("/intermediates", operation_id="clear_intermediates")
async def clear_intermediates() -> int:
@@ -374,31 +395,32 @@ async def list_image_dtos(
return image_dtos
class DeleteImagesFromListResult(BaseModel):
deleted_images: list[str]
@images_router.post("/delete", operation_id="delete_images_from_list", response_model=DeleteImagesFromListResult)
@images_router.post("/delete", operation_id="delete_images_from_list", response_model=DeleteImagesResult)
async def delete_images_from_list(
image_names: list[str] = Body(description="The list of names of images to delete", embed=True),
) -> DeleteImagesFromListResult:
) -> DeleteImagesResult:
try:
deleted_images: list[str] = []
deleted_images: set[str] = set()
affected_boards: set[str] = set()
for image_name in image_names:
try:
image_dto = ApiDependencies.invoker.services.images.get_dto(image_name)
board_id = image_dto.board_id or "none"
ApiDependencies.invoker.services.images.delete(image_name)
deleted_images.append(image_name)
deleted_images.add(image_name)
affected_boards.add(board_id)
except Exception:
pass
return DeleteImagesFromListResult(deleted_images=deleted_images)
return DeleteImagesResult(
deleted_images=list(deleted_images),
affected_boards=list(affected_boards),
)
except Exception:
raise HTTPException(status_code=500, detail="Failed to delete images")
@images_router.delete(
"/uncategorized", operation_id="delete_uncategorized_images", response_model=DeleteImagesFromListResult
)
async def delete_uncategorized_images() -> DeleteImagesFromListResult:
@images_router.delete("/uncategorized", operation_id="delete_uncategorized_images", response_model=DeleteImagesResult)
async def delete_uncategorized_images() -> DeleteImagesResult:
"""Deletes all images that are uncategorized"""
image_names = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board(
@@ -406,14 +428,19 @@ async def delete_uncategorized_images() -> DeleteImagesFromListResult:
)
try:
deleted_images: list[str] = []
deleted_images: set[str] = set()
affected_boards: set[str] = set()
for image_name in image_names:
try:
ApiDependencies.invoker.services.images.delete(image_name)
deleted_images.append(image_name)
deleted_images.add(image_name)
affected_boards.add("none")
except Exception:
pass
return DeleteImagesFromListResult(deleted_images=deleted_images)
return DeleteImagesResult(
deleted_images=list(deleted_images),
affected_boards=list(affected_boards),
)
except Exception:
raise HTTPException(status_code=500, detail="Failed to delete images")
@@ -422,36 +449,50 @@ class ImagesUpdatedFromListResult(BaseModel):
updated_image_names: list[str] = Field(description="The image names that were updated")
@images_router.post("/star", operation_id="star_images_in_list", response_model=ImagesUpdatedFromListResult)
@images_router.post("/star", operation_id="star_images_in_list", response_model=StarredImagesResult)
async def star_images_in_list(
image_names: list[str] = Body(description="The list of names of images to star", embed=True),
) -> ImagesUpdatedFromListResult:
) -> StarredImagesResult:
try:
updated_image_names: list[str] = []
starred_images: set[str] = set()
affected_boards: set[str] = set()
for image_name in image_names:
try:
ApiDependencies.invoker.services.images.update(image_name, changes=ImageRecordChanges(starred=True))
updated_image_names.append(image_name)
updated_image_dto = ApiDependencies.invoker.services.images.update(
image_name, changes=ImageRecordChanges(starred=True)
)
starred_images.add(image_name)
affected_boards.add(updated_image_dto.board_id or "none")
except Exception:
pass
return ImagesUpdatedFromListResult(updated_image_names=updated_image_names)
return StarredImagesResult(
starred_images=list(starred_images),
affected_boards=list(affected_boards),
)
except Exception:
raise HTTPException(status_code=500, detail="Failed to star images")
@images_router.post("/unstar", operation_id="unstar_images_in_list", response_model=ImagesUpdatedFromListResult)
@images_router.post("/unstar", operation_id="unstar_images_in_list", response_model=UnstarredImagesResult)
async def unstar_images_in_list(
image_names: list[str] = Body(description="The list of names of images to unstar", embed=True),
) -> ImagesUpdatedFromListResult:
) -> UnstarredImagesResult:
try:
updated_image_names: list[str] = []
unstarred_images: set[str] = set()
affected_boards: set[str] = set()
for image_name in image_names:
try:
ApiDependencies.invoker.services.images.update(image_name, changes=ImageRecordChanges(starred=False))
updated_image_names.append(image_name)
updated_image_dto = ApiDependencies.invoker.services.images.update(
image_name, changes=ImageRecordChanges(starred=False)
)
unstarred_images.add(image_name)
affected_boards.add(updated_image_dto.board_id or "none")
except Exception:
pass
return ImagesUpdatedFromListResult(updated_image_names=updated_image_names)
return UnstarredImagesResult(
unstarred_images=list(unstarred_images),
affected_boards=list(affected_boards),
)
except Exception:
raise HTTPException(status_code=500, detail="Failed to unstar images")
@@ -522,3 +563,92 @@ async def get_bulk_download_item(
return response
except Exception:
raise HTTPException(status_code=404)
@images_router.get(
"/collections/counts", operation_id="get_image_collection_counts", response_model=ImageCollectionCounts
)
async def get_image_collection_counts(
image_origin: Optional[ResourceOrigin] = Query(default=None, description="The origin of images to count."),
categories: Optional[list[ImageCategory]] = Query(default=None, description="The categories of image to include."),
is_intermediate: Optional[bool] = Query(default=None, description="Whether to include intermediate images."),
board_id: Optional[str] = Query(
default=None,
description="The board id to filter by. Use 'none' to find images without a board.",
),
search_term: Optional[str] = Query(default=None, description="The term to search for"),
) -> ImageCollectionCounts:
"""Gets counts for starred and unstarred image collections"""
try:
return ApiDependencies.invoker.services.images.get_collection_counts(
image_origin=image_origin,
categories=categories,
is_intermediate=is_intermediate,
board_id=board_id,
search_term=search_term,
)
except Exception:
raise HTTPException(status_code=500, detail="Failed to get collection counts")
@images_router.get("/collections/{collection}", operation_id="get_image_collection")
async def get_image_collection(
collection: Literal["starred", "unstarred"] = Path(..., description="The collection to retrieve from"),
image_origin: Optional[ResourceOrigin] = Query(default=None, description="The origin of images to list."),
categories: Optional[list[ImageCategory]] = Query(default=None, description="The categories of image to include."),
is_intermediate: Optional[bool] = Query(default=None, description="Whether to list intermediate images."),
board_id: Optional[str] = Query(
default=None,
description="The board id to filter by. Use 'none' to find images without a board.",
),
offset: int = Query(default=0, description="The offset within the collection"),
limit: int = Query(default=50, description="The number of images to return"),
order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"),
search_term: Optional[str] = Query(default=None, description="The term to search for"),
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets images from a specific collection (starred or unstarred)"""
try:
image_dtos = ApiDependencies.invoker.services.images.get_collection_images(
collection=collection,
offset=offset,
limit=limit,
order_dir=order_dir,
image_origin=image_origin,
categories=categories,
is_intermediate=is_intermediate,
board_id=board_id,
search_term=search_term,
)
return image_dtos
except Exception:
raise HTTPException(status_code=500, detail="Failed to get collection images")
@images_router.get("/names", operation_id="get_image_names")
async def get_image_names(
image_origin: Optional[ResourceOrigin] = Query(default=None, description="The origin of images to list."),
categories: Optional[list[ImageCategory]] = Query(default=None, description="The categories of image to include."),
is_intermediate: Optional[bool] = Query(default=None, description="Whether to list intermediate images."),
board_id: Optional[str] = Query(
default=None,
description="The board id to filter by. Use 'none' to find images without a board.",
),
order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"),
search_term: Optional[str] = Query(default=None, description="The term to search for"),
) -> list[str]:
"""Gets ordered list of all image names (starred first, then unstarred)"""
try:
image_names = ApiDependencies.invoker.services.images.get_image_names(
order_dir=order_dir,
image_origin=image_origin,
categories=categories,
is_intermediate=is_intermediate,
board_id=board_id,
search_term=search_term,
)
return image_names
except Exception:
raise HTTPException(status_code=500, detail="Failed to get image names")

View File

@@ -14,13 +14,14 @@ from invokeai.app.services.session_queue.session_queue_common import (
CancelByBatchIDsResult,
CancelByDestinationResult,
ClearResult,
DeleteAllExceptCurrentResult,
DeleteByDestinationResult,
EnqueueBatchResult,
FieldIdentifier,
PruneResult,
RetryItemsResult,
SessionQueueCountsByDestination,
SessionQueueItem,
SessionQueueItemDTO,
SessionQueueStatus,
)
from invokeai.app.services.shared.pagination import CursorPaginatedResults
@@ -68,7 +69,7 @@ async def enqueue_batch(
"/{queue_id}/list",
operation_id="list_queue_items",
responses={
200: {"model": CursorPaginatedResults[SessionQueueItemDTO]},
200: {"model": CursorPaginatedResults[SessionQueueItem]},
},
)
async def list_queue_items(
@@ -77,11 +78,36 @@ async def list_queue_items(
status: Optional[QUEUE_ITEM_STATUS] = Query(default=None, description="The status of items to fetch"),
cursor: Optional[int] = Query(default=None, description="The pagination cursor"),
priority: int = Query(default=0, description="The pagination cursor priority"),
) -> CursorPaginatedResults[SessionQueueItemDTO]:
"""Gets all queue items (without graphs)"""
destination: Optional[str] = Query(default=None, description="The destination of queue items to fetch"),
) -> CursorPaginatedResults[SessionQueueItem]:
"""Gets cursor-paginated queue items"""
return ApiDependencies.invoker.services.session_queue.list_queue_items(
queue_id=queue_id, limit=limit, status=status, cursor=cursor, priority=priority
queue_id=queue_id,
limit=limit,
status=status,
cursor=cursor,
priority=priority,
destination=destination,
)
@session_queue_router.get(
"/{queue_id}/list_all",
operation_id="list_all_queue_items",
responses={
200: {"model": list[SessionQueueItem]},
},
)
async def list_all_queue_items(
queue_id: str = Path(description="The queue id to perform this operation on"),
destination: Optional[str] = Query(default=None, description="The destination of queue items to fetch"),
) -> list[SessionQueueItem]:
"""Gets all queue items"""
return ApiDependencies.invoker.services.session_queue.list_all_queue_items(
queue_id=queue_id,
destination=destination,
)
@@ -121,6 +147,18 @@ async def cancel_all_except_current(
return ApiDependencies.invoker.services.session_queue.cancel_all_except_current(queue_id=queue_id)
@session_queue_router.put(
"/{queue_id}/delete_all_except_current",
operation_id="delete_all_except_current",
responses={200: {"model": DeleteAllExceptCurrentResult}},
)
async def delete_all_except_current(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> DeleteAllExceptCurrentResult:
"""Immediately deletes all queue items except in-processing items"""
return ApiDependencies.invoker.services.session_queue.delete_all_except_current(queue_id=queue_id)
@session_queue_router.put(
"/{queue_id}/cancel_by_batch_ids",
operation_id="cancel_by_batch_ids",
@@ -269,6 +307,18 @@ async def get_queue_item(
return ApiDependencies.invoker.services.session_queue.get_queue_item(item_id)
@session_queue_router.delete(
"/{queue_id}/i/{item_id}",
operation_id="delete_queue_item",
)
async def delete_queue_item(
queue_id: str = Path(description="The queue id to perform this operation on"),
item_id: int = Path(description="The queue item to delete"),
) -> None:
"""Deletes a queue item"""
ApiDependencies.invoker.services.session_queue.delete_queue_item(item_id)
@session_queue_router.put(
"/{queue_id}/i/{item_id}/cancel",
operation_id="cancel_queue_item",
@@ -298,3 +348,18 @@ async def counts_by_destination(
return ApiDependencies.invoker.services.session_queue.get_counts_by_destination(
queue_id=queue_id, destination=destination
)
@session_queue_router.delete(
"/{queue_id}/d/{destination}",
operation_id="delete_by_destination",
responses={200: {"model": DeleteByDestinationResult}},
)
async def delete_by_destination(
queue_id: str = Path(description="The queue id to query"),
destination: str = Path(description="The destination to query"),
) -> DeleteByDestinationResult:
"""Deletes all items with the given destination"""
return ApiDependencies.invoker.services.session_queue.delete_by_destination(
queue_id=queue_id, destination=destination
)

View File

@@ -158,7 +158,7 @@ web_root_path = Path(list(web_dir.__path__)[0])
try:
app.mount("/", NoCacheStaticFiles(directory=Path(web_root_path, "dist"), html=True), name="ui")
except RuntimeError:
logger.warn(f"No UI found at {web_root_path}/dist, skipping UI mount")
logger.warning(f"No UI found at {web_root_path}/dist, skipping UI mount")
app.mount(
"/static", NoCacheStaticFiles(directory=Path(web_root_path, "static/")), name="static"
) # docs favicon is in here

View File

@@ -499,7 +499,7 @@ def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None
ui_type = field.json_schema_extra.get("ui_type", None)
if isinstance(ui_type, str) and ui_type.startswith("DEPRECATED_"):
logger.warn(f'"UIType.{ui_type.split("_")[-1]}" is deprecated, ignoring')
logger.warning(f'"UIType.{ui_type.split("_")[-1]}" is deprecated, ignoring')
field.json_schema_extra.pop("ui_type")
return None
@@ -582,6 +582,8 @@ def invocation(
fields: dict[str, tuple[Any, FieldInfo]] = {}
original_model_fields: dict[str, OriginalModelField] = {}
for field_name, field_info in cls.model_fields.items():
annotation = field_info.annotation
assert annotation is not None, f"{field_name} on invocation {invocation_type} has no type annotation."
@@ -589,7 +591,7 @@ def invocation(
f"{field_name} on invocation {invocation_type} has a non-dict json_schema_extra, did you forget to use InputField?"
)
cls._original_model_fields[field_name] = OriginalModelField(annotation=annotation, field_info=field_info)
original_model_fields[field_name] = OriginalModelField(annotation=annotation, field_info=field_info)
validate_field_default(cls.__name__, field_name, invocation_type, annotation, field_info)
@@ -613,7 +615,7 @@ def invocation(
raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e
uiconfig["version"] = version
else:
logger.warn(f'No version specified for node "{invocation_type}", using "1.0.0"')
logger.warning(f'No version specified for node "{invocation_type}", using "1.0.0"')
uiconfig["version"] = "1.0.0"
cls.UIConfig = UIConfigBase(**uiconfig)
@@ -676,6 +678,7 @@ def invocation(
docstring = cls.__doc__
new_class = create_model(cls.__qualname__, __base__=cls, __module__=cls.__module__, **fields) # type: ignore
new_class.__doc__ = docstring
new_class._original_model_fields = original_model_fields
InvocationRegistry.register_invocation(new_class)

View File

@@ -114,6 +114,13 @@ class CompelInvocation(BaseInvocation):
c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction)
del compel
del patched_tokenizer
del tokenizer
del ti_manager
del text_encoder
del text_encoder_info
c = c.detach().to("cpu")
conditioning_data = ConditioningFieldData(conditionings=[BasicConditioningInfo(embeds=c)])
@@ -222,7 +229,10 @@ class SDXLPromptInvocationBase:
else:
c_pooled = None
del compel
del patched_tokenizer
del tokenizer
del ti_manager
del text_encoder
del text_encoder_info

View File

@@ -437,7 +437,7 @@ class WithWorkflow:
workflow = None
def __init_subclass__(cls) -> None:
logger.warn(
logger.warning(
f"{cls.__module__.split('.')[0]}.{cls.__name__}: WithWorkflow is deprecated. Use `context.workflow` to access the workflow."
)
super().__init_subclass__()
@@ -578,7 +578,7 @@ def InputField(
if default_factory is not _Unset and default_factory is not None:
default = default_factory()
logger.warn('"default_factory" is not supported, calling it now to set "default"')
logger.warning('"default_factory" is not supported, calling it now to set "default"')
# These are the args we may wish pass to the pydantic `Field()` function
field_args = {

View File

@@ -24,7 +24,6 @@ from invokeai.frontend.cli.arg_parser import InvokeAIArgs
INIT_FILE = Path("invokeai.yaml")
DB_FILE = Path("invokeai.db")
LEGACY_INIT_FILE = Path("invokeai.init")
DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
PRECISION = Literal["auto", "float16", "bfloat16", "float32"]
ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
@@ -93,7 +92,7 @@ class InvokeAIAppConfig(BaseSettings):
vram: DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_vram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.
lazy_offload: DEPRECATED: This setting is no longer used. Lazy-offloading is enabled by default. This config setting will be removed once the new model cache behavior is stable.
pytorch_cuda_alloc_conf: Configure the Torch CUDA memory allocator. This will impact peak reserved VRAM usage and performance. Setting to "backend:cudaMallocAsync" works well on many systems. The optimal configuration is highly dependent on the system configuration (device type, VRAM, CUDA driver version, etc.), so must be tuned experimentally.
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number)
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
attention_type: Attention type.<br>Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp`
@@ -176,7 +175,7 @@ class InvokeAIAppConfig(BaseSettings):
pytorch_cuda_alloc_conf: Optional[str] = Field(default=None, description="Configure the Torch CUDA memory allocator. This will impact peak reserved VRAM usage and performance. Setting to \"backend:cudaMallocAsync\" works well on many systems. The optimal configuration is highly dependent on the system configuration (device type, VRAM, CUDA driver version, etc.), so must be tuned experimentally.")
# DEVICE
device: DEVICE = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.")
device: str = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number)", pattern=r"^(auto|cpu|mps|cuda(:\d+)?)$")
precision: PRECISION = Field(default="auto", description="Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.")
# GENERATION

View File

@@ -1,10 +1,11 @@
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Optional
from typing import Literal, Optional
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
ImageCollectionCounts,
ImageRecord,
ImageRecordChanges,
ResourceOrigin,
@@ -97,3 +98,44 @@ class ImageRecordStorageBase(ABC):
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
"""Gets the most recent image for a board."""
pass
@abstractmethod
def get_collection_counts(
self,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> ImageCollectionCounts:
"""Gets counts for starred and unstarred image collections."""
pass
@abstractmethod
def get_collection_images(
self,
collection: Literal["starred", "unstarred"],
offset: int = 0,
limit: int = 10,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageRecord]:
"""Gets images from a specific collection (starred or unstarred)."""
pass
@abstractmethod
def get_image_names(
self,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> list[str]:
"""Gets ordered list of all image names (starred first, then unstarred)."""
pass

View File

@@ -3,7 +3,7 @@ import datetime
from enum import Enum
from typing import Optional, Union
from pydantic import Field, StrictBool, StrictStr
from pydantic import BaseModel, Field, StrictBool, StrictStr
from invokeai.app.util.metaenum import MetaEnum
from invokeai.app.util.misc import get_iso_timestamp
@@ -207,3 +207,8 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
starred=starred,
has_workflow=has_workflow,
)
class ImageCollectionCounts(BaseModel):
starred_count: int = Field(description="The number of starred images in the collection.")
unstarred_count: int = Field(description="The number of unstarred images in the collection.")

View File

@@ -1,12 +1,13 @@
import sqlite3
from datetime import datetime
from typing import Optional, Union, cast
from typing import Literal, Optional, Union, cast
from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator
from invokeai.app.services.image_records.image_records_base import ImageRecordStorageBase
from invokeai.app.services.image_records.image_records_common import (
IMAGE_DTO_COLS,
ImageCategory,
ImageCollectionCounts,
ImageRecord,
ImageRecordChanges,
ImageRecordDeleteException,
@@ -196,9 +197,13 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
# Search term condition
if search_term:
query_conditions += """--sql
AND images.metadata LIKE ?
AND (
images.metadata LIKE ?
OR images.created_at LIKE ?
)
"""
query_params.append(f"%{search_term.lower()}%")
query_params.append(f"%{search_term.lower()}%")
if starred_first:
query_pagination = f"""--sql
@@ -382,3 +387,253 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
return None
return deserialize_image_record(dict(result))
def get_collection_counts(
self,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> ImageCollectionCounts:
cursor = self._conn.cursor()
# Build the base query conditions (same as get_many)
base_query = """--sql
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
query_conditions = ""
query_params: list[Union[int, str, bool]] = []
if image_origin is not None:
query_conditions += """--sql
AND images.image_origin = ?
"""
query_params.append(image_origin.value)
if categories is not None:
category_strings = [c.value for c in set(categories)]
placeholders = ",".join("?" * len(category_strings))
query_conditions += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
for c in category_strings:
query_params.append(c)
if is_intermediate is not None:
query_conditions += """--sql
AND images.is_intermediate = ?
"""
query_params.append(is_intermediate)
if board_id == "none":
query_conditions += """--sql
AND board_images.board_id IS NULL
"""
elif board_id is not None:
query_conditions += """--sql
AND board_images.board_id = ?
"""
query_params.append(board_id)
if search_term:
query_conditions += """--sql
AND (
images.metadata LIKE ?
OR images.created_at LIKE ?
)
"""
query_params.append(f"%{search_term.lower()}%")
query_params.append(f"%{search_term.lower()}%")
# Get starred count
starred_query = f"SELECT COUNT(*) {base_query} {query_conditions} AND images.starred = TRUE;"
cursor.execute(starred_query, query_params)
starred_count = cast(int, cursor.fetchone()[0])
# Get unstarred count
unstarred_query = f"SELECT COUNT(*) {base_query} {query_conditions} AND images.starred = FALSE;"
cursor.execute(unstarred_query, query_params)
unstarred_count = cast(int, cursor.fetchone()[0])
return ImageCollectionCounts(starred_count=starred_count, unstarred_count=unstarred_count)
def get_collection_images(
self,
collection: Literal["starred", "unstarred"],
offset: int = 0,
limit: int = 10,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageRecord]:
cursor = self._conn.cursor()
# Base queries
count_query = """--sql
SELECT COUNT(*)
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
images_query = f"""--sql
SELECT {IMAGE_DTO_COLS}
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
query_conditions = ""
query_params: list[Union[int, str, bool]] = []
# Add starred/unstarred filter
is_starred = collection == "starred"
query_conditions += """--sql
AND images.starred = ?
"""
query_params.append(is_starred)
if image_origin is not None:
query_conditions += """--sql
AND images.image_origin = ?
"""
query_params.append(image_origin.value)
if categories is not None:
category_strings = [c.value for c in set(categories)]
placeholders = ",".join("?" * len(category_strings))
query_conditions += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
for c in category_strings:
query_params.append(c)
if is_intermediate is not None:
query_conditions += """--sql
AND images.is_intermediate = ?
"""
query_params.append(is_intermediate)
if board_id == "none":
query_conditions += """--sql
AND board_images.board_id IS NULL
"""
elif board_id is not None:
query_conditions += """--sql
AND board_images.board_id = ?
"""
query_params.append(board_id)
if search_term:
query_conditions += """--sql
AND (
images.metadata LIKE ?
OR images.created_at LIKE ?
)
"""
query_params.append(f"%{search_term.lower()}%")
query_params.append(f"%{search_term.lower()}%")
# Add ordering and pagination
query_pagination = f"""--sql
ORDER BY images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
# Execute images query
images_query += query_conditions + query_pagination + ";"
images_params = query_params.copy()
images_params.extend([limit, offset])
cursor.execute(images_query, images_params)
result = cast(list[sqlite3.Row], cursor.fetchall())
images = [deserialize_image_record(dict(r)) for r in result]
# Execute count query
count_query += query_conditions + ";"
cursor.execute(count_query, query_params)
count = cast(int, cursor.fetchone()[0])
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
def get_image_names(
self,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> list[str]:
cursor = self._conn.cursor()
# Base query to get image names in order (starred first, then unstarred)
query = """--sql
SELECT images.image_name
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
query_conditions = ""
query_params: list[Union[int, str, bool]] = []
if image_origin is not None:
query_conditions += """--sql
AND images.image_origin = ?
"""
query_params.append(image_origin.value)
if categories is not None:
category_strings = [c.value for c in set(categories)]
placeholders = ",".join("?" * len(category_strings))
query_conditions += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
for c in category_strings:
query_params.append(c)
if is_intermediate is not None:
query_conditions += """--sql
AND images.is_intermediate = ?
"""
query_params.append(is_intermediate)
if board_id == "none":
query_conditions += """--sql
AND board_images.board_id IS NULL
"""
elif board_id is not None:
query_conditions += """--sql
AND board_images.board_id = ?
"""
query_params.append(board_id)
if search_term:
query_conditions += """--sql
AND (
images.metadata LIKE ?
OR images.created_at LIKE ?
)
"""
query_params.append(f"%{search_term.lower()}%")
query_params.append(f"%{search_term.lower()}%")
# Order by starred first, then by created_at
query += (
query_conditions
+ f"""--sql
ORDER BY images.starred DESC, images.created_at {order_dir.value}
"""
)
cursor.execute(query, query_params)
result = cast(list[sqlite3.Row], cursor.fetchall())
return [row[0] for row in result]

View File

@@ -1,11 +1,12 @@
from abc import ABC, abstractmethod
from typing import Callable, Optional
from typing import Callable, Literal, Optional
from PIL.Image import Image as PILImageType
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
ImageCollectionCounts,
ImageRecord,
ImageRecordChanges,
ResourceOrigin,
@@ -125,7 +126,7 @@ class ImageServiceABC(ABC):
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets a paginated list of image DTOs."""
"""Gets a paginated list of image DTOs with starred images first when starred_first=True."""
pass
@abstractmethod
@@ -147,3 +148,44 @@ class ImageServiceABC(ABC):
def delete_images_on_board(self, board_id: str):
"""Deletes all images on a board."""
pass
@abstractmethod
def get_collection_counts(
self,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> ImageCollectionCounts:
"""Gets counts for starred and unstarred image collections."""
pass
@abstractmethod
def get_collection_images(
self,
collection: Literal["starred", "unstarred"],
offset: int = 0,
limit: int = 10,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets images from a specific collection (starred or unstarred)."""
pass
@abstractmethod
def get_image_names(
self,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> list[str]:
"""Gets ordered list of all image names (starred first, then unstarred)."""
pass

View File

@@ -1,6 +1,6 @@
from typing import Optional
from pydantic import Field
from pydantic import BaseModel, Field
from invokeai.app.services.image_records.image_records_common import ImageRecord
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
@@ -39,3 +39,27 @@ def image_record_to_dto(
thumbnail_url=thumbnail_url,
board_id=board_id,
)
class ResultWithAffectedBoards(BaseModel):
affected_boards: list[str] = Field(description="The ids of boards affected by the delete operation")
class DeleteImagesResult(ResultWithAffectedBoards):
deleted_images: list[str] = Field(description="The names of the images that were deleted")
class StarredImagesResult(ResultWithAffectedBoards):
starred_images: list[str] = Field(description="The names of the images that were starred")
class UnstarredImagesResult(ResultWithAffectedBoards):
unstarred_images: list[str] = Field(description="The names of the images that were unstarred")
class AddImagesToBoardResult(ResultWithAffectedBoards):
added_images: list[str] = Field(description="The image names that were added to the board")
class RemoveImagesFromBoardResult(ResultWithAffectedBoards):
removed_images: list[str] = Field(description="The image names that were removed from their board")

View File

@@ -1,4 +1,4 @@
from typing import Optional
from typing import Literal, Optional
from PIL.Image import Image as PILImageType
@@ -10,6 +10,7 @@ from invokeai.app.services.image_files.image_files_common import (
)
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
ImageCollectionCounts,
ImageRecord,
ImageRecordChanges,
ImageRecordDeleteException,
@@ -78,7 +79,7 @@ class ImageService(ImageServiceABC):
board_id=board_id, image_name=image_name
)
except Exception as e:
self.__invoker.services.logger.warn(f"Failed to add image to board {board_id}: {str(e)}")
self.__invoker.services.logger.warning(f"Failed to add image to board {board_id}: {str(e)}")
self.__invoker.services.image_files.save(
image_name=image_name, image=image, metadata=metadata, workflow=workflow, graph=graph
)
@@ -309,3 +310,90 @@ class ImageService(ImageServiceABC):
except Exception as e:
self.__invoker.services.logger.error("Problem getting intermediates count")
raise e
def get_collection_counts(
self,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> ImageCollectionCounts:
try:
return self.__invoker.services.image_records.get_collection_counts(
image_origin=image_origin,
categories=categories,
is_intermediate=is_intermediate,
board_id=board_id,
search_term=search_term,
)
except Exception as e:
self.__invoker.services.logger.error("Problem getting collection counts")
raise e
def get_collection_images(
self,
collection: Literal["starred", "unstarred"],
offset: int = 0,
limit: int = 10,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageDTO]:
try:
results = self.__invoker.services.image_records.get_collection_images(
collection=collection,
offset=offset,
limit=limit,
order_dir=order_dir,
image_origin=image_origin,
categories=categories,
is_intermediate=is_intermediate,
board_id=board_id,
search_term=search_term,
)
image_dtos = [
image_record_to_dto(
image_record=r,
image_url=self.__invoker.services.urls.get_image_url(r.image_name),
thumbnail_url=self.__invoker.services.urls.get_image_url(r.image_name, True),
board_id=self.__invoker.services.board_image_records.get_board_for_image(r.image_name),
)
for r in results.items
]
return OffsetPaginatedResults[ImageDTO](
items=image_dtos,
offset=results.offset,
limit=results.limit,
total=results.total,
)
except Exception as e:
self.__invoker.services.logger.error("Problem getting collection images")
raise e
def get_image_names(
self,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> list[str]:
try:
return self.__invoker.services.image_records.get_image_names(
order_dir=order_dir,
image_origin=image_origin,
categories=categories,
is_intermediate=is_intermediate,
board_id=board_id,
search_term=search_term,
)
except Exception as e:
self.__invoker.services.logger.error("Problem getting image names")
raise e

View File

@@ -148,7 +148,7 @@ class ModelInstallService(ModelInstallServiceBase):
def _clear_pending_jobs(self) -> None:
for job in self.list_jobs():
if not job.in_terminal_state:
self._logger.warning("Cancelling job {job.id}")
self._logger.warning(f"Cancelling job {job.id}")
self.cancel_job(job)
while True:
try:

View File

@@ -1,3 +1,4 @@
import gc
import traceback
from contextlib import suppress
from threading import BoundedSemaphore, Thread
@@ -439,6 +440,12 @@ class DefaultSessionProcessor(SessionProcessorBase):
poll_now_event.wait(self._polling_interval)
continue
# GC-ing here can reduce peak memory usage of the invoke process by freeing allocated memory blocks.
# Most queue items take seconds to execute, so the relative cost of a GC is very small.
# Python will never cede allocated memory back to the OS, so anything we can do to reduce the peak
# allocation is well worth it.
gc.collect()
self._invoker.services.logger.info(
f"Executing queue item {self._queue_item.item_id}, session {self._queue_item.session_id}"
)

View File

@@ -10,6 +10,8 @@ from invokeai.app.services.session_queue.session_queue_common import (
CancelByDestinationResult,
CancelByQueueIDResult,
ClearResult,
DeleteAllExceptCurrentResult,
DeleteByDestinationResult,
EnqueueBatchResult,
IsEmptyResult,
IsFullResult,
@@ -17,7 +19,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
RetryItemsResult,
SessionQueueCountsByDestination,
SessionQueueItem,
SessionQueueItemDTO,
SessionQueueStatus,
)
from invokeai.app.services.shared.graph import GraphExecutionState
@@ -92,6 +93,11 @@ class SessionQueueBase(ABC):
"""Cancels a session queue item"""
pass
@abstractmethod
def delete_queue_item(self, item_id: int) -> None:
"""Deletes a session queue item"""
pass
@abstractmethod
def fail_queue_item(
self, item_id: int, error_type: str, error_message: str, error_traceback: str
@@ -109,6 +115,11 @@ class SessionQueueBase(ABC):
"""Cancels all queue items with the given batch destination"""
pass
@abstractmethod
def delete_by_destination(self, queue_id: str, destination: str) -> DeleteByDestinationResult:
"""Deletes all queue items with the given batch destination"""
pass
@abstractmethod
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
"""Cancels all queue items with matching queue ID"""
@@ -119,6 +130,11 @@ class SessionQueueBase(ABC):
"""Cancels all queue items except in-progress items"""
pass
@abstractmethod
def delete_all_except_current(self, queue_id: str) -> DeleteAllExceptCurrentResult:
"""Deletes all queue items except in-progress items"""
pass
@abstractmethod
def list_queue_items(
self,
@@ -127,10 +143,20 @@ class SessionQueueBase(ABC):
priority: int,
cursor: Optional[int] = None,
status: Optional[QUEUE_ITEM_STATUS] = None,
) -> CursorPaginatedResults[SessionQueueItemDTO]:
destination: Optional[str] = None,
) -> CursorPaginatedResults[SessionQueueItem]:
"""Gets a page of session queue items"""
pass
@abstractmethod
def list_all_queue_items(
self,
queue_id: str,
destination: Optional[str] = None,
) -> list[SessionQueueItem]:
"""Gets all queue items that match the given parameters"""
pass
@abstractmethod
def get_queue_item(self, item_id: int) -> SessionQueueItem:
"""Gets a session queue item by ID"""

View File

@@ -207,7 +207,7 @@ class FieldIdentifier(BaseModel):
field_name: str = Field(description="The name of the field")
class SessionQueueItemWithoutGraph(BaseModel):
class SessionQueueItem(BaseModel):
"""Session queue item without the full graph. Used for serialization."""
item_id: int = Field(description="The identifier of the session queue item")
@@ -251,42 +251,7 @@ class SessionQueueItemWithoutGraph(BaseModel):
default=None,
description="The ID of the published workflow associated with this queue item",
)
api_input_fields: Optional[list[FieldIdentifier]] = Field(
default=None, description="The fields that were used as input to the API"
)
api_output_fields: Optional[list[FieldIdentifier]] = Field(
default=None, description="The nodes that were used as output from the API"
)
credits: Optional[float] = Field(default=None, description="The total credits used for this queue item")
@classmethod
def queue_item_dto_from_dict(cls, queue_item_dict: dict) -> "SessionQueueItemDTO":
# must parse these manually
queue_item_dict["field_values"] = get_field_values(queue_item_dict)
return SessionQueueItemDTO(**queue_item_dict)
model_config = ConfigDict(
json_schema_extra={
"required": [
"item_id",
"status",
"batch_id",
"queue_id",
"session_id",
"priority",
"session_id",
"created_at",
"updated_at",
]
}
)
class SessionQueueItemDTO(SessionQueueItemWithoutGraph):
pass
class SessionQueueItem(SessionQueueItemWithoutGraph):
session: GraphExecutionState = Field(description="The fully-populated session to be executed")
workflow: Optional[WorkflowWithoutID] = Field(
default=None, description="The workflow associated with this queue item"
@@ -397,6 +362,18 @@ class CancelByDestinationResult(CancelByBatchIDsResult):
pass
class DeleteByDestinationResult(BaseModel):
"""Result of deleting by a destination"""
deleted: int = Field(..., description="Number of queue items deleted")
class DeleteAllExceptCurrentResult(DeleteByDestinationResult):
"""Result of deleting all except current"""
pass
class CancelByQueueIDResult(CancelByBatchIDsResult):
"""Result of canceling by queue id"""

View File

@@ -17,6 +17,8 @@ from invokeai.app.services.session_queue.session_queue_common import (
CancelByDestinationResult,
CancelByQueueIDResult,
ClearResult,
DeleteAllExceptCurrentResult,
DeleteByDestinationResult,
EnqueueBatchResult,
IsEmptyResult,
IsFullResult,
@@ -24,7 +26,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
RetryItemsResult,
SessionQueueCountsByDestination,
SessionQueueItem,
SessionQueueItemDTO,
SessionQueueItemNotFoundError,
SessionQueueStatus,
ValueToInsertTuple,
@@ -46,10 +47,6 @@ class SqliteSessionQueue(SessionQueueBase):
clear_result = self.clear(DEFAULT_QUEUE_ID)
if clear_result.deleted > 0:
self.__invoker.services.logger.info(f"Cleared all {clear_result.deleted} queue items")
else:
prune_result = self.prune(DEFAULT_QUEUE_ID)
if prune_result.deleted > 0:
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
@@ -104,11 +101,7 @@ class SqliteSessionQueue(SessionQueueBase):
return cast(Union[int, None], cursor.fetchone()[0]) or 0
async def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
return await asyncio.to_thread(self._enqueue_batch, queue_id, batch, prepend)
def _enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
try:
cursor = self._conn.cursor()
# TODO: how does this work in a multi-user scenario?
current_queue_size = self._get_current_queue_size(queue_id)
max_queue_size = self.__invoker.services.configuration.max_queue_size
@@ -118,8 +111,12 @@ class SqliteSessionQueue(SessionQueueBase):
if prepend:
priority = self._get_highest_priority(queue_id) + 1
requested_count = calc_session_count(batch)
values_to_insert = prepare_values_to_insert(
requested_count = await asyncio.to_thread(
calc_session_count,
batch=batch,
)
values_to_insert = await asyncio.to_thread(
prepare_values_to_insert,
queue_id=queue_id,
batch=batch,
priority=priority,
@@ -127,19 +124,16 @@ class SqliteSessionQueue(SessionQueueBase):
)
enqueued_count = len(values_to_insert)
if requested_count > enqueued_count:
values_to_insert = values_to_insert[:max_new_queue_items]
cursor.executemany(
"""--sql
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
values_to_insert,
)
self._conn.commit()
with self._conn:
cursor = self._conn.cursor()
cursor.executemany(
"""--sql
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
values_to_insert,
)
except Exception:
self._conn.rollback()
raise
enqueue_result = EnqueueBatchResult(
queue_id=queue_id,
@@ -220,6 +214,19 @@ class SqliteSessionQueue(SessionQueueBase):
) -> SessionQueueItem:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT status FROM session_queue WHERE item_id = ?
""",
(item_id,),
)
row = cursor.fetchone()
if row is None:
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
current_status = row[0]
# Only update if not already finished (completed, failed or canceled)
if current_status in ("completed", "failed", "canceled"):
return self.get_queue_item(item_id)
cursor.execute(
"""--sql
UPDATE session_queue
@@ -331,6 +338,27 @@ class SqliteSessionQueue(SessionQueueBase):
queue_item = self._set_queue_item_status(item_id=item_id, status="canceled")
return queue_item
def delete_queue_item(self, item_id: int) -> None:
"""Deletes a session queue item"""
try:
self.cancel_queue_item(item_id)
except SessionQueueItemNotFoundError:
pass
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
DELETE
FROM session_queue
WHERE item_id = ?
""",
(item_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
def complete_queue_item(self, item_id: int) -> SessionQueueItem:
queue_item = self._set_queue_item_status(item_id=item_id, status="completed")
return queue_item
@@ -428,6 +456,71 @@ class SqliteSessionQueue(SessionQueueBase):
raise
return CancelByDestinationResult(canceled=count)
def delete_by_destination(self, queue_id: str, destination: str) -> DeleteByDestinationResult:
try:
cursor = self._conn.cursor()
current_queue_item = self.get_current(queue_id)
if current_queue_item is not None and current_queue_item.destination == destination:
self.cancel_queue_item(current_queue_item.item_id)
params = (queue_id, destination)
cursor.execute(
"""--sql
SELECT COUNT(*)
FROM session_queue
WHERE
queue_id = ?
AND destination = ?;
""",
params,
)
count = cursor.fetchone()[0]
cursor.execute(
"""--sql
DELETE
FROM session_queue
WHERE
queue_id = ?
AND destination = ?;
""",
params,
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return DeleteByDestinationResult(deleted=count)
def delete_all_except_current(self, queue_id: str) -> DeleteAllExceptCurrentResult:
try:
cursor = self._conn.cursor()
where = """--sql
WHERE
queue_id == ?
AND status == 'pending'
"""
cursor.execute(
f"""--sql
SELECT COUNT(*)
FROM session_queue
{where};
""",
(queue_id,),
)
count = cursor.fetchone()[0]
cursor.execute(
f"""--sql
DELETE
FROM session_queue
{where};
""",
(queue_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return DeleteAllExceptCurrentResult(deleted=count)
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
try:
cursor = self._conn.cursor()
@@ -543,26 +636,12 @@ class SqliteSessionQueue(SessionQueueBase):
priority: int,
cursor: Optional[int] = None,
status: Optional[QUEUE_ITEM_STATUS] = None,
) -> CursorPaginatedResults[SessionQueueItemDTO]:
destination: Optional[str] = None,
) -> CursorPaginatedResults[SessionQueueItem]:
cursor_ = self._conn.cursor()
item_id = cursor
query = """--sql
SELECT item_id,
status,
priority,
field_values,
error_type,
error_message,
error_traceback,
created_at,
updated_at,
completed_at,
started_at,
session_id,
batch_id,
queue_id,
origin,
destination
SELECT *
FROM session_queue
WHERE queue_id = ?
"""
@@ -574,6 +653,12 @@ class SqliteSessionQueue(SessionQueueBase):
"""
params.append(status)
if destination is not None:
query += """---sql
AND destination = ?
"""
params.append(destination)
if item_id is not None:
query += """--sql
AND (priority < ?) OR (priority = ? AND item_id > ?)
@@ -589,7 +674,7 @@ class SqliteSessionQueue(SessionQueueBase):
params.append(limit + 1)
cursor_.execute(query, params)
results = cast(list[sqlite3.Row], cursor_.fetchall())
items = [SessionQueueItemDTO.queue_item_dto_from_dict(dict(result)) for result in results]
items = [SessionQueueItem.queue_item_from_dict(dict(result)) for result in results]
has_more = False
if len(items) > limit:
# remove the extra item
@@ -597,6 +682,37 @@ class SqliteSessionQueue(SessionQueueBase):
has_more = True
return CursorPaginatedResults(items=items, limit=limit, has_more=has_more)
def list_all_queue_items(
self,
queue_id: str,
destination: Optional[str] = None,
) -> list[SessionQueueItem]:
"""Gets all queue items that match the given parameters"""
cursor_ = self._conn.cursor()
query = """--sql
SELECT *
FROM session_queue
WHERE queue_id = ?
"""
params: list[Union[str, int]] = [queue_id]
if destination is not None:
query += """---sql
AND destination = ?
"""
params.append(destination)
query += """--sql
ORDER BY
priority DESC,
item_id ASC
;
"""
cursor_.execute(query, params)
results = cast(list[sqlite3.Row], cursor_.fetchall())
items = [SessionQueueItem.queue_item_from_dict(dict(result)) for result in results]
return items
def get_queue_status(self, queue_id: str) -> SessionQueueStatus:
cursor = self._conn.cursor()
cursor.execute(

View File

@@ -7,6 +7,7 @@ from typing import Any, Optional, TypeVar, Union, get_args, get_origin, get_type
import networkx as nx
from pydantic import (
BaseModel,
ConfigDict,
GetCoreSchemaHandler,
GetJsonSchemaHandler,
ValidationError,
@@ -787,6 +788,22 @@ class GraphExecutionState(BaseModel):
default_factory=dict,
)
model_config = ConfigDict(
json_schema_extra={
"required": [
"id",
"graph",
"execution_graph",
"executed",
"executed_history",
"results",
"errors",
"prepared_source_mapping",
"source_prepared_mapping",
]
}
)
@field_validator("graph")
def graph_is_valid(cls, v: Graph):
"""Validates that the graph is valid"""

View File

@@ -42,4 +42,5 @@ IP-Adapters:
- [InvokeAI/ip_adapter_plus_sd15](https://huggingface.co/InvokeAI/ip_adapter_plus_sd15)
- [InvokeAI/ip_adapter_plus_face_sd15](https://huggingface.co/InvokeAI/ip_adapter_plus_face_sd15)
- [InvokeAI/ip_adapter_sdxl](https://huggingface.co/InvokeAI/ip_adapter_sdxl)
- [InvokeAI/ip_adapter_sdxl_vit_h](https://huggingface.co/InvokeAI/ip_adapter_sdxl_vit_h)
- [InvokeAI/ip_adapter_sdxl_vit_h](https://huggingface.co/InvokeAI/ip_adapter_sdxl_vit_h)
- [InvokeAI/ip-adapter-plus_sdxl_vit-h](https://huggingface.co/InvokeAI/ip-adapter-plus_sdxl_vit-h)

View File

@@ -296,7 +296,7 @@ class LoRAConfigBase(ABC, BaseModel):
from invokeai.backend.patches.lora_conversions.formats import flux_format_from_state_dict
sd = mod.load_state_dict(mod.path)
value = flux_format_from_state_dict(sd)
value = flux_format_from_state_dict(sd, mod.metadata())
mod.cache[key] = value
return value

View File

@@ -20,6 +20,10 @@ from invokeai.backend.model_manager.taxonomy import (
ModelType,
SubModelType,
)
from invokeai.backend.patches.lora_conversions.flux_aitoolkit_lora_conversion_utils import (
is_state_dict_likely_in_flux_aitoolkit_format,
lora_model_from_flux_aitoolkit_state_dict,
)
from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import (
is_state_dict_likely_flux_control,
lora_model_from_flux_control_state_dict,
@@ -92,6 +96,8 @@ class LoRALoader(ModelLoader):
model = lora_model_from_flux_onetrainer_state_dict(state_dict=state_dict)
elif is_state_dict_likely_flux_control(state_dict=state_dict):
model = lora_model_from_flux_control_state_dict(state_dict=state_dict)
elif is_state_dict_likely_in_flux_aitoolkit_format(state_dict=state_dict):
model = lora_model_from_flux_aitoolkit_state_dict(state_dict=state_dict)
else:
raise ValueError(f"LoRA model is in unsupported FLUX format: {config.format}")
else:

View File

@@ -297,6 +297,15 @@ ip_adapter_sdxl = StarterModel(
dependencies=[ip_adapter_sdxl_image_encoder],
previous_names=["IP Adapter SDXL"],
)
ip_adapter_plus_sdxl = StarterModel(
name="Precise Reference (IP Adapter Plus ViT-H)",
base=BaseModelType.StableDiffusionXL,
source="https://huggingface.co/InvokeAI/ip-adapter-plus_sdxl_vit-h/resolve/main/ip-adapter-plus_sdxl_vit-h.safetensors",
description="References images with a higher degree of precision.",
type=ModelType.IPAdapter,
dependencies=[ip_adapter_sdxl_image_encoder],
previous_names=["IP Adapter Plus SDXL"],
)
ip_adapter_flux = StarterModel(
name="Standard Reference (XLabs FLUX IP-Adapter v2)",
base=BaseModelType.Flux,
@@ -672,6 +681,7 @@ STARTER_MODELS: list[StarterModel] = [
ip_adapter_plus_sd1,
ip_adapter_plus_face_sd1,
ip_adapter_sdxl,
ip_adapter_plus_sdxl,
ip_adapter_flux,
qr_code_cnet_sd1,
qr_code_cnet_sdxl,
@@ -744,6 +754,7 @@ sdxl_bundle: list[StarterModel] = [
juggernaut_sdxl,
sdxl_fp16_vae_fix,
ip_adapter_sdxl,
ip_adapter_plus_sdxl,
canny_sdxl,
depth_sdxl,
softedge_sdxl,

View File

@@ -137,6 +137,7 @@ class FluxLoRAFormat(str, Enum):
Kohya = "flux.kohya"
OneTrainer = "flux.onetrainer"
Control = "flux.control"
AIToolkit = "flux.aitoolkit"
AnyVariant: TypeAlias = Union[ModelVariantType, ClipVariantType, None]

View File

@@ -46,6 +46,10 @@ class ModelPatcher:
text_encoder: Union[CLIPTextModel, CLIPTextModelWithProjection],
ti_list: List[Tuple[str, TextualInversionModelRaw]],
) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]:
if len(ti_list) == 0:
yield tokenizer, TextualInversionManager(tokenizer)
return
init_tokens_count = None
new_tokens_added = None

View File

@@ -0,0 +1,63 @@
import json
from dataclasses import dataclass, field
from typing import Any
import torch
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict
from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import _group_by_layer
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.util import InvokeAILogger
def is_state_dict_likely_in_flux_aitoolkit_format(state_dict: dict[str, Any], metadata: dict[str, Any] = None) -> bool:
if metadata:
try:
software = json.loads(metadata.get("software", "{}"))
except json.JSONDecodeError:
return False
return software.get("name") == "ai-toolkit"
# metadata got lost somewhere
return any("diffusion_model" == k.split(".", 1)[0] for k in state_dict.keys())
@dataclass
class GroupedStateDict:
transformer: dict[str, Any] = field(default_factory=dict)
# might also grow CLIP and T5 submodels
def _group_state_by_submodel(state_dict: dict[str, Any]) -> GroupedStateDict:
logger = InvokeAILogger.get_logger()
grouped = GroupedStateDict()
for key, value in state_dict.items():
submodel_name, param_name = key.split(".", 1)
match submodel_name:
case "diffusion_model":
grouped.transformer[param_name] = value
case _:
logger.warning(f"Unexpected submodel name: {submodel_name}")
return grouped
def _rename_peft_lora_keys(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""Renames keys from the PEFT LoRA format to the InvokeAI format."""
renamed_state_dict = {}
for key, value in state_dict.items():
renamed_key = key.replace(".lora_A.", ".lora_down.").replace(".lora_B.", ".lora_up.")
renamed_state_dict[renamed_key] = value
return renamed_state_dict
def lora_model_from_flux_aitoolkit_state_dict(state_dict: dict[str, torch.Tensor]) -> ModelPatchRaw:
state_dict = _rename_peft_lora_keys(state_dict)
by_layer = _group_by_layer(state_dict)
by_model = _group_state_by_submodel(by_layer)
layers: dict[str, BaseLayerPatch] = {}
for layer_key, layer_state_dict in by_model.transformer.items():
layers[FLUX_LORA_TRANSFORMER_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
return ModelPatchRaw(layers=layers)

View File

@@ -1,4 +1,7 @@
from invokeai.backend.model_manager.taxonomy import FluxLoRAFormat
from invokeai.backend.patches.lora_conversions.flux_aitoolkit_lora_conversion_utils import (
is_state_dict_likely_in_flux_aitoolkit_format,
)
from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import is_state_dict_likely_flux_control
from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import (
is_state_dict_likely_in_flux_diffusers_format,
@@ -11,7 +14,7 @@ from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_u
)
def flux_format_from_state_dict(state_dict):
def flux_format_from_state_dict(state_dict: dict, metadata: dict | None = None) -> FluxLoRAFormat | None:
if is_state_dict_likely_in_flux_kohya_format(state_dict):
return FluxLoRAFormat.Kohya
elif is_state_dict_likely_in_flux_onetrainer_format(state_dict):
@@ -20,5 +23,7 @@ def flux_format_from_state_dict(state_dict):
return FluxLoRAFormat.Diffusers
elif is_state_dict_likely_flux_control(state_dict):
return FluxLoRAFormat.Control
elif is_state_dict_likely_in_flux_aitoolkit_format(state_dict, metadata):
return FluxLoRAFormat.AIToolkit
else:
return None

View File

@@ -9,13 +9,16 @@ module.exports = {
// https://github.com/qdanik/eslint-plugin-path
'path/no-relative-imports': ['error', { maxDepth: 0 }],
// https://github.com/edvardchen/eslint-plugin-i18next/blob/HEAD/docs/rules/no-literal-string.md
'i18next/no-literal-string': 'error',
// TODO: ENABLE THIS RULE BEFORE v6.0.0
// 'i18next/no-literal-string': 'error',
// https://eslint.org/docs/latest/rules/no-console
'no-console': 'error',
'no-console': 'warn',
// https://eslint.org/docs/latest/rules/no-promise-executor-return
'no-promise-executor-return': 'error',
// https://eslint.org/docs/latest/rules/require-await
'require-await': 'error',
// TODO: ENABLE THIS RULE BEFORE v6.0.0
'react/display-name': 'off',
'no-restricted-properties': [
'error',
{

View File

@@ -3,6 +3,8 @@ import type { KnipConfig } from 'knip';
const config: KnipConfig = {
project: ['src/**/*.{ts,tsx}!'],
ignore: [
// TODO(psyche): temporarily ignored all files for test build purposes
'src/**',
// This file is only used during debugging
'src/app/store/middleware/debugLoggerMiddleware.ts',
// Autogenerated types - shouldn't ever touch these

View File

@@ -60,15 +60,16 @@
"@fontsource-variable/inter": "^5.2.5",
"@invoke-ai/ui-library": "^0.0.46",
"@nanostores/react": "^1.0.0",
"@reduxjs/toolkit": "2.7.0",
"@reduxjs/toolkit": "2.8.2",
"@roarr/browser-log-writer": "^1.3.0",
"@xyflow/react": "^12.6.0",
"async-mutex": "^0.5.0",
"chakra-react-select": "^4.9.2",
"cmdk": "^1.1.1",
"compare-versions": "^6.1.1",
"dockview": "^4.3.1",
"filesize": "^10.1.6",
"fracturedjsonjs": "^4.0.2",
"fracturedjsonjs": "^4.1.0",
"framer-motion": "^11.10.0",
"i18next": "^25.0.1",
"i18next-http-backend": "^3.0.2",

View File

@@ -30,8 +30,8 @@ dependencies:
specifier: ^1.0.0
version: 1.0.0(nanostores@1.0.1)(react@18.3.1)
'@reduxjs/toolkit':
specifier: 2.7.0
version: 2.7.0(react-redux@9.2.0)(react@18.3.1)
specifier: 2.8.2
version: 2.8.2(react-redux@9.2.0)(react@18.3.1)
'@roarr/browser-log-writer':
specifier: ^1.3.0
version: 1.3.0
@@ -50,12 +50,15 @@ dependencies:
compare-versions:
specifier: ^6.1.1
version: 6.1.1
dockview:
specifier: ^4.3.1
version: 4.3.1(react@18.3.1)
filesize:
specifier: ^10.1.6
version: 10.1.6
fracturedjsonjs:
specifier: ^4.0.2
version: 4.0.2
specifier: ^4.1.0
version: 4.1.0
framer-motion:
specifier: ^11.10.0
version: 11.10.0(react-dom@18.3.1)(react@18.3.1)
@@ -2161,8 +2164,8 @@ packages:
- supports-color
dev: true
/@reduxjs/toolkit@2.7.0(react-redux@9.2.0)(react@18.3.1):
resolution: {integrity: sha512-XVwolG6eTqwV0N8z/oDlN93ITCIGIop6leXlGJI/4EKy+0POYkR+ABHRSdGXY+0MQvJBP8yAzh+EYFxTuvmBiQ==}
/@reduxjs/toolkit@2.8.2(react-redux@9.2.0)(react@18.3.1):
resolution: {integrity: sha512-MYlOhQ0sLdw4ud48FoC5w0dH9VfWQjtCjreKwYTT3l+r427qYC5Y8PihNutepr8XrNaBUDQo9khWUwQxZaqt5A==}
peerDependencies:
react: ^16.9.0 || ^17.0.0 || ^18 || ^19
react-redux: ^7.2.1 || ^8.1.3 || ^9.0.0
@@ -4492,6 +4495,19 @@ packages:
resolution: {integrity: sha512-c68LpLbO+7kP/b1Hr1qs8/BJ09F5khZGTxqxZuhzxpmwJKOgRFHJWIb9/KmqnqHhLdO55aOxFH/EGBvUQbL/RQ==}
dev: false
/dockview-core@4.3.1:
resolution: {integrity: sha512-cjGIXKc1wtHHkeKisuDLNt3HSHCVzvabxm1K9Auna27A9T3QR7ISOiTJyEUKUPllkcztFYBut0vwnnvwLnPAuQ==}
dev: false
/dockview@4.3.1(react@18.3.1):
resolution: {integrity: sha512-D4SvZPs1GJxGUBPkrehlKNGsWlSDaBiPuSYI+IEXnZ7b2bCUs1/h954sVs7xyykqEW3r6TkPKLWdTR/47Q7/QQ==}
peerDependencies:
react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0
dependencies:
dockview-core: 4.3.1
react: 18.3.1
dev: false
/doctrine@2.1.0:
resolution: {integrity: sha512-35mSku4ZXK0vfCuHEDAwt55dg2jNajHZ1odvF+8SSr82EsZY4QmXfuWso8oEd8zRhVObSN18aM0CjSdoBX7zIw==}
engines: {node: '>=0.10.0'}
@@ -5280,8 +5296,8 @@ packages:
signal-exit: 4.1.0
dev: true
/fracturedjsonjs@4.0.2:
resolution: {integrity: sha512-+vGJH9wK0EEhbbn50V2sOebLRaar1VL3EXr02kxchIwpkhQk0ItrPjIOtYPYuU9hNFpVzxjrPgzjtMJih+ae4A==}
/fracturedjsonjs@4.1.0:
resolution: {integrity: sha512-qy6LPA8OOiiyRHt5/sNKDayD7h5r3uHmHxSOLbBsgtU/hkt5vOVWOR51MdfDbeCNfj7k/dKCRbXYm8FBAJcgWQ==}
dev: false
/framer-motion@10.18.0(react-dom@18.3.1)(react@18.3.1):

View File

@@ -2015,7 +2015,9 @@
"resetGenerationSettings": "Reset Generation Settings",
"replaceCurrent": "Replace Current",
"controlLayerEmptyState": "<UploadButton>Upload an image</UploadButton>, drag an image from the <GalleryButton>gallery</GalleryButton> onto this layer, <PullBboxButton>pull the bounding box into this layer</PullBboxButton>, or draw on the canvas to get started.",
"referenceImageEmptyState": "<UploadButton>Upload an image</UploadButton>, drag an image from the <GalleryButton>gallery</GalleryButton> onto this layer, or <PullBboxButton>pull the bounding box into this layer</PullBboxButton> to get started.",
"referenceImageEmptyStateWithCanvasOptions": "<UploadButton>Upload an image</UploadButton>, drag an image from the <GalleryButton>gallery</GalleryButton> onto this Reference Image or <PullBboxButton>pull the bounding box into this Reference Image</PullBboxButton> to get started.",
"referenceImageEmptyState": "<UploadButton>Upload an image</UploadButton> or drag an image from the <GalleryButton>gallery</GalleryButton> onto this Reference Image to get started.",
"uploadOrDragAnImage": "Drag an image from the gallery or <UploadButton>upload an image</UploadButton>.",
"imageNoise": "Image Noise",
"denoiseLimit": "Denoise Limit",
"warnings": {

View File

@@ -3,7 +3,7 @@ import { useStore } from '@nanostores/react';
import { GlobalHookIsolator } from 'app/components/GlobalHookIsolator';
import { GlobalModalIsolator } from 'app/components/GlobalModalIsolator';
import type { StudioInitAction } from 'app/hooks/useStudioInitAction';
import { $didStudioInit } from 'app/hooks/useStudioInitAction';
import { $globalIsLoading } from 'app/store/nanostores/globalIsLoading';
import type { PartialAppConfig } from 'app/types/invokeai';
import Loading from 'common/components/Loading/Loading';
import { useClearStorage } from 'common/hooks/useClearStorage';
@@ -20,7 +20,7 @@ interface Props {
}
const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
const didStudioInit = useStore($didStudioInit);
const globalIsLoading = useStore($globalIsLoading);
const clearStorage = useClearStorage();
const handleReset = useCallback(() => {
@@ -33,7 +33,7 @@ const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
<ErrorBoundary onReset={handleReset} FallbackComponent={AppErrorBoundaryFallback}>
<Box id="invoke-app-wrapper" w="100dvw" h="100dvh" position="relative" overflow="hidden">
<AppContent />
{!didStudioInit && <Loading />}
{globalIsLoading && <Loading />}
</Box>
<GlobalHookIsolator config={config} studioInitAction={studioInitAction} />
<GlobalModalIsolator />

View File

@@ -8,6 +8,7 @@ import { appStarted } from 'app/store/middleware/listenerMiddleware/listeners/ap
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import type { PartialAppConfig } from 'app/types/invokeai';
import { useFocusRegionWatcher } from 'common/hooks/focus';
import { useCloseChakraTooltipsOnDragFix } from 'common/hooks/useCloseChakraTooltipsOnDragFix';
import { useGlobalHotkeys } from 'common/hooks/useGlobalHotkeys';
import { useDynamicPromptsWatcher } from 'features/dynamicPrompts/hooks/useDynamicPromptsWatcher';
import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast';
@@ -19,8 +20,11 @@ import i18n from 'i18n';
import { size } from 'lodash-es';
import { memo, useEffect } from 'react';
import { useGetOpenAPISchemaQuery } from 'services/api/endpoints/appInfo';
import { useGetQueueCountsByDestinationQuery } from 'services/api/endpoints/queue';
import { useSocketIO } from 'services/events/useSocketIO';
const queueCountArg = { destination: 'canvas' };
/**
* GlobalHookIsolator is a logical component that runs global hooks in an isolated component, so that they do not
* cause needless re-renders of any other components.
@@ -38,6 +42,11 @@ export const GlobalHookIsolator = memo(
useGlobalHotkeys();
useGetOpenAPISchemaQuery();
useSyncLoggingConfig();
useCloseChakraTooltipsOnDragFix();
// Persistent subscription to the queue counts query - canvas relies on this to know if there are pending
// and/or in progress canvas sessions.
useGetQueueCountsByDestinationQuery(queueCountArg);
useEffect(() => {
i18n.changeLanguage(language);

View File

@@ -7,11 +7,13 @@ import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { memo } from 'react';
import { useImageDTO } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
export const GlobalImageHotkeys = memo(() => {
useAssertSingleton('GlobalImageHotkeys');
const imageDTO = useAppSelector(selectLastSelectedImage);
const imageName = useAppSelector(selectLastSelectedImage);
const imageDTO = useImageDTO(imageName);
if (!imageDTO) {
return null;

View File

@@ -6,7 +6,7 @@ import {
NewGallerySessionDialog,
} from 'features/controlLayers/components/NewSessionConfirmationAlertDialog';
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageModal';
import { DeleteImageModal } from 'features/deleteImageModal/components/DeleteImageModal';
import { FullscreenDropzone } from 'features/dnd/FullscreenDropzone';
import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicPromptsPreviewModal';
import DeleteBoardModal from 'features/gallery/components/Boards/DeleteBoardModal';
@@ -15,6 +15,7 @@ import { ShareWorkflowModal } from 'features/nodes/components/sidePanel/workflow
import { WorkflowLibraryModal } from 'features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowLibraryModal';
import { CancelAllExceptCurrentQueueItemConfirmationAlertDialog } from 'features/queue/components/CancelAllExceptCurrentQueueItemConfirmationAlertDialog';
import { ClearQueueConfirmationsAlertDialog } from 'features/queue/components/ClearQueueConfirmationAlertDialog';
import { DeleteAllExceptCurrentQueueItemConfirmationAlertDialog } from 'features/queue/components/DeleteAllExceptCurrentQueueItemConfirmationAlertDialog';
import { DeleteStylePresetDialog } from 'features/stylePresets/components/DeleteStylePresetDialog';
import { StylePresetModal } from 'features/stylePresets/components/StylePresetForm/StylePresetModal';
import RefreshAfterResetModal from 'features/system/components/SettingsModal/RefreshAfterResetModal';
@@ -39,6 +40,7 @@ export const GlobalModalIsolator = memo(() => {
<StylePresetModal />
<WorkflowLibraryModal />
<CancelAllExceptCurrentQueueItemConfirmationAlertDialog />
<DeleteAllExceptCurrentQueueItemConfirmationAlertDialog />
<ClearQueueConfirmationsAlertDialog />
<NewWorkflowConfirmationAlertDialog />
<LoadWorkflowConfirmationAlertDialog />

View File

@@ -1,6 +1,7 @@
import '@fontsource-variable/inter';
import 'overlayscrollbars/overlayscrollbars.css';
import '@xyflow/react/dist/base.css';
import 'common/components/OverlayScrollbars/overlayscrollbars.css';
import { ChakraProvider, DarkMode, extendTheme, theme as _theme, TOAST_OPTIONS } from '@invoke-ai/ui-library';
import type { ReactNode } from 'react';

View File

@@ -3,11 +3,10 @@ import { useAppStore } from 'app/store/storeHooks';
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
import { withResultAsync } from 'common/util/result';
import { canvasReset } from 'features/controlLayers/store/actions';
import { settingsSendToCanvasChanged } from 'features/controlLayers/store/canvasSettingsSlice';
import { rasterLayerAdded } from 'features/controlLayers/store/canvasSlice';
import { paramsReset } from 'features/controlLayers/store/paramsSlice';
import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
import { imageDTOToImageObject } from 'features/controlLayers/store/util';
import { $imageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
import { sentImageToCanvas } from 'features/gallery/store/actions';
import { parseAndRecallAllMetadata } from 'features/metadata/util/handlers';
import { $hasTemplates } from 'features/nodes/store/nodesSlice';
@@ -93,10 +92,7 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
};
store.dispatch(canvasReset());
store.dispatch(rasterLayerAdded({ overrides, isSelected: true }));
store.dispatch(settingsSendToCanvasChanged(true));
store.dispatch(setActiveTab('canvas'));
store.dispatch(sentImageToCanvas());
$imageViewer.set(false);
toast({
title: t('toast.sentToCanvas'),
status: 'info',
@@ -118,9 +114,9 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
return;
}
const metadata = getImageMetadataResult.value;
store.dispatch(canvasReset());
// This shows a toast
await parseAndRecallAllMetadata(metadata, true);
store.dispatch(setActiveTab('canvas'));
},
[store, t]
);
@@ -164,16 +160,12 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
switch (destination) {
case 'generation':
// Go to the canvas tab, open the image viewer, and enable send-to-gallery mode
store.dispatch(setActiveTab('canvas'));
store.dispatch(paramsReset());
store.dispatch(activeTabCanvasRightPanelChanged('gallery'));
store.dispatch(settingsSendToCanvasChanged(false));
$imageViewer.set(true);
break;
case 'canvas':
// Go to the canvas tab, close the image viewer, and disable send-to-gallery mode
store.dispatch(setActiveTab('canvas'));
store.dispatch(settingsSendToCanvasChanged(true));
$imageViewer.set(false);
store.dispatch(canvasReset());
break;
case 'workflows':
// Go to the workflows tab

View File

@@ -1,7 +1,6 @@
import type { TypedStartListening } from '@reduxjs/toolkit';
import { addListener, createListenerMiddleware } from '@reduxjs/toolkit';
import { addAdHocPostProcessingRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/addAdHocPostProcessingRequestedListener';
import { addStagingListeners } from 'app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener';
import { addAnyEnqueuedListener } from 'app/store/middleware/listenerMiddleware/listeners/anyEnqueued';
import { addAppConfigReceivedListener } from 'app/store/middleware/listenerMiddleware/listeners/appConfigReceived';
import { addAppStartedListener } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
@@ -10,15 +9,12 @@ import { addDeleteBoardAndImagesFulfilledListener } from 'app/store/middleware/l
import { addBoardIdSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/boardIdSelected';
import { addBulkDownloadListeners } from 'app/store/middleware/listenerMiddleware/listeners/bulkDownload';
import { addEnqueueRequestedLinear } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear';
import { addEnsureImageIsSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/ensureImageIsSelectedListener';
import { addGalleryImageClickedListener } from 'app/store/middleware/listenerMiddleware/listeners/galleryImageClicked';
import { addGalleryOffsetChangedListener } from 'app/store/middleware/listenerMiddleware/listeners/galleryOffsetChanged';
import { addGetOpenAPISchemaListener } from 'app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema';
import { addImageAddedToBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageAddedToBoard';
import { addImageDeletionListeners } from 'app/store/middleware/listenerMiddleware/listeners/imageDeletionListeners';
import { addImageRemovedFromBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageRemovedFromBoard';
import { addImagesStarredListener } from 'app/store/middleware/listenerMiddleware/listeners/imagesStarred';
import { addImagesUnstarredListener } from 'app/store/middleware/listenerMiddleware/listeners/imagesUnstarred';
import { addImageToDeleteSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/imageToDeleteSelected';
import { addImageUploadedFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageUploaded';
import { addModelSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelSelected';
import { addModelsLoadedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelsLoaded';
@@ -47,13 +43,7 @@ export const addAppListener = addListener.withTypes<RootState, AppDispatch>();
addImageUploadedFulfilledListener(startAppListening);
// Image deleted
addImageDeletionListeners(startAppListening);
addDeleteBoardAndImagesFulfilledListener(startAppListening);
addImageToDeleteSelectedListener(startAppListening);
// Image starred
addImagesStarredListener(startAppListening);
addImagesUnstarredListener(startAppListening);
// Gallery
addGalleryImageClickedListener(startAppListening);
@@ -65,9 +55,6 @@ addEnqueueRequestedUpscale(startAppListening);
addAnyEnqueuedListener(startAppListening);
addBatchEnqueuedListener(startAppListening);
// Canvas actions
addStagingListeners(startAppListening);
// Socket.IO
addSocketConnectedEventListener(startAppListening);
@@ -95,3 +82,5 @@ addAppConfigReceivedListener(startAppListening);
addAdHocPostProcessingRequestedListener(startAppListening);
addSetDefaultSettingsListener(startAppListening);
addEnsureImageIsSelectedListener(startAppListening);

View File

@@ -25,7 +25,7 @@ export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartLis
matcher: matchAnyBoardDeleted,
effect: (action, { dispatch, getState }) => {
const state = getState();
const deletedBoardId = action.meta.arg.originalArgs;
const deletedBoardId = action.meta.arg.originalArgs.board_id;
const { autoAddBoardId, selectedBoardId } = state.gallery;
// If the deleted board was currently selected, we should reset the selected board to uncategorized

View File

@@ -1,46 +0,0 @@
import { isAnyOf } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { canvasReset, newSessionRequested } from 'features/controlLayers/store/actions';
import { stagingAreaReset } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { queueApi } from 'services/api/endpoints/queue';
const log = logger('canvas');
const matchCanvasOrStagingAreaReset = isAnyOf(stagingAreaReset, canvasReset, newSessionRequested);
export const addStagingListeners = (startAppListening: AppStartListening) => {
startAppListening({
matcher: matchCanvasOrStagingAreaReset,
effect: async (_, { dispatch }) => {
try {
const req = dispatch(
queueApi.endpoints.cancelByBatchDestination.initiate(
{ destination: 'canvas' },
{ fixedCacheKey: 'cancelByBatchOrigin' }
)
);
const { canceled } = await req.unwrap();
req.reset();
if (canceled > 0) {
log.debug(`Canceled ${canceled} canvas batches`);
toast({
id: 'CANCEL_BATCH_SUCCEEDED',
title: t('queue.cancelBatchSucceeded'),
status: 'success',
});
}
} catch {
log.error('Failed to cancel canvas batches');
toast({
id: 'CANCEL_BATCH_FAILED',
title: t('queue.cancelBatchFailed'),
status: 'error',
});
}
},
});
};

View File

@@ -1,6 +1,7 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { getImageUsage } from 'features/deleteImageModal/store/selectors';
import { getImageUsage } from 'features/deleteImageModal/store/state';
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
import { selectNodesSlice } from 'features/nodes/store/selectors';
import { selectUpscaleSlice } from 'features/parameters/store/upscaleSlice';
@@ -20,9 +21,10 @@ export const addDeleteBoardAndImagesFulfilledListener = (startAppListening: AppS
const nodes = selectNodesSlice(state);
const canvas = selectCanvasSlice(state);
const upscale = selectUpscaleSlice(state);
const refImages = selectRefImagesSlice(state);
deleted_images.forEach((image_name) => {
const imageUsage = getImageUsage(nodes, canvas, upscale, image_name);
const imageUsage = getImageUsage(nodes, canvas, upscale, refImages, image_name);
if (imageUsage.isNodesImage && !wasNodeEditorReset) {
dispatch(nodeEditorReset());

View File

@@ -30,9 +30,9 @@ export const addBoardIdSelectedListener = (startAppListening: AppStartListening)
const selectedImage = boardImagesData.items.find(
(item) => item.image_name === action.payload.selectedImageName
);
dispatch(imageSelected(selectedImage || null));
dispatch(imageSelected(selectedImage?.image_name ?? null));
} else if (boardImagesData) {
dispatch(imageSelected(boardImagesData.items[0] || null));
dispatch(imageSelected(boardImagesData.items[0]?.image_name ?? null));
} else {
// board has no images - deselect
dispatch(imageSelected(null));

View File

@@ -5,6 +5,12 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError';
import { withResult, withResultAsync } from 'common/util/result';
import { parseify } from 'common/util/serialize';
import {
canvasSessionIdCreated,
generateSessionIdCreated,
selectCanvasSessionId,
selectGenerateSessionId,
} from 'features/controlLayers/store/canvasStagingAreaSlice';
import { $canvasManager } from 'features/controlLayers/store/ephemeral';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildChatGPT4oGraph';
@@ -17,6 +23,7 @@ import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Grap
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
import { toast } from 'features/toast/toast';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { serializeError } from 'serialize-error';
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
import { assert, AssertionError } from 'tsafe';
@@ -30,11 +37,34 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
actionCreator: enqueueRequestedCanvas,
effect: async (action, { getState, dispatch }) => {
log.debug('Enqueue requested');
const tab = selectActiveTab(getState());
let sessionId = null;
if (tab === 'generate') {
sessionId = selectGenerateSessionId(getState());
if (!sessionId) {
dispatch(generateSessionIdCreated());
sessionId = selectGenerateSessionId(getState());
}
} else if (tab === 'canvas') {
sessionId = selectCanvasSessionId(getState());
if (!sessionId) {
dispatch(canvasSessionIdCreated());
sessionId = selectCanvasSessionId(getState());
}
} else {
log.warn(`Enqueue requested in unsupported tab ${tab}`);
return;
}
const state = getState();
const destination = sessionId;
assert(destination !== null);
const { prepend } = action.payload;
const manager = $canvasManager.get();
assert(manager, 'No canvas manager');
// assert(manager, 'No canvas manager');
const model = state.params.model;
assert(model, 'No model found in state');
@@ -87,8 +117,6 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
const { g, seedFieldIdentifier, positivePromptFieldIdentifier } = buildGraphResult.value;
const destination = state.canvasSettings.sendToCanvas ? 'canvas' : 'gallery';
const prepareBatchResult = withResult(() =>
prepareLinearUIBatch({
state,

View File

@@ -0,0 +1,16 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { imagesApi } from 'services/api/endpoints/images';
export const addEnsureImageIsSelectedListener = (startAppListening: AppStartListening) => {
// When we list images, if no images is selected, select the first one.
startAppListening({
matcher: imagesApi.endpoints.listImages.matchFulfilled,
effect: (action, { dispatch, getState }) => {
const selection = getState().gallery.selection;
if (selection.length === 0) {
dispatch(imageSelected(action.payload.items[0]?.image_name ?? null));
}
},
});
};

View File

@@ -1,12 +1,32 @@
import { createAction } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
import type { RootState } from 'app/store/store';
import { selectImageCollectionQueryArgs } from 'features/gallery/store/gallerySelectors';
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
import { uniq } from 'lodash-es';
import { imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import type { ImageCategory, SQLiteDirection } from 'services/api/types';
// Type for image collection query arguments
type ImageCollectionQueryArgs = {
board_id?: string;
categories?: ImageCategory[];
search_term?: string;
order_dir?: SQLiteDirection;
is_intermediate: boolean;
};
/**
* Helper function to get cached image names list for selection operations
* Returns an ordered array of image names (starred first, then unstarred)
*/
const getCachedImageNames = (state: RootState, queryArgs: ImageCollectionQueryArgs): string[] => {
const queryResult = imagesApi.endpoints.getImageNames.select(queryArgs)(state);
return queryResult.data || [];
};
export const galleryImageClicked = createAction<{
imageDTO: ImageDTO;
imageName: string;
shiftKey: boolean;
ctrlKey: boolean;
metaKey: boolean;
@@ -28,45 +48,51 @@ export const addGalleryImageClickedListener = (startAppListening: AppStartListen
startAppListening({
actionCreator: galleryImageClicked,
effect: (action, { dispatch, getState }) => {
const { imageDTO, shiftKey, ctrlKey, metaKey, altKey } = action.payload;
const { imageName, shiftKey, ctrlKey, metaKey, altKey } = action.payload;
const state = getState();
const queryArgs = selectListImagesQueryArgs(state);
const queryResult = imagesApi.endpoints.listImages.select(queryArgs)(state);
const queryArgs = selectImageCollectionQueryArgs(state);
if (!queryResult.data) {
// Should never happen if we have clicked a gallery image
// Get cached image names for selection operations
const imageNames = getCachedImageNames(state, queryArgs);
// If we don't have the image names cached, we can't perform selection operations
// This can happen if the user clicks on an image before the names are loaded
if (imageNames.length === 0) {
// For basic click without modifiers, we can still set selection
if (!shiftKey && !ctrlKey && !metaKey && !altKey) {
dispatch(selectionChanged([imageName]));
}
return;
}
const imageDTOs = queryResult.data.items;
const selection = state.gallery.selection;
if (altKey) {
if (state.gallery.imageToCompare?.image_name === imageDTO.image_name) {
if (state.gallery.imageToCompare === imageName) {
dispatch(imageToCompareChanged(null));
} else {
dispatch(imageToCompareChanged(imageDTO));
dispatch(imageToCompareChanged(imageName));
}
} else if (shiftKey) {
const rangeEndImageName = imageDTO.image_name;
const lastSelectedImage = selection[selection.length - 1]?.image_name;
const lastClickedIndex = imageDTOs.findIndex((n) => n.image_name === lastSelectedImage);
const currentClickedIndex = imageDTOs.findIndex((n) => n.image_name === rangeEndImageName);
const rangeEndImageName = imageName;
const lastSelectedImage = selection.at(-1);
const lastClickedIndex = imageNames.findIndex((name) => name === lastSelectedImage);
const currentClickedIndex = imageNames.findIndex((name) => name === rangeEndImageName);
if (lastClickedIndex > -1 && currentClickedIndex > -1) {
// We have a valid range!
const start = Math.min(lastClickedIndex, currentClickedIndex);
const end = Math.max(lastClickedIndex, currentClickedIndex);
const imagesToSelect = imageDTOs.slice(start, end + 1);
dispatch(selectionChanged(selection.concat(imagesToSelect)));
const imagesToSelect = imageNames.slice(start, end + 1);
dispatch(selectionChanged(uniq(selection.concat(imagesToSelect))));
}
} else if (ctrlKey || metaKey) {
if (selection.some((i) => i.image_name === imageDTO.image_name) && selection.length > 1) {
dispatch(selectionChanged(selection.filter((n) => n.image_name !== imageDTO.image_name)));
if (selection.some((n) => n === imageName) && selection.length > 1) {
dispatch(selectionChanged(uniq(selection.filter((n) => n !== imageName))));
} else {
dispatch(selectionChanged(selection.concat(imageDTO)));
dispatch(selectionChanged(uniq(selection.concat(imageName))));
}
} else {
dispatch(selectionChanged([imageDTO]));
dispatch(selectionChanged([imageName]));
}
},
});

View File

@@ -84,14 +84,14 @@ export const addGalleryOffsetChangedListener = (startAppListening: AppStartListe
if (offset < prevOffset) {
// We've gone backwards
const lastImage = imageDTOs[imageDTOs.length - 1];
if (!selection.some((selectedImage) => selectedImage.image_name === lastImage?.image_name)) {
dispatch(selectionChanged(lastImage ? [lastImage] : []));
if (!selection.some((selectedImage) => selectedImage === lastImage?.image_name)) {
dispatch(selectionChanged(lastImage ? [lastImage.image_name] : []));
}
} else {
// We've gone forwards
const firstImage = imageDTOs[0];
if (!selection.some((selectedImage) => selectedImage.image_name === firstImage?.image_name)) {
dispatch(selectionChanged(firstImage ? [firstImage] : []));
if (!selection.some((selectedImage) => selectedImage === firstImage?.image_name)) {
dispatch(selectionChanged(firstImage ? [firstImage.image_name] : []));
}
}
return;
@@ -102,14 +102,14 @@ export const addGalleryOffsetChangedListener = (startAppListening: AppStartListe
if (offset < prevOffset) {
// We've gone backwards
const lastImage = imageDTOs[imageDTOs.length - 1];
if (lastImage && imageToCompare?.image_name !== lastImage.image_name) {
dispatch(imageToCompareChanged(lastImage));
if (lastImage && imageToCompare !== lastImage.image_name) {
dispatch(imageToCompareChanged(lastImage.image_name));
}
} else {
// We've gone forwards
const firstImage = imageDTOs[0];
if (firstImage && imageToCompare?.image_name !== firstImage.image_name) {
dispatch(imageToCompareChanged(firstImage));
if (firstImage && imageToCompare !== firstImage.image_name) {
dispatch(imageToCompareChanged(firstImage.image_name));
}
}
return;

View File

@@ -8,16 +8,16 @@ export const addImageAddedToBoardFulfilledListener = (startAppListening: AppStar
startAppListening({
matcher: imagesApi.endpoints.addImageToBoard.matchFulfilled,
effect: (action) => {
const { board_id, imageDTO } = action.meta.arg.originalArgs;
log.debug({ board_id, imageDTO }, 'Image added to board');
const { board_id, image_name } = action.meta.arg.originalArgs;
log.debug({ board_id, image_name }, 'Image added to board');
},
});
startAppListening({
matcher: imagesApi.endpoints.addImageToBoard.matchRejected,
effect: (action) => {
const { board_id, imageDTO } = action.meta.arg.originalArgs;
log.debug({ board_id, imageDTO }, 'Problem adding image to board');
const { board_id, image_name } = action.meta.arg.originalArgs;
log.debug({ board_id, image_name }, 'Problem adding image to board');
},
});
};

View File

@@ -1,221 +0,0 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppDispatch, RootState } from 'app/store/store';
import { entityDeleted, referenceImageIPAdapterImageChanged } from 'features/controlLayers/store/canvasSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { getEntityIdentifier } from 'features/controlLayers/store/types';
import { imageDeletionConfirmed } from 'features/deleteImageModal/store/actions';
import { isModalOpenChanged } from 'features/deleteImageModal/store/slice';
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { fieldImageCollectionValueChanged, fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { isImageFieldCollectionInputInstance, isImageFieldInputInstance } from 'features/nodes/types/field';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { forEach, intersectionBy } from 'lodash-es';
import { imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import type { Param0 } from 'tsafe';
const log = logger('gallery');
//TODO(psyche): handle image deletion (canvas staging area?)
// Some utils to delete images from different parts of the app
const deleteNodesImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
const actions: Param0<typeof dispatch>[] = [];
state.nodes.present.nodes.forEach((node) => {
if (!isInvocationNode(node)) {
return;
}
forEach(node.data.inputs, (input) => {
if (isImageFieldInputInstance(input) && input.value?.image_name === imageDTO.image_name) {
actions.push(
fieldImageValueChanged({
nodeId: node.data.id,
fieldName: input.name,
value: undefined,
})
);
return;
}
if (isImageFieldCollectionInputInstance(input)) {
actions.push(
fieldImageCollectionValueChanged({
nodeId: node.data.id,
fieldName: input.name,
value: input.value?.filter((value) => value?.image_name !== imageDTO.image_name),
})
);
}
});
});
actions.forEach(dispatch);
};
const deleteControlLayerImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
selectCanvasSlice(state).controlLayers.entities.forEach(({ id, objects }) => {
let shouldDelete = false;
for (const obj of objects) {
if (obj.type === 'image' && obj.image.image_name === imageDTO.image_name) {
shouldDelete = true;
break;
}
}
if (shouldDelete) {
dispatch(entityDeleted({ entityIdentifier: { id, type: 'control_layer' } }));
}
});
};
const deleteReferenceImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
selectCanvasSlice(state).referenceImages.entities.forEach((entity) => {
if (entity.ipAdapter.image?.image_name === imageDTO.image_name) {
dispatch(referenceImageIPAdapterImageChanged({ entityIdentifier: getEntityIdentifier(entity), imageDTO: null }));
}
});
};
const deleteRasterLayerImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
selectCanvasSlice(state).rasterLayers.entities.forEach(({ id, objects }) => {
let shouldDelete = false;
for (const obj of objects) {
if (obj.type === 'image' && obj.image.image_name === imageDTO.image_name) {
shouldDelete = true;
break;
}
}
if (shouldDelete) {
dispatch(entityDeleted({ entityIdentifier: { id, type: 'raster_layer' } }));
}
});
};
export const addImageDeletionListeners = (startAppListening: AppStartListening) => {
// Handle single image deletion
startAppListening({
actionCreator: imageDeletionConfirmed,
effect: async (action, { dispatch, getState }) => {
const { imageDTOs, imagesUsage } = action.payload;
if (imageDTOs.length !== 1 || imagesUsage.length !== 1) {
// handle multiples in separate listener
return;
}
const imageDTO = imageDTOs[0];
const imageUsage = imagesUsage[0];
if (!imageDTO || !imageUsage) {
// satisfy noUncheckedIndexedAccess
return;
}
try {
const state = getState();
await dispatch(imagesApi.endpoints.deleteImage.initiate(imageDTO)).unwrap();
if (state.gallery.selection.some((i) => i.image_name === imageDTO.image_name)) {
// The deleted image was a selected image, we need to select the next image
const newSelection = state.gallery.selection.filter((i) => i.image_name !== imageDTO.image_name);
if (newSelection.length > 0) {
return;
}
// Get the current list of images and select the same index
const baseQueryArgs = selectListImagesQueryArgs(state);
const data = imagesApi.endpoints.listImages.select(baseQueryArgs)(state).data;
if (data) {
const deletedImageIndex = data.items.findIndex((i) => i.image_name === imageDTO.image_name);
const nextImage = data.items[deletedImageIndex + 1] ?? data.items[0] ?? null;
if (nextImage?.image_name === imageDTO.image_name) {
// If the next image is the same as the deleted one, it means it was the last image, reset selection
dispatch(imageSelected(null));
} else {
dispatch(imageSelected(nextImage));
}
}
}
deleteNodesImages(state, dispatch, imageDTO);
deleteReferenceImages(state, dispatch, imageDTO);
deleteRasterLayerImages(state, dispatch, imageDTO);
deleteControlLayerImages(state, dispatch, imageDTO);
} catch {
// no-op
} finally {
dispatch(isModalOpenChanged(false));
}
},
});
// Handle multiple image deletion
startAppListening({
actionCreator: imageDeletionConfirmed,
effect: async (action, { dispatch, getState }) => {
const { imageDTOs, imagesUsage } = action.payload;
if (imageDTOs.length <= 1 || imagesUsage.length <= 1) {
// handle singles in separate listener
return;
}
try {
const state = getState();
await dispatch(imagesApi.endpoints.deleteImages.initiate({ imageDTOs })).unwrap();
if (intersectionBy(state.gallery.selection, imageDTOs, 'image_name').length > 0) {
// Some selected images were deleted, need to select the next image
const queryArgs = selectListImagesQueryArgs(state);
const { data } = imagesApi.endpoints.listImages.select(queryArgs)(state);
if (data) {
// When we delete multiple images, we clear the selection. Then, the the next time we load images, we will
// select the first one. This is handled below in the listener for `imagesApi.endpoints.listImages.matchFulfilled`.
dispatch(imageSelected(null));
}
}
// We need to reset the features where the image is in use - none of these work if their image(s) don't exist
imageDTOs.forEach((imageDTO) => {
deleteNodesImages(state, dispatch, imageDTO);
deleteControlLayerImages(state, dispatch, imageDTO);
deleteReferenceImages(state, dispatch, imageDTO);
deleteRasterLayerImages(state, dispatch, imageDTO);
});
} catch {
// no-op
} finally {
dispatch(isModalOpenChanged(false));
}
},
});
// When we list images, if no images is selected, select the first one.
startAppListening({
matcher: imagesApi.endpoints.listImages.matchFulfilled,
effect: (action, { dispatch, getState }) => {
const selection = getState().gallery.selection;
if (selection.length === 0) {
dispatch(imageSelected(action.payload.items[0] ?? null));
}
},
});
startAppListening({
matcher: imagesApi.endpoints.deleteImage.matchFulfilled,
effect: (action) => {
log.debug({ imageDTO: action.meta.arg.originalArgs }, 'Image deleted');
},
});
startAppListening({
matcher: imagesApi.endpoints.deleteImage.matchRejected,
effect: (action) => {
log.debug({ imageDTO: action.meta.arg.originalArgs }, 'Unable to delete image');
},
});
};

View File

@@ -1,32 +0,0 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { imageDeletionConfirmed } from 'features/deleteImageModal/store/actions';
import { selectImageUsage } from 'features/deleteImageModal/store/selectors';
import { imagesToDeleteSelected, isModalOpenChanged } from 'features/deleteImageModal/store/slice';
export const addImageToDeleteSelectedListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: imagesToDeleteSelected,
effect: (action, { dispatch, getState }) => {
const imageDTOs = action.payload;
const state = getState();
const { shouldConfirmOnDelete } = state.system;
const imagesUsage = selectImageUsage(getState());
const isImageInUse =
imagesUsage.some((i) => i.isRasterLayerImage) ||
imagesUsage.some((i) => i.isControlLayerImage) ||
imagesUsage.some((i) => i.isReferenceImage) ||
imagesUsage.some((i) => i.isInpaintMaskImage) ||
imagesUsage.some((i) => i.isUpscaleImage) ||
imagesUsage.some((i) => i.isNodesImage) ||
imagesUsage.some((i) => i.isRegionalGuidanceImage);
if (shouldConfirmOnDelete || isImageInUse) {
dispatch(isModalOpenChanged(true));
return;
}
dispatch(imageDeletionConfirmed({ imageDTOs, imagesUsage }));
},
});
};

View File

@@ -1,30 +0,0 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectionChanged } from 'features/gallery/store/gallerySlice';
import { imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
export const addImagesStarredListener = (startAppListening: AppStartListening) => {
startAppListening({
matcher: imagesApi.endpoints.starImages.matchFulfilled,
effect: (action, { dispatch, getState }) => {
const { updated_image_names: starredImages } = action.payload;
const state = getState();
const { selection } = state.gallery;
const updatedSelection: ImageDTO[] = [];
selection.forEach((selectedImageDTO) => {
if (starredImages.includes(selectedImageDTO.image_name)) {
updatedSelection.push({
...selectedImageDTO,
starred: true,
});
} else {
updatedSelection.push(selectedImageDTO);
}
});
dispatch(selectionChanged(updatedSelection));
},
});
};

View File

@@ -1,30 +0,0 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectionChanged } from 'features/gallery/store/gallerySlice';
import { imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
export const addImagesUnstarredListener = (startAppListening: AppStartListening) => {
startAppListening({
matcher: imagesApi.endpoints.unstarImages.matchFulfilled,
effect: (action, { dispatch, getState }) => {
const { updated_image_names: unstarredImages } = action.payload;
const state = getState();
const { selection } = state.gallery;
const updatedSelection: ImageDTO[] = [];
selection.forEach((selectedImageDTO) => {
if (unstarredImages.includes(selectedImageDTO.image_name)) {
updatedSelection.push({
...selectedImageDTO,
starred: false,
});
} else {
updatedSelection.push(selectedImageDTO);
}
});
dispatch(selectionChanged(updatedSelection));
},
});
};

View File

@@ -1,11 +1,7 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppDispatch, RootState } from 'app/store/store';
import {
controlLayerModelChanged,
referenceImageIPAdapterModelChanged,
rgIPAdapterModelChanged,
} from 'features/controlLayers/store/canvasSlice';
import { controlLayerModelChanged, rgRefImageModelChanged } from 'features/controlLayers/store/canvasSlice';
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
import {
clipEmbedModelSelected,
@@ -15,8 +11,9 @@ import {
t5EncoderModelSelected,
vaeSelected,
} from 'features/controlLayers/store/paramsSlice';
import { refImageModelChanged, selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { getEntityIdentifier } from 'features/controlLayers/store/types';
import { getEntityIdentifier, isFLUXReduxConfig, isIPAdapterConfig } from 'features/controlLayers/store/types';
import { modelSelected } from 'features/parameters/store/actions';
import { postProcessingModelChanged, upscaleModelChanged } from 'features/parameters/store/upscaleSlice';
import {
@@ -210,12 +207,12 @@ const handleControlAdapterModels: ModelHandler = (models, state, dispatch, log)
const handleIPAdapterModels: ModelHandler = (models, state, dispatch, log) => {
const ipaModels = models.filter(isIPAdapterModelConfig);
selectCanvasSlice(state).referenceImages.entities.forEach((entity) => {
if (entity.ipAdapter.type !== 'ip_adapter') {
selectRefImagesSlice(state).entities.forEach((entity) => {
if (!isIPAdapterConfig(entity.config)) {
return;
}
const selectedIPAdapterModel = entity.ipAdapter.model;
const selectedIPAdapterModel = entity.config.model;
// `null` is a valid IP adapter model - no need to do anything.
if (!selectedIPAdapterModel) {
return;
@@ -225,16 +222,16 @@ const handleIPAdapterModels: ModelHandler = (models, state, dispatch, log) => {
return;
}
log.debug({ selectedIPAdapterModel }, 'Selected IP adapter model is not available, clearing');
dispatch(referenceImageIPAdapterModelChanged({ entityIdentifier: getEntityIdentifier(entity), modelConfig: null }));
dispatch(refImageModelChanged({ id: entity.id, modelConfig: null }));
});
selectCanvasSlice(state).regionalGuidance.entities.forEach((entity) => {
entity.referenceImages.forEach(({ id: referenceImageId, ipAdapter }) => {
if (ipAdapter.type !== 'ip_adapter') {
entity.referenceImages.forEach(({ id: referenceImageId, config }) => {
if (!isIPAdapterConfig(config)) {
return;
}
const selectedIPAdapterModel = ipAdapter.model;
const selectedIPAdapterModel = config.model;
// `null` is a valid IP adapter model - no need to do anything.
if (!selectedIPAdapterModel) {
return;
@@ -245,7 +242,7 @@ const handleIPAdapterModels: ModelHandler = (models, state, dispatch, log) => {
}
log.debug({ selectedIPAdapterModel }, 'Selected IP adapter model is not available, clearing');
dispatch(
rgIPAdapterModelChanged({ entityIdentifier: getEntityIdentifier(entity), referenceImageId, modelConfig: null })
rgRefImageModelChanged({ entityIdentifier: getEntityIdentifier(entity), referenceImageId, modelConfig: null })
);
});
});
@@ -254,11 +251,11 @@ const handleIPAdapterModels: ModelHandler = (models, state, dispatch, log) => {
const handleFLUXReduxModels: ModelHandler = (models, state, dispatch, log) => {
const fluxReduxModels = models.filter(isFluxReduxModelConfig);
selectCanvasSlice(state).referenceImages.entities.forEach((entity) => {
if (entity.ipAdapter.type !== 'flux_redux') {
selectRefImagesSlice(state).entities.forEach((entity) => {
if (!isFLUXReduxConfig(entity.config)) {
return;
}
const selectedFLUXReduxModel = entity.ipAdapter.model;
const selectedFLUXReduxModel = entity.config.model;
// `null` is a valid FLUX Redux model - no need to do anything.
if (!selectedFLUXReduxModel) {
return;
@@ -268,16 +265,16 @@ const handleFLUXReduxModels: ModelHandler = (models, state, dispatch, log) => {
return;
}
log.debug({ selectedFLUXReduxModel }, 'Selected FLUX Redux model is not available, clearing');
dispatch(referenceImageIPAdapterModelChanged({ entityIdentifier: getEntityIdentifier(entity), modelConfig: null }));
dispatch(refImageModelChanged({ id: entity.id, modelConfig: null }));
});
selectCanvasSlice(state).regionalGuidance.entities.forEach((entity) => {
entity.referenceImages.forEach(({ id: referenceImageId, ipAdapter }) => {
if (ipAdapter.type !== 'flux_redux') {
entity.referenceImages.forEach(({ id: referenceImageId, config }) => {
if (!isFLUXReduxConfig(config)) {
return;
}
const selectedFLUXReduxModel = ipAdapter.model;
const selectedFLUXReduxModel = config.model;
// `null` is a valid FLUX Redux model - no need to do anything.
if (!selectedFLUXReduxModel) {
return;
@@ -288,7 +285,7 @@ const handleFLUXReduxModels: ModelHandler = (models, state, dispatch, log) => {
}
log.debug({ selectedFLUXReduxModel }, 'Selected FLUX Redux model is not available, clearing');
dispatch(
rgIPAdapterModelChanged({ entityIdentifier: getEntityIdentifier(entity), referenceImageId, modelConfig: null })
rgRefImageModelChanged({ entityIdentifier: getEntityIdentifier(entity), referenceImageId, modelConfig: null })
);
});
});

View File

@@ -0,0 +1,13 @@
import { $didStudioInit } from 'app/hooks/useStudioInitAction';
import { atom, computed } from 'nanostores';
import { flushSync } from 'react-dom';
export const $isLayoutLoading = atom(false);
export const setIsLayoutLoading = (isLoading: boolean) => {
flushSync(() => {
$isLayoutLoading.set(isLoading);
});
};
export const $globalIsLoading = computed([$didStudioInit, $isLayoutLoading], (didStudioInit, isLayoutLoading) => {
return !didStudioInit || isLayoutLoading;
});

View File

@@ -8,12 +8,12 @@ import { changeBoardModalSlice } from 'features/changeBoardModal/store/slice';
import { canvasSettingsPersistConfig, canvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
import { canvasPersistConfig, canvasSlice, canvasUndoableConfig } from 'features/controlLayers/store/canvasSlice';
import {
canvasSessionSlice,
canvasStagingAreaPersistConfig,
canvasStagingAreaSlice,
} from 'features/controlLayers/store/canvasStagingAreaSlice';
import { lorasPersistConfig, lorasSlice } from 'features/controlLayers/store/lorasSlice';
import { paramsPersistConfig, paramsSlice } from 'features/controlLayers/store/paramsSlice';
import { deleteImageModalSlice } from 'features/deleteImageModal/store/slice';
import { refImagesPersistConfig, refImagesSlice } from 'features/controlLayers/store/refImagesSlice';
import { dynamicPromptsPersistConfig, dynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import { galleryPersistConfig, gallerySlice } from 'features/gallery/store/gallerySlice';
import { hrfPersistConfig, hrfSlice } from 'features/hrf/store/hrfSlice';
@@ -54,7 +54,6 @@ const allReducers = {
[configSlice.name]: configSlice.reducer,
[uiSlice.name]: uiSlice.reducer,
[dynamicPromptsSlice.name]: dynamicPromptsSlice.reducer,
[deleteImageModalSlice.name]: deleteImageModalSlice.reducer,
[changeBoardModalSlice.name]: changeBoardModalSlice.reducer,
[modelManagerV2Slice.name]: modelManagerV2Slice.reducer,
[queueSlice.name]: queueSlice.reducer,
@@ -65,9 +64,10 @@ const allReducers = {
[stylePresetSlice.name]: stylePresetSlice.reducer,
[paramsSlice.name]: paramsSlice.reducer,
[canvasSettingsSlice.name]: canvasSettingsSlice.reducer,
[canvasStagingAreaSlice.name]: canvasStagingAreaSlice.reducer,
[canvasSessionSlice.name]: canvasSessionSlice.reducer,
[lorasSlice.name]: lorasSlice.reducer,
[workflowLibrarySlice.name]: workflowLibrarySlice.reducer,
[refImagesSlice.name]: refImagesSlice.reducer,
};
const rootReducer = combineReducers(allReducers);
@@ -113,6 +113,7 @@ const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
[canvasStagingAreaPersistConfig.name]: canvasStagingAreaPersistConfig,
[lorasPersistConfig.name]: lorasPersistConfig,
[workflowLibraryPersistConfig.name]: workflowLibraryPersistConfig,
[refImagesSlice.name]: refImagesPersistConfig,
};
const unserialize: UnserializeFunction = (data, key) => {
@@ -175,6 +176,7 @@ export const createStore = (uniqueStoreKey?: string, persist = true) =>
.concat(api.middleware)
.concat(dynamicMiddlewares)
.concat(authToastMiddleware)
// .concat(getDebugLoggerMiddleware())
.prepend(listenerMiddleware.middleware),
enhancers: (getDefaultEnhancers) => {
const _enhancers = getDefaultEnhancers().concat(autoBatchEnhancer());
@@ -209,3 +211,4 @@ export type RootState = ReturnType<AppStore['getState']>;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export type AppThunkDispatch = ThunkDispatch<RootState, any, UnknownAction>;
export type AppDispatch = ReturnType<typeof createStore>['dispatch'];
export type AppGetState = ReturnType<typeof createStore>['getState'];

View File

@@ -17,6 +17,7 @@ const Loading = () => {
right={0}
bottom={0}
left={0}
zIndex={99999}
>
<Image src={InvokeLogoWhite} w="8rem" h="8rem" />
<Spinner

View File

@@ -11,13 +11,14 @@ import { memo, useEffect, useMemo, useState } from 'react';
type Props = PropsWithChildren & {
maxHeight?: ChakraProps['maxHeight'];
maxWidth?: ChakraProps['maxWidth'];
overflowX?: 'hidden' | 'scroll';
overflowY?: 'hidden' | 'scroll';
};
const styles: CSSProperties = { position: 'absolute', top: 0, left: 0, right: 0, bottom: 0 };
const ScrollableContent = ({ children, maxHeight, overflowX = 'hidden', overflowY = 'scroll' }: Props) => {
const ScrollableContent = ({ children, maxHeight, maxWidth, overflowX = 'hidden', overflowY = 'scroll' }: Props) => {
const overlayscrollbarsOptions = useMemo(
() => getOverlayScrollbarsParams({ overflowX, overflowY }).options,
[overflowX, overflowY]
@@ -44,7 +45,7 @@ const ScrollableContent = ({ children, maxHeight, overflowX = 'hidden', overflow
}, [os]);
return (
<Flex w="full" h="full" maxHeight={maxHeight} position="relative">
<Flex w="full" h="full" maxHeight={maxHeight} maxWidth={maxWidth} position="relative">
<Box position="absolute" top={0} left={0} right={0} bottom={0}>
<OverlayScrollbarsComponent ref={osRef} style={styles} options={overlayscrollbarsOptions}>
{children}

View File

@@ -1,6 +1,6 @@
.os-scrollbar {
/* The size of the scrollbar */
--os-size: 9px;
--os-size: 8px;
/* The axis-perpedicular padding of the scrollbar (horizontal: padding-y, vertical: padding-x) */
/* --os-padding-perpendicular: 0; */
/* The axis padding of the scrollbar (horizontal: padding-x, vertical: padding-y) */
@@ -8,11 +8,11 @@
/* The border radius of the scrollbar track */
/* --os-track-border-radius: 0; */
/* The background of the scrollbar track */
/* --os-track-bg: rgba(0, 0, 0, 0.3); */
--os-track-bg: rgba(0, 0, 0, 0.5);
/* The :hover background of the scrollbar track */
/* --os-track-bg-hover: rgba(0, 0, 0, 0.3); */
--os-track-bg-hover: rgba(0, 0, 0, 0.5);
/* The :active background of the scrollbar track */
/* --os-track-bg-active: rgba(0, 0, 0, 0.3); */
--os-track-bg-active: rgba(0, 0, 0, 0.6);
/* The border of the scrollbar track */
/* --os-track-border: none; */
/* The :hover background of the scrollbar track */
@@ -22,11 +22,11 @@
/* The border radius of the scrollbar handle */
/* --os-handle-border-radius: 2px; */
/* The background of the scrollbar handle */
/* --os-handle-bg: var(--invokeai-colors-accentAlpha-500); */
--os-handle-bg: var(--invoke-colors-base-400);
/* The :hover background of the scrollbar handle */
/* --os-handle-bg-hover: var(--invokeai-colors-accentAlpha-700); */
--os-handle-bg-hover: var(--invoke-colors-base-300);
/* The :active background of the scrollbar handle */
/* --os-handle-bg-active: var(--invokeai-colors-accentAlpha-800); */
--os-handle-bg-active: var(--invoke-colors-base-250);
/* The border of the scrollbar handle */
/* --os-handle-border: none; */
/* The :hover border of the scrollbar handle */
@@ -34,23 +34,23 @@
/* The :active border of the scrollbar handle */
/* --os-handle-border-active: none; */
/* The min size of the scrollbar handle */
--os-handle-min-size: 50px;
--os-handle-min-size: 32px;
/* The max size of the scrollbar handle */
/* --os-handle-max-size: none; */
/* The axis-perpedicular size of the scrollbar handle (horizontal: height, vertical: width) */
/* --os-handle-perpendicular-size: 100%; */
/* The :hover axis-perpedicular size of the scrollbar handle (horizontal: height, vertical: width) */
/* --os-handle-perpendicular-size-hover: 100%; */
--os-handle-perpendicular-size-hover: 100%;
/* The :active axis-perpedicular size of the scrollbar handle (horizontal: height, vertical: width) */
/* --os-handle-perpendicular-size-active: 100%; */
/* Increases the interactive area of the scrollbar handle. */
/* --os-handle-interactive-area-offset: 0; */
--os-handle-interactive-area-offset: -1px;
}
.os-scrollbar-handle {
cursor: grab;
/* cursor: grab; */
}
.os-scrollbar-handle:active {
cursor: grabbing;
/* cursor: grabbing; */
}

View File

@@ -73,7 +73,7 @@ export const useBoolean = (initialValue: boolean): UseBoolean => {
};
};
type UseDisclosure = {
export type UseDisclosure = {
isOpen: boolean;
open: () => void;
close: () => void;

View File

@@ -0,0 +1,29 @@
import { combine } from '@atlaskit/pragmatic-drag-and-drop/combine';
import { dropTargetForElements } from '@atlaskit/pragmatic-drag-and-drop/element/adapter';
import { dropTargetForExternal } from '@atlaskit/pragmatic-drag-and-drop/external/adapter';
import { useTimeoutCallback } from 'common/hooks/useTimeoutCallback';
import type { RefObject } from 'react';
import { useEffect } from 'react';
export const useCallbackOnDragEnter = (cb: () => void, ref: RefObject<HTMLElement>, delay = 300) => {
const [run, cancel] = useTimeoutCallback(cb, delay);
useEffect(() => {
const element = ref.current;
if (!element) {
return;
}
return combine(
dropTargetForElements({
element,
onDragEnter: run,
onDragLeave: cancel,
}),
dropTargetForExternal({
element,
onDragEnter: run,
onDragLeave: cancel,
})
);
}, [cancel, ref, run]);
};

View File

@@ -0,0 +1,19 @@
import { useEffect } from 'react';
// Chakra tooltips sometimes open during a drag operation. We can fix it by dispatching an event that chakra listens
// for to close tooltips. It's reaching into the internals but it seems to work.
const closeEventName = 'chakra-ui:close-tooltip';
export const useCloseChakraTooltipsOnDragFix = () => {
useEffect(() => {
const closeTooltips = () => {
document.dispatchEvent(new window.CustomEvent(closeEventName));
};
document.addEventListener('drag', closeTooltips);
return () => {
document.removeEventListener('drag', closeTooltips);
};
}, []);
};

View File

@@ -0,0 +1,165 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
/**
* Adapted from https://github.com/chakra-ui/chakra-ui/blob/v2/packages/hooks/src/use-outside-click.ts
*
* The main change here is to support filtering of outside clicks via a `filter` function.
*
* This lets us work around issues with portals and components like popovers, which typically close on an outside click.
*
* For example, consider a popover that has a custom drop-down component inside it, which uses a portal to render
* the drop-down options. The original outside click handler would close the popover when clicking on the drop-down options,
* because the click is outside the popover - but we expect the popover to stay open in this case.
*
* A filter function like this can fix that:
*
* ```ts
* const filter = (el: HTMLElement) => el.className.includes('chakra-portal') || el.id.includes('react-select')
* ```
*
* This ignores clicks on react-select-based drop-downs and Chakra UI portals and is used as the default filter.
*/
import { useCallback, useEffect, useRef } from 'react';
type FilterFunction = (el: HTMLElement | SVGElement) => boolean;
export function useCallbackRef<T extends (...args: any[]) => any>(
callback: T | undefined,
deps: React.DependencyList = []
) {
const callbackRef = useRef(callback);
useEffect(() => {
callbackRef.current = callback;
});
// eslint-disable-next-line react-hooks/exhaustive-deps
return useCallback(((...args) => callbackRef.current?.(...args)) as T, deps);
}
export interface UseOutsideClickProps {
/**
* Whether the hook is enabled
*/
enabled?: boolean;
/**
* The reference to a DOM element.
*/
ref: React.RefObject<HTMLElement | null>;
/**
* Function invoked when a click is triggered outside the referenced element.
*/
handler?: (e: Event) => void;
/**
* A function that filters the elements that should be considered as outside clicks.
*
* If omitted, a default filter function that ignores clicks in Chakra UI portals and react-select components is used.
*/
filter?: FilterFunction;
}
export const DEFAULT_FILTER: FilterFunction = (el) => {
if (el instanceof SVGElement) {
// SVGElement's type appears to be incorrect. Its className is not a string, which causes `includes` to fail.
// Let's assume that SVG elements with a class name are not part of the portal and should not be filtered.
return false;
}
return el.className.includes('chakra-portal') || el.id.includes('react-select');
};
/**
* Example, used in components like Dialogs and Popovers, so they can close
* when a user clicks outside them.
*/
export function useFilterableOutsideClick(props: UseOutsideClickProps) {
const { ref, handler, enabled = true, filter = DEFAULT_FILTER } = props;
const savedHandler = useCallbackRef(handler);
const stateRef = useRef({
isPointerDown: false,
ignoreEmulatedMouseEvents: false,
});
const state = stateRef.current;
useEffect(() => {
if (!enabled) {
return;
}
const onPointerDown: any = (e: PointerEvent) => {
if (isValidEvent(e, ref, filter)) {
state.isPointerDown = true;
}
};
const onMouseUp: any = (event: MouseEvent) => {
if (state.ignoreEmulatedMouseEvents) {
state.ignoreEmulatedMouseEvents = false;
return;
}
if (state.isPointerDown && handler && isValidEvent(event, ref)) {
state.isPointerDown = false;
savedHandler(event);
}
};
const onTouchEnd = (event: TouchEvent) => {
state.ignoreEmulatedMouseEvents = true;
if (handler && state.isPointerDown && isValidEvent(event, ref)) {
state.isPointerDown = false;
savedHandler(event);
}
};
const doc = getOwnerDocument(ref.current);
doc.addEventListener('mousedown', onPointerDown, true);
doc.addEventListener('mouseup', onMouseUp, true);
doc.addEventListener('touchstart', onPointerDown, true);
doc.addEventListener('touchend', onTouchEnd, true);
return () => {
doc.removeEventListener('mousedown', onPointerDown, true);
doc.removeEventListener('mouseup', onMouseUp, true);
doc.removeEventListener('touchstart', onPointerDown, true);
doc.removeEventListener('touchend', onTouchEnd, true);
};
}, [handler, ref, savedHandler, state, enabled, filter]);
}
function isValidEvent(event: Event, ref: React.RefObject<HTMLElement | null>, filter?: FilterFunction): boolean {
const target = (event.composedPath?.()[0] ?? event.target) as HTMLElement;
if (target) {
const doc = getOwnerDocument(target);
if (!doc.contains(target)) {
return false;
}
}
if (ref.current?.contains(target)) {
return false;
}
// This is the main logic change from the original hook.
if (filter) {
// Check if the click is inside an element matching the filter.
// This is used for portal-awareness or other general exclusion cases.
let currentElement: HTMLElement | null = target;
// Traverse up the DOM tree from the target element.
while (currentElement && currentElement !== document.body) {
if (filter(currentElement)) {
return false;
}
currentElement = currentElement.parentElement;
}
}
// If the click is not inside the ref and not inside a portal, it's a valid outside click.
return true;
}
function getOwnerDocument(node?: Element | null): Document {
return node?.ownerDocument ?? document;
}

View File

@@ -1,6 +1,6 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { useCancelCurrentQueueItem } from 'features/queue/hooks/useCancelCurrentQueueItem';
import { useClearQueue } from 'features/queue/hooks/useClearQueue';
import { useDeleteCurrentQueueItem } from 'features/queue/hooks/useDeleteCurrentQueueItem';
import { useInvoke } from 'features/queue/hooks/useInvoke';
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
@@ -35,34 +35,39 @@ export const useGlobalHotkeys = () => {
dependencies: [queue],
});
const {
cancelQueueItem,
isDisabled: isDisabledCancelQueueItem,
isLoading: isLoadingCancelQueueItem,
} = useCancelCurrentQueueItem();
const deleteCurrentQueueItem = useDeleteCurrentQueueItem();
useRegisteredHotkeys({
id: 'cancelQueueItem',
category: 'app',
callback: cancelQueueItem,
callback: deleteCurrentQueueItem.trigger,
options: {
enabled: !isDisabledCancelQueueItem && !isLoadingCancelQueueItem,
enabled: !deleteCurrentQueueItem.isDisabled && !deleteCurrentQueueItem.isLoading,
preventDefault: true,
},
dependencies: [cancelQueueItem, isDisabledCancelQueueItem, isLoadingCancelQueueItem],
dependencies: [deleteCurrentQueueItem],
});
const { clearQueue, isDisabled: isDisabledClearQueue, isLoading: isLoadingClearQueue } = useClearQueue();
const clearQueue = useClearQueue();
useRegisteredHotkeys({
id: 'clearQueue',
category: 'app',
callback: clearQueue,
callback: clearQueue.trigger,
options: {
enabled: !isDisabledClearQueue && !isLoadingClearQueue,
enabled: !clearQueue.isDisabled && !clearQueue.isLoading,
preventDefault: true,
},
dependencies: [clearQueue, isDisabledClearQueue, isLoadingClearQueue],
dependencies: [clearQueue],
});
useRegisteredHotkeys({
id: 'selectGenerateTab',
category: 'app',
callback: () => {
dispatch(setActiveTab('generate'));
},
dependencies: [dispatch],
});
useRegisteredHotkeys({

View File

@@ -1,11 +1,11 @@
import type { IconButtonProps, SystemStyleObject } from '@invoke-ai/ui-library';
import { IconButton } from '@invoke-ai/ui-library';
import type { ButtonProps, IconButtonProps, SystemStyleObject } from '@invoke-ai/ui-library';
import { Button, IconButton } from '@invoke-ai/ui-library';
import { logger } from 'app/logging/logger';
import { useAppSelector } from 'app/store/storeHooks';
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
import { selectIsClientSideUploadEnabled } from 'features/system/store/configSlice';
import { toast } from 'features/toast/toast';
import { useCallback } from 'react';
import { memo, useCallback } from 'react';
import type { FileRejection } from 'react-dropzone';
import { useDropzone } from 'react-dropzone';
import { useTranslation } from 'react-i18next';
@@ -163,32 +163,63 @@ const sx = {
},
} satisfies SystemStyleObject;
export const UploadImageButton = ({
isDisabled = false,
onUpload,
isError = false,
...rest
}: {
export const UploadImageIconButton = memo(
({
isDisabled = false,
onUpload,
isError = false,
...rest
}: {
onUpload?: (imageDTO: ImageDTO) => void;
isError?: boolean;
} & SetOptional<IconButtonProps, 'aria-label'>) => {
const uploadApi = useImageUploadButton({ isDisabled, allowMultiple: false, onUpload });
return (
<>
<IconButton
aria-label="Upload image"
variant="outline"
sx={sx}
data-error={isError}
icon={<PiUploadBold />}
isLoading={uploadApi.request.isLoading}
{...rest}
{...uploadApi.getUploadButtonProps()}
/>
<input {...uploadApi.getUploadInputProps()} />
</>
);
}
);
UploadImageIconButton.displayName = 'UploadImageIconButton';
type UploadImageButtonProps = {
onUpload?: (imageDTO: ImageDTO) => void;
isError?: boolean;
} & SetOptional<IconButtonProps, 'aria-label'>) => {
} & ButtonProps;
const UploadImageButton = memo((props: UploadImageButtonProps) => {
const { children, isDisabled = false, onUpload, isError = false, ...rest } = props;
const uploadApi = useImageUploadButton({ isDisabled, allowMultiple: false, onUpload });
return (
<>
<IconButton
<Button
aria-label="Upload image"
variant="outline"
sx={sx}
data-error={isError}
icon={<PiUploadBold />}
rightIcon={<PiUploadBold />}
isLoading={uploadApi.request.isLoading}
{...rest}
{...uploadApi.getUploadButtonProps()}
/>
>
{children ?? 'Upload'}
</Button>
<input {...uploadApi.getUploadInputProps()} />
</>
);
};
});
UploadImageButton.displayName = 'UploadImageButton';
export const UploadMultipleImageButton = ({
isDisabled = false,

View File

@@ -1,12 +1,18 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { EMPTY_ARRAY } from 'app/store/constants';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import type { GroupBase } from 'chakra-react-select';
import { selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice';
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { uniq } from 'lodash-es';
import { useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetRelatedModelIdsBatchQuery } from 'services/api/endpoints/modelRelationships';
import type { AnyModelConfig } from 'services/api/types';
import { useGroupedModelCombobox } from './useGroupedModelCombobox';
import { useRelatedModelKeys } from './useRelatedModelKeys';
import { useSelectedModelKeys } from './useSelectedModelKeys';
type UseRelatedGroupedModelComboboxArg<T extends AnyModelConfig> = {
modelConfigs: T[];
@@ -29,6 +35,32 @@ type UseRelatedGroupedModelComboboxReturn = {
noOptionsMessage: () => string;
};
const selectSelectedModelKeys = createMemoizedSelector(selectParamsSlice, selectLoRAsSlice, (params, loras) => {
const keys: string[] = [];
const main = params.model;
const vae = params.vae;
const refiner = params.refinerModel;
const controlnet = params.controlLora;
if (main) {
keys.push(main.key);
}
if (vae) {
keys.push(vae.key);
}
if (refiner) {
keys.push(refiner.key);
}
if (controlnet) {
keys.push(controlnet.key);
}
for (const { model } of loras.loras) {
keys.push(model.key);
}
return uniq(keys);
});
export function useRelatedGroupedModelCombobox<T extends AnyModelConfig>({
modelConfigs,
selectedModel,
@@ -39,9 +71,15 @@ export function useRelatedGroupedModelCombobox<T extends AnyModelConfig>({
}: UseRelatedGroupedModelComboboxArg<T>): UseRelatedGroupedModelComboboxReturn {
const { t } = useTranslation();
const selectedKeys = useSelectedModelKeys();
const relatedKeys = useRelatedModelKeys(selectedKeys);
const selectedKeys = useAppSelector(selectSelectedModelKeys);
const { relatedKeys } = useGetRelatedModelIdsBatchQuery(selectedKeys, {
selectFromResult: ({ data }) => {
if (!data) {
return { relatedKeys: EMPTY_ARRAY };
}
return { relatedKeys: data };
},
});
// Base grouped options
const base = useGroupedModelCombobox({
@@ -53,40 +91,42 @@ export function useRelatedGroupedModelCombobox<T extends AnyModelConfig>({
groupByType,
});
// If no related models selected, just return base
if (relatedKeys.size === 0) {
return base;
}
const options = useMemo(() => {
if (relatedKeys.length === 0) {
return base.options;
}
const relatedOptions: ComboboxOption[] = [];
const updatedGroups: GroupBase<ComboboxOption>[] = [];
const relatedOptions: ComboboxOption[] = [];
const updatedGroups: GroupBase<ComboboxOption>[] = [];
for (const group of base.options) {
const remainingOptions: ComboboxOption[] = [];
for (const group of base.options) {
const remainingOptions: ComboboxOption[] = [];
for (const option of group.options) {
if (relatedKeys.has(option.value)) {
relatedOptions.push({ ...option, label: `* ${option.label}` });
} else {
remainingOptions.push(option);
for (const option of group.options) {
if (relatedKeys.includes(option.value)) {
relatedOptions.push({ ...option, label: `* ${option.label}` });
} else {
remainingOptions.push(option);
}
}
if (remainingOptions.length > 0) {
updatedGroups.push({
label: group.label,
options: remainingOptions,
});
}
}
if (remainingOptions.length > 0) {
updatedGroups.push({
label: group.label,
options: remainingOptions,
});
if (relatedOptions.length > 0) {
return [{ label: t('modelManager.relatedModels'), options: relatedOptions }, ...updatedGroups];
} else {
return updatedGroups;
}
}
const finalOptions: GroupBase<ComboboxOption>[] =
relatedOptions.length > 0
? [{ label: t('modelManager.relatedModels'), options: relatedOptions }, ...updatedGroups]
: updatedGroups;
}, [base.options, relatedKeys, t]);
return {
...base,
options: finalOptions,
options,
};
}

View File

@@ -1,14 +0,0 @@
import { useMemo } from 'react';
import { useGetRelatedModelIdsBatchQuery } from 'services/api/endpoints/modelRelationships';
/**
* Fetches related model keys for a given set of selected model keys.
* Returns a Set<string> for fast lookup.
*/
export const useRelatedModelKeys = (selectedKeys: Set<string>) => {
const { data: related = [] } = useGetRelatedModelIdsBatchQuery([...selectedKeys], {
skip: selectedKeys.size === 0,
});
return useMemo(() => new Set(related), [related]);
};

View File

@@ -1,34 +0,0 @@
import { useAppSelector } from 'app/store/storeHooks';
/**
* Gathers all currently selected model keys from parameters and loras.
* This includes the main model, VAE, refiner model, controlnet, and loras.
*/
export const useSelectedModelKeys = () => {
return useAppSelector((state) => {
const keys = new Set<string>();
const main = state.params.model;
const vae = state.params.vae;
const refiner = state.params.refinerModel;
const controlnet = state.params.controlLora;
const loras = state.loras.loras.map((l) => l.model);
if (main) {
keys.add(main.key);
}
if (vae) {
keys.add(vae.key);
}
if (refiner) {
keys.add(refiner.key);
}
if (controlnet) {
keys.add(controlnet.key);
}
for (const lora of loras) {
keys.add(lora.key);
}
return keys;
});
};

View File

@@ -0,0 +1,28 @@
import type { Selector } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
import { useAppStore } from 'app/store/storeHooks';
import type { Atom, WritableAtom } from 'nanostores';
import { atom } from 'nanostores';
import { useEffect, useState } from 'react';
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
export const useSelectorAsAtom = <T extends Selector<RootState, any>>(selector: T): Atom<ReturnType<T>> => {
const store = useAppStore();
const $atom = useState<WritableAtom<ReturnType<T>>>(() => atom<ReturnType<T>>(selector(store.getState())))[0];
useEffect(() => {
const unsubscribe = store.subscribe(() => {
const prev = $atom.get();
const next = selector(store.getState());
if (prev !== next) {
$atom.set(next);
}
});
return () => {
unsubscribe();
};
}, [$atom, selector, store]);
return $atom;
};

View File

@@ -0,0 +1,21 @@
import { useCallback, useMemo, useRef } from 'react';
export const useTimeoutCallback = (callback: () => void, delay: number, onCancel?: () => void) => {
const timeoutRef = useRef<number | null>(null);
const cancel = useCallback(() => {
if (timeoutRef.current !== null) {
window.clearTimeout(timeoutRef.current);
timeoutRef.current = null;
onCancel?.();
}
}, [onCancel]);
const callWithTimeout = useCallback(() => {
cancel();
timeoutRef.current = window.setTimeout(() => {
callback();
timeoutRef.current = null;
}, delay);
}, [callback, cancel, delay]);
const api = useMemo(() => [callWithTimeout, cancel] as const, [callWithTimeout, cancel]);
return api;
};

View File

@@ -0,0 +1,10 @@
import type { z } from 'zod';
/**
* Helper to create a type guard from a zod schema. The type guard will infer the schema's TS type.
* @param schema The zod schema to create a type guard from.
* @returns A type guard function for the schema.
*/
export const buildZodTypeGuard = <T extends z.ZodTypeAny>(schema: T) => {
return (val: unknown): val is z.infer<T> => schema.safeParse(val).success;
};

View File

@@ -16,7 +16,7 @@ import { useAddImagesToBoardMutation, useRemoveImagesFromBoardMutation } from 's
const selectImagesToChange = createMemoizedSelector(
selectChangeBoardModalSlice,
(changeBoardModal) => changeBoardModal.imagesToChange
(changeBoardModal) => changeBoardModal.image_names
);
const selectIsModalOpen = createSelector(
@@ -57,10 +57,10 @@ const ChangeBoardModal = () => {
}
if (selectedBoard === 'none') {
removeImagesFromBoard({ imageDTOs: imagesToChange });
removeImagesFromBoard({ image_names: imagesToChange });
} else {
addImagesToBoard({
imageDTOs: imagesToChange,
image_names: imagesToChange,
board_id: selectedBoard,
});
}

View File

@@ -2,5 +2,5 @@ import type { ChangeBoardModalState } from './types';
export const initialState: ChangeBoardModalState = {
isModalOpen: false,
imagesToChange: [],
image_names: [],
};

View File

@@ -1,7 +1,6 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
import type { ImageDTO } from 'services/api/types';
import { initialState } from './initialState';
@@ -12,11 +11,11 @@ export const changeBoardModalSlice = createSlice({
isModalOpenChanged: (state, action: PayloadAction<boolean>) => {
state.isModalOpen = action.payload;
},
imagesToChangeSelected: (state, action: PayloadAction<ImageDTO[]>) => {
state.imagesToChange = action.payload;
imagesToChangeSelected: (state, action: PayloadAction<string[]>) => {
state.image_names = action.payload;
},
changeBoardReset: (state) => {
state.imagesToChange = [];
state.image_names = [];
state.isModalOpen = false;
},
},

View File

@@ -1,6 +1,4 @@
import type { ImageDTO } from 'services/api/types';
export type ChangeBoardModalState = {
isModalOpen: boolean;
imagesToChange: ImageDTO[];
image_names: string[];
};

View File

@@ -0,0 +1,182 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import {
ContextMenu,
Divider,
Flex,
IconButton,
Menu,
MenuButton,
MenuList,
Tab,
TabList,
TabPanel,
TabPanels,
Tabs,
} from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { FocusRegionWrapper } from 'common/components/FocusRegionWrapper';
import { CanvasAlertsInvocationProgress } from 'features/controlLayers/components/CanvasAlerts/CanvasAlertsInvocationProgress';
import { CanvasAlertsPreserveMask } from 'features/controlLayers/components/CanvasAlerts/CanvasAlertsPreserveMask';
import { CanvasAlertsSelectedEntityStatus } from 'features/controlLayers/components/CanvasAlerts/CanvasAlertsSelectedEntityStatus';
import { CanvasContextMenuGlobalMenuItems } from 'features/controlLayers/components/CanvasContextMenu/CanvasContextMenuGlobalMenuItems';
import { CanvasContextMenuSelectedEntityMenuItems } from 'features/controlLayers/components/CanvasContextMenu/CanvasContextMenuSelectedEntityMenuItems';
import { CanvasDropArea } from 'features/controlLayers/components/CanvasDropArea';
import { Filter } from 'features/controlLayers/components/Filters/Filter';
import { CanvasHUD } from 'features/controlLayers/components/HUD/CanvasHUD';
import { InvokeCanvasComponent } from 'features/controlLayers/components/InvokeCanvasComponent';
import { SelectObject } from 'features/controlLayers/components/SelectObject/SelectObject';
import { CanvasSessionContextProvider } from 'features/controlLayers/components/SimpleSession/context';
import { GenerateLaunchpadPanel } from 'features/controlLayers/components/SimpleSession/GenerateLaunchpadPanel';
import { StagingAreaItemsList } from 'features/controlLayers/components/SimpleSession/StagingAreaItemsList';
import { StagingAreaToolbar } from 'features/controlLayers/components/StagingArea/StagingAreaToolbar';
import { CanvasToolbar } from 'features/controlLayers/components/Toolbar/CanvasToolbar';
import { Transform } from 'features/controlLayers/components/Transform/Transform';
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { selectDynamicGrid, selectShowHUD } from 'features/controlLayers/store/canvasSettingsSlice';
import { ImageViewer } from 'features/gallery/components/ImageViewer/ImageViewer';
import { ViewerToolbar } from 'features/gallery/components/ImageViewer/ViewerToolbar';
import { memo, useCallback } from 'react';
import { PiDotsThreeOutlineVerticalFill } from 'react-icons/pi';
const FOCUS_REGION_STYLES: SystemStyleObject = {
width: 'full',
height: 'full',
};
const MenuContent = memo(() => {
return (
<CanvasManagerProviderGate>
<MenuList>
<CanvasContextMenuSelectedEntityMenuItems />
<CanvasContextMenuGlobalMenuItems />
</MenuList>
</CanvasManagerProviderGate>
);
});
MenuContent.displayName = 'MenuContent';
const canvasBgSx = {
position: 'relative',
w: 'full',
h: 'full',
borderRadius: 'base',
overflow: 'hidden',
bg: 'base.900',
'&[data-dynamic-grid="true"]': {
bg: 'base.850',
},
};
export const AdvancedSession = memo(({ id }: { id: string | null }) => {
const dynamicGrid = useAppSelector(selectDynamicGrid);
const showHUD = useAppSelector(selectShowHUD);
const renderMenu = useCallback(() => {
return <MenuContent />;
}, []);
return (
<Tabs w="full" h="full">
<TabList>
<Tab>Welcome</Tab>
<Tab>Workspace</Tab>
<Tab>Viewer</Tab>
</TabList>
<TabPanels w="full" h="full">
<TabPanel w="full" h="full" justifyContent="center">
<GenerateLaunchpadPanel />
</TabPanel>
<TabPanel w="full" h="full">
<FocusRegionWrapper region="canvas" sx={FOCUS_REGION_STYLES}>
<Flex
tabIndex={-1}
borderRadius="base"
position="relative"
flexDirection="column"
height="full"
width="full"
gap={2}
alignItems="center"
justifyContent="center"
overflow="hidden"
>
<CanvasManagerProviderGate>
<CanvasToolbar />
</CanvasManagerProviderGate>
<Divider />
<ContextMenu<HTMLDivElement> renderMenu={renderMenu} withLongPress={false}>
{(ref) => (
<Flex ref={ref} sx={canvasBgSx} data-dynamic-grid={dynamicGrid}>
<InvokeCanvasComponent />
<CanvasManagerProviderGate>
<Flex
position="absolute"
flexDir="column"
top={1}
insetInlineStart={1}
pointerEvents="none"
gap={2}
alignItems="flex-start"
>
{showHUD && <CanvasHUD />}
<CanvasAlertsSelectedEntityStatus />
<CanvasAlertsPreserveMask />
<CanvasAlertsInvocationProgress />
</Flex>
<Flex position="absolute" top={1} insetInlineEnd={1}>
<Menu>
<MenuButton as={IconButton} icon={<PiDotsThreeOutlineVerticalFill />} colorScheme="base" />
<MenuContent />
</Menu>
</Flex>
</CanvasManagerProviderGate>
</Flex>
)}
</ContextMenu>
{id !== null && (
<CanvasManagerProviderGate>
<CanvasSessionContextProvider type="advanced" id={id}>
<Flex
position="absolute"
flexDir="column"
bottom={4}
gap={2}
align="center"
justify="center"
left={4}
right={4}
>
<Flex position="relative" maxW="full" w="full" h={108}>
<StagingAreaItemsList />
</Flex>
<Flex gap={2}>
<StagingAreaToolbar />
</Flex>
</Flex>
</CanvasSessionContextProvider>
</CanvasManagerProviderGate>
)}
<Flex position="absolute" bottom={4}>
<CanvasManagerProviderGate>
<Filter />
<Transform />
<SelectObject />
</CanvasManagerProviderGate>
</Flex>
<CanvasManagerProviderGate>
<CanvasDropArea />
</CanvasManagerProviderGate>
</Flex>
</FocusRegionWrapper>
</TabPanel>
<TabPanel w="full" h="full">
<Flex flexDir="column" w="full" h="full">
<ViewerToolbar />
<ImageViewer />
</Flex>
</TabPanel>
</TabPanels>
</Tabs>
);
});
AdvancedSession.displayName = 'AdvancedSession';

View File

@@ -2,11 +2,10 @@ import { Button, Flex, Heading } from '@invoke-ai/ui-library';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import {
useAddControlLayer,
useAddGlobalReferenceImage,
useAddInpaintMask,
useAddNewRegionalGuidanceWithARefImage,
useAddRasterLayer,
useAddRegionalGuidance,
useAddRegionalReferenceImage,
} from 'features/controlLayers/hooks/addLayerHooks';
import { useIsEntityTypeEnabled } from 'features/controlLayers/hooks/useIsEntityTypeEnabled';
import { memo } from 'react';
@@ -19,9 +18,7 @@ export const CanvasAddEntityButtons = memo(() => {
const addRegionalGuidance = useAddRegionalGuidance();
const addRasterLayer = useAddRasterLayer();
const addControlLayer = useAddControlLayer();
const addGlobalReferenceImage = useAddGlobalReferenceImage();
const addRegionalReferenceImage = useAddRegionalReferenceImage();
const isReferenceImageEnabled = useIsEntityTypeEnabled('reference_image');
const addRegionalReferenceImage = useAddNewRegionalGuidanceWithARefImage();
const isRegionalGuidanceEnabled = useIsEntityTypeEnabled('regional_guidance');
const isControlLayerEnabled = useIsEntityTypeEnabled('control_layer');
const isInpaintLayerEnabled = useIsEntityTypeEnabled('inpaint_mask');
@@ -29,21 +26,6 @@ export const CanvasAddEntityButtons = memo(() => {
return (
<Flex w="full" h="full" justifyContent="center" gap={4}>
<Flex position="relative" flexDir="column" gap={4} top="20%">
<Flex flexDir="column" justifyContent="flex-start" gap={2}>
<Heading size="xs">{t('controlLayers.global')}</Heading>
<InformationalPopover feature="globalReferenceImage">
<Button
size="sm"
variant="ghost"
justifyContent="flex-start"
leftIcon={<PiPlusBold />}
onClick={addGlobalReferenceImage}
isDisabled={!isReferenceImageEnabled}
>
{t('controlLayers.globalReferenceImage')}
</Button>
</InformationalPopover>
</Flex>
<Flex flexDir="column" gap={2}>
<Heading size="xs">{t('controlLayers.regional')}</Heading>
<InformationalPopover feature="inpainting">

View File

@@ -6,11 +6,11 @@ import { selectIsLocal } from 'features/system/store/configSlice';
import { selectSystemShouldShowInvocationProgressDetail } from 'features/system/store/systemSlice';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { $invocationProgressMessage } from 'services/events/stores';
import { $lastProgressMessage } from 'services/events/stores';
const CanvasAlertsInvocationProgressContentLocal = memo(() => {
const { t } = useTranslation();
const invocationProgressMessage = useStore($invocationProgressMessage);
const invocationProgressMessage = useStore($lastProgressMessage);
if (!invocationProgressMessage) {
return null;

View File

@@ -1,146 +0,0 @@
import { Alert, AlertDescription, AlertIcon, AlertTitle, Button, Flex } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useBoolean } from 'common/hooks/useBoolean';
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
import { useCurrentDestination } from 'features/queue/hooks/useCurrentDestination';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { activeTabCanvasRightPanelChanged, setActiveTab } from 'features/ui/store/uiSlice';
import { AnimatePresence, motion } from 'framer-motion';
import type { PropsWithChildren, ReactNode } from 'react';
import { useCallback, useMemo } from 'react';
import { Trans, useTranslation } from 'react-i18next';
const ActivateImageViewerButton = (props: PropsWithChildren) => {
const imageViewer = useImageViewer();
const dispatch = useAppDispatch();
const onClick = useCallback(() => {
imageViewer.open();
dispatch(activeTabCanvasRightPanelChanged('gallery'));
}, [imageViewer, dispatch]);
return (
<Button onClick={onClick} size="sm" variant="link" color="base.50">
{props.children}
</Button>
);
};
export const CanvasAlertsSendingToGallery = () => {
const { t } = useTranslation();
const destination = useCurrentDestination();
const tab = useAppSelector(selectActiveTab);
const isVisible = useMemo(() => {
// This alert should only be visible when the destination is gallery and the tab is canvas
if (tab !== 'canvas') {
return false;
}
if (!destination) {
return false;
}
return destination === 'gallery';
}, [destination, tab]);
return (
<AlertWrapper
title={t('controlLayers.sendingToGallery')}
description={
<Trans i18nKey="controlLayers.viewProgressInViewer" components={{ Btn: <ActivateImageViewerButton /> }} />
}
isVisible={isVisible}
/>
);
};
const ActivateCanvasButton = (props: PropsWithChildren) => {
const dispatch = useAppDispatch();
const imageViewer = useImageViewer();
const onClick = useCallback(() => {
dispatch(setActiveTab('canvas'));
dispatch(activeTabCanvasRightPanelChanged('layers'));
imageViewer.close();
}, [dispatch, imageViewer]);
return (
<Button onClick={onClick} size="sm" variant="link" color="base.50">
{props.children}
</Button>
);
};
export const CanvasAlertsSendingToCanvas = () => {
const { t } = useTranslation();
const destination = useCurrentDestination();
const isStaging = useAppSelector(selectIsStaging);
const tab = useAppSelector(selectActiveTab);
const isVisible = useMemo(() => {
// When we are on a non-canvas tab, and the current generation's destination is not the canvas, we don't show the alert
// For example, on the workflows tab, when the destinatin is gallery, we don't show the alert
if (tab !== 'canvas' && destination !== 'canvas') {
return false;
}
if (isStaging) {
return true;
}
if (!destination) {
return false;
}
return destination === 'canvas';
}, [destination, isStaging, tab]);
return (
<AlertWrapper
title={t('controlLayers.sendingToCanvas')}
description={
<Trans i18nKey="controlLayers.viewProgressOnCanvas" components={{ Btn: <ActivateCanvasButton /> }} />
}
isVisible={isVisible}
/>
);
};
const AlertWrapper = ({
title,
description,
isVisible,
}: {
title: ReactNode;
description: ReactNode;
isVisible: boolean;
}) => {
const isHovered = useBoolean(false);
return (
<AnimatePresence>
{(isVisible || isHovered.isTrue) && (
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1, transition: { duration: 0.1, ease: 'easeOut' } }}
exit={{
opacity: 0,
transition: { duration: 0.1, delay: !isHovered.isTrue ? 1 : 0.1, ease: 'easeIn' },
}}
onMouseEnter={isHovered.setTrue}
onMouseLeave={isHovered.setFalse}
>
<Alert
status="warning"
flexDir="column"
pointerEvents="auto"
borderRadius="base"
fontSize="sm"
shadow="md"
w="fit-content"
>
<Flex w="full" alignItems="center">
<AlertIcon />
<AlertTitle>{title}</AlertTitle>
</Flex>
<AlertDescription>{description}</AlertDescription>
</Alert>
</motion.div>
)}
</AnimatePresence>
);
};

View File

@@ -2,8 +2,8 @@ import { MenuGroup } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { ControlLayerMenuItems } from 'features/controlLayers/components/ControlLayer/ControlLayerMenuItems';
import { InpaintMaskMenuItems } from 'features/controlLayers/components/InpaintMask/InpaintMaskMenuItems';
import { IPAdapterMenuItems } from 'features/controlLayers/components/IPAdapter/IPAdapterMenuItems';
import { RasterLayerMenuItems } from 'features/controlLayers/components/RasterLayer/RasterLayerMenuItems';
import { IPAdapterMenuItems } from 'features/controlLayers/components/RefImage/IPAdapterMenuItems';
import { RegionalGuidanceMenuItems } from 'features/controlLayers/components/RegionalGuidance/RegionalGuidanceMenuItems';
import { CanvasEntityStateGate } from 'features/controlLayers/contexts/CanvasEntityStateGate';
import {

View File

@@ -2,7 +2,6 @@ import { Grid, GridItem } from '@invoke-ai/ui-library';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { newCanvasEntityFromImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
@@ -13,19 +12,11 @@ const addControlLayerFromImageDndTargetData = newCanvasEntityFromImageDndTarget.
const addRegionalGuidanceReferenceImageFromImageDndTargetData = newCanvasEntityFromImageDndTarget.getData({
type: 'regional_guidance_with_reference_image',
});
const addGlobalReferenceImageFromImageDndTargetData = newCanvasEntityFromImageDndTarget.getData({
type: 'reference_image',
});
export const CanvasDropArea = memo(() => {
const { t } = useTranslation();
const imageViewer = useImageViewer();
const isBusy = useCanvasIsBusy();
if (imageViewer.isOpen) {
return null;
}
return (
<>
<Grid
@@ -63,14 +54,6 @@ export const CanvasDropArea = memo(() => {
isDisabled={isBusy}
/>
</GridItem>
<GridItem position="relative">
<DndDropTarget
dndTarget={newCanvasEntityFromImageDndTarget}
dndTargetData={addGlobalReferenceImageFromImageDndTargetData}
label={t('controlLayers.canvasContextMenu.newGlobalReferenceImage')}
isDisabled={isBusy}
/>
</GridItem>
</Grid>
</>
);

View File

@@ -14,7 +14,6 @@ import { useEntityTypeInformationalPopover } from 'features/controlLayers/hooks/
import { useEntityTypeTitle } from 'features/controlLayers/hooks/useEntityTypeTitle';
import { entitiesReordered } from 'features/controlLayers/store/canvasSlice';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import { isRenderableEntityType } from 'features/controlLayers/store/types';
import { singleCanvasEntityDndSource } from 'features/dnd/dnd';
import { triggerPostMoveFlash } from 'features/dnd/util';
import type { PropsWithChildren } from 'react';
@@ -165,8 +164,8 @@ export const CanvasEntityGroupList = memo(({ isSelected, type, children, entityI
<Spacer />
</Flex>
{isRenderableEntityType(type) && <CanvasEntityMergeVisibleButton type={type} />}
{isRenderableEntityType(type) && <CanvasEntityTypeIsHiddenToggle type={type} />}
<CanvasEntityMergeVisibleButton type={type} />
<CanvasEntityTypeIsHiddenToggle type={type} />
<CanvasEntityAddOfTypeButton type={type} />
</Flex>
<Collapse in={collapse.isTrue} style={fixTooltipCloseOnScrollStyles}>

View File

@@ -2,7 +2,6 @@ import { Flex } from '@invoke-ai/ui-library';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { ControlLayerEntityList } from 'features/controlLayers/components/ControlLayer/ControlLayerEntityList';
import { InpaintMaskList } from 'features/controlLayers/components/InpaintMask/InpaintMaskList';
import { IPAdapterList } from 'features/controlLayers/components/IPAdapter/IPAdapterList';
import { RasterLayerEntityList } from 'features/controlLayers/components/RasterLayer/RasterLayerEntityList';
import { RegionalGuidanceEntityList } from 'features/controlLayers/components/RegionalGuidance/RegionalGuidanceEntityList';
import { memo } from 'react';
@@ -11,7 +10,6 @@ export const CanvasEntityList = memo(() => {
return (
<ScrollableContent>
<Flex flexDir="column" gap={2} data-testid="control-layers-layer-list" w="full" h="full">
<IPAdapterList />
<InpaintMaskList />
<RegionalGuidanceEntityList />
<ControlLayerEntityList />

View File

@@ -1,11 +1,10 @@
import { IconButton, Menu, MenuButton, MenuGroup, MenuItem, MenuList } from '@invoke-ai/ui-library';
import {
useAddControlLayer,
useAddGlobalReferenceImage,
useAddInpaintMask,
useAddNewRegionalGuidanceWithARefImage,
useAddRasterLayer,
useAddRegionalGuidance,
useAddRegionalReferenceImage,
} from 'features/controlLayers/hooks/addLayerHooks';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { useIsEntityTypeEnabled } from 'features/controlLayers/hooks/useIsEntityTypeEnabled';
@@ -16,13 +15,11 @@ import { PiPlusBold } from 'react-icons/pi';
export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
const { t } = useTranslation();
const isBusy = useCanvasIsBusy();
const addGlobalReferenceImage = useAddGlobalReferenceImage();
const addInpaintMask = useAddInpaintMask();
const addRegionalGuidance = useAddRegionalGuidance();
const addRegionalReferenceImage = useAddRegionalReferenceImage();
const addRegionalReferenceImage = useAddNewRegionalGuidanceWithARefImage();
const addRasterLayer = useAddRasterLayer();
const addControlLayer = useAddControlLayer();
const isReferenceImageEnabled = useIsEntityTypeEnabled('reference_image');
const isRegionalGuidanceEnabled = useIsEntityTypeEnabled('regional_guidance');
const isControlLayerEnabled = useIsEntityTypeEnabled('control_layer');
const isInpaintLayerEnabled = useIsEntityTypeEnabled('inpaint_mask');
@@ -41,11 +38,6 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
isDisabled={isBusy}
/>
<MenuList>
<MenuGroup title={t('controlLayers.global')}>
<MenuItem icon={<PiPlusBold />} onClick={addGlobalReferenceImage} isDisabled={!isReferenceImageEnabled}>
{t('controlLayers.globalReferenceImage')}
</MenuItem>
</MenuGroup>
<MenuGroup title={t('controlLayers.regional')}>
<MenuItem icon={<PiPlusBold />} onClick={addInpaintMask} isDisabled={!isInpaintLayerEnabled}>
{t('controlLayers.inpaintMask')}

View File

@@ -22,7 +22,6 @@ import {
selectEntity,
selectSelectedEntityIdentifier,
} from 'features/controlLayers/store/selectors';
import { isRenderableEntity } from 'features/controlLayers/store/types';
import { clamp, round } from 'lodash-es';
import type { KeyboardEvent } from 'react';
import { memo, useCallback, useEffect, useState } from 'react';
@@ -70,9 +69,6 @@ const selectOpacity = createSelector(selectCanvasSlice, (canvas) => {
if (!selectedEntity) {
return 1; // fallback to 100% opacity
}
if (!isRenderableEntity(selectedEntity)) {
return 1; // fallback to 100% opacity
}
// Opacity is a float from 0-1, but we want to display it as a percentage
return selectedEntity.opacity;
});
@@ -134,11 +130,7 @@ export const EntityListSelectedEntityActionBarOpacity = memo(() => {
return (
<Popover>
<FormControl
w="min-content"
gap={2}
isDisabled={selectedEntityIdentifier === null || selectedEntityIdentifier.type === 'reference_image'}
>
<FormControl w="min-content" gap={2} isDisabled={selectedEntityIdentifier === null}>
<FormLabel m={0}>{t('controlLayers.opacity')}</FormLabel>
<PopoverAnchor>
<NumberInput
@@ -167,7 +159,7 @@ export const EntityListSelectedEntityActionBarOpacity = memo(() => {
position="absolute"
insetInlineEnd={0}
h="full"
isDisabled={selectedEntityIdentifier === null || selectedEntityIdentifier.type === 'reference_image'}
isDisabled={selectedEntityIdentifier === null}
/>
</PopoverTrigger>
</NumberInput>
@@ -185,7 +177,7 @@ export const EntityListSelectedEntityActionBarOpacity = memo(() => {
marks={marks}
formatValue={formatSliderValue}
alwaysShowMarks
isDisabled={selectedEntityIdentifier === null || selectedEntityIdentifier.type === 'reference_image'}
isDisabled={selectedEntityIdentifier === null}
/>
</PopoverBody>
</PopoverContent>

View File

@@ -1,25 +1,20 @@
import { Divider, Flex, type SystemStyleObject } from '@invoke-ai/ui-library';
import { Divider, Flex } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { FocusRegionWrapper } from 'common/components/FocusRegionWrapper';
import { CanvasAddEntityButtons } from 'features/controlLayers/components/CanvasAddEntityButtons';
import { CanvasEntityList } from 'features/controlLayers/components/CanvasEntityList/CanvasEntityList';
import { EntityListSelectedEntityActionBar } from 'features/controlLayers/components/CanvasEntityList/EntityListSelectedEntityActionBar';
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { selectHasEntities } from 'features/controlLayers/store/selectors';
import { memo } from 'react';
import { ParamDenoisingStrength } from './ParamDenoisingStrength';
const FOCUS_REGION_STYLES: SystemStyleObject = {
width: 'full',
height: 'full',
};
export const CanvasLayersPanelContent = memo(() => {
export const CanvasLayersPanel = memo(() => {
const hasEntities = useAppSelector(selectHasEntities);
return (
<FocusRegionWrapper region="layers" sx={FOCUS_REGION_STYLES}>
<Flex flexDir="column" gap={2} w="full" h="full">
<CanvasManagerProviderGate>
<Flex flexDir="column" gap={2} w="full" h="full" p={2}>
<EntityListSelectedEntityActionBar />
<Divider py={0} />
<ParamDenoisingStrength />
@@ -27,8 +22,8 @@ export const CanvasLayersPanelContent = memo(() => {
{!hasEntities && <CanvasAddEntityButtons />}
{hasEntities && <CanvasEntityList />}
</Flex>
</FocusRegionWrapper>
</CanvasManagerProviderGate>
);
});
CanvasLayersPanelContent.displayName = 'CanvasLayersPanelContent';
CanvasLayersPanel.displayName = 'CanvasLayersPanel';

View File

@@ -1,140 +0,0 @@
import {
ContextMenu,
Flex,
IconButton,
Menu,
MenuButton,
MenuList,
type SystemStyleObject,
} from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { FocusRegionWrapper } from 'common/components/FocusRegionWrapper';
import { CanvasAlertsPreserveMask } from 'features/controlLayers/components/CanvasAlerts/CanvasAlertsPreserveMask';
import { CanvasAlertsSelectedEntityStatus } from 'features/controlLayers/components/CanvasAlerts/CanvasAlertsSelectedEntityStatus';
import { CanvasAlertsSendingToGallery } from 'features/controlLayers/components/CanvasAlerts/CanvasAlertsSendingTo';
import { CanvasBusySpinner } from 'features/controlLayers/components/CanvasBusySpinner';
import { CanvasContextMenuGlobalMenuItems } from 'features/controlLayers/components/CanvasContextMenu/CanvasContextMenuGlobalMenuItems';
import { CanvasContextMenuSelectedEntityMenuItems } from 'features/controlLayers/components/CanvasContextMenu/CanvasContextMenuSelectedEntityMenuItems';
import { CanvasDropArea } from 'features/controlLayers/components/CanvasDropArea';
import { Filter } from 'features/controlLayers/components/Filters/Filter';
import { CanvasHUD } from 'features/controlLayers/components/HUD/CanvasHUD';
import { InvokeCanvasComponent } from 'features/controlLayers/components/InvokeCanvasComponent';
import { SelectObject } from 'features/controlLayers/components/SelectObject/SelectObject';
import { StagingAreaIsStagingGate } from 'features/controlLayers/components/StagingArea/StagingAreaIsStagingGate';
import { StagingAreaToolbar } from 'features/controlLayers/components/StagingArea/StagingAreaToolbar';
import { CanvasToolbar } from 'features/controlLayers/components/Toolbar/CanvasToolbar';
import { Transform } from 'features/controlLayers/components/Transform/Transform';
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { selectDynamicGrid, selectShowHUD } from 'features/controlLayers/store/canvasSettingsSlice';
import { GatedImageViewer } from 'features/gallery/components/ImageViewer/ImageViewer';
import { memo, useCallback } from 'react';
import { PiDotsThreeOutlineVerticalFill } from 'react-icons/pi';
import { CanvasAlertsInvocationProgress } from './CanvasAlerts/CanvasAlertsInvocationProgress';
const FOCUS_REGION_STYLES: SystemStyleObject = {
width: 'full',
height: 'full',
};
const MenuContent = () => {
return (
<CanvasManagerProviderGate>
<MenuList>
<CanvasContextMenuSelectedEntityMenuItems />
<CanvasContextMenuGlobalMenuItems />
</MenuList>
</CanvasManagerProviderGate>
);
};
export const CanvasMainPanelContent = memo(() => {
const dynamicGrid = useAppSelector(selectDynamicGrid);
const showHUD = useAppSelector(selectShowHUD);
const renderMenu = useCallback(() => {
return <MenuContent />;
}, []);
return (
<FocusRegionWrapper region="canvas" sx={FOCUS_REGION_STYLES}>
<Flex
tabIndex={-1}
borderRadius="base"
position="relative"
flexDirection="column"
height="full"
width="full"
gap={2}
alignItems="center"
justifyContent="center"
overflow="hidden"
>
<CanvasManagerProviderGate>
<CanvasToolbar />
</CanvasManagerProviderGate>
<ContextMenu<HTMLDivElement> renderMenu={renderMenu} withLongPress={false}>
{(ref) => (
<Flex
ref={ref}
position="relative"
w="full"
h="full"
bg={dynamicGrid ? 'base.850' : 'base.900'}
borderRadius="base"
overflow="hidden"
>
<InvokeCanvasComponent />
<CanvasManagerProviderGate>
<Flex
position="absolute"
flexDir="column"
top={1}
insetInlineStart={1}
pointerEvents="none"
gap={2}
alignItems="flex-start"
>
{showHUD && <CanvasHUD />}
<CanvasAlertsSelectedEntityStatus />
<CanvasAlertsPreserveMask />
<CanvasAlertsSendingToGallery />
<CanvasAlertsInvocationProgress />
</Flex>
<Flex position="absolute" top={1} insetInlineEnd={1}>
<Menu>
<MenuButton as={IconButton} icon={<PiDotsThreeOutlineVerticalFill />} colorScheme="base" />
<MenuContent />
</Menu>
</Flex>
<Flex position="absolute" bottom={4} insetInlineEnd={4}>
<CanvasBusySpinner />
</Flex>
</CanvasManagerProviderGate>
</Flex>
)}
</ContextMenu>
<Flex position="absolute" bottom={4} gap={2} align="center" justify="center">
<CanvasManagerProviderGate>
<StagingAreaIsStagingGate>
<StagingAreaToolbar />
</StagingAreaIsStagingGate>
</CanvasManagerProviderGate>
</Flex>
<Flex position="absolute" bottom={4}>
<CanvasManagerProviderGate>
<Filter />
<Transform />
<SelectObject />
</CanvasManagerProviderGate>
</Flex>
<CanvasManagerProviderGate>
<CanvasDropArea />
</CanvasManagerProviderGate>
<GatedImageViewer />
</Flex>
</FocusRegionWrapper>
);
});
CanvasMainPanelContent.displayName = 'CanvasMainPanelContent';

View File

@@ -1,272 +0,0 @@
import { combine } from '@atlaskit/pragmatic-drag-and-drop/combine';
import { dropTargetForElements, monitorForElements } from '@atlaskit/pragmatic-drag-and-drop/element/adapter';
import { dropTargetForExternal, monitorForExternal } from '@atlaskit/pragmatic-drag-and-drop/external/adapter';
import { Box, Button, Spacer, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector, useAppStore } from 'app/store/storeHooks';
import { CanvasLayersPanelContent } from 'features/controlLayers/components/CanvasLayersPanelContent';
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { selectEntityCountActive } from 'features/controlLayers/store/selectors';
import { multipleImageDndSource, singleImageDndSource } from 'features/dnd/dnd';
import { DndDropOverlay } from 'features/dnd/DndDropOverlay';
import type { DndTargetState } from 'features/dnd/types';
import GalleryPanelContent from 'features/gallery/components/GalleryPanelContent';
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
import { selectActiveTabCanvasRightPanel } from 'features/ui/store/uiSelectors';
import { activeTabCanvasRightPanelChanged } from 'features/ui/store/uiSlice';
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
export const CanvasRightPanel = memo(() => {
const { t } = useTranslation();
const activeTab = useAppSelector(selectActiveTabCanvasRightPanel);
const imageViewer = useImageViewer();
const dispatch = useAppDispatch();
const tabIndex = useMemo(() => {
if (activeTab === 'gallery') {
return 1;
} else {
return 0;
}
}, [activeTab]);
const onClickViewerToggleButton = useCallback(() => {
if (activeTab !== 'gallery') {
dispatch(activeTabCanvasRightPanelChanged('gallery'));
}
imageViewer.toggle();
}, [imageViewer, activeTab, dispatch]);
const onChangeTab = useCallback(
(index: number) => {
if (index === 0) {
dispatch(activeTabCanvasRightPanelChanged('layers'));
} else {
dispatch(activeTabCanvasRightPanelChanged('gallery'));
}
},
[dispatch]
);
useRegisteredHotkeys({
id: 'toggleViewer',
category: 'viewer',
callback: imageViewer.toggle,
dependencies: [imageViewer],
});
return (
<Tabs index={tabIndex} onChange={onChangeTab} w="full" h="full" display="flex" flexDir="column">
<TabList alignItems="center">
<PanelTabs />
<Spacer />
<Button size="sm" variant="ghost" onClick={onClickViewerToggleButton}>
{imageViewer.isOpen ? t('gallery.closeViewer') : t('gallery.openViewer')}
</Button>
</TabList>
<TabPanels w="full" h="full">
<TabPanel w="full" h="full" p={0} pt={3}>
<CanvasManagerProviderGate>
<CanvasLayersPanelContent />
</CanvasManagerProviderGate>
</TabPanel>
<TabPanel w="full" h="full" p={0} pt={3}>
<GalleryPanelContent />
</TabPanel>
</TabPanels>
</Tabs>
);
});
CanvasRightPanel.displayName = 'CanvasRightPanel';
const PanelTabs = memo(() => {
const { t } = useTranslation();
const store = useAppStore();
const activeEntityCount = useAppSelector(selectEntityCountActive);
const [layersTabDndState, setLayersTabDndState] = useState<DndTargetState>('idle');
const [galleryTabDndState, setGalleryTabDndState] = useState<DndTargetState>('idle');
const layersTabRef = useRef<HTMLDivElement>(null);
const galleryTabRef = useRef<HTMLDivElement>(null);
const timeoutRef = useRef<number | null>(null);
const layersTabLabel = useMemo(() => {
if (activeEntityCount === 0) {
return t('controlLayers.layer_other');
}
return `${t('controlLayers.layer_other')} (${activeEntityCount})`;
}, [activeEntityCount, t]);
useEffect(() => {
if (!layersTabRef.current) {
return;
}
const getIsOnLayersTab = () => selectActiveTabCanvasRightPanel(store.getState()) === 'layers';
const onDragEnter = () => {
// If we are already on the layers tab, do nothing
if (getIsOnLayersTab()) {
return;
}
// Else set the state to active and switch to the layers tab after a timeout
setLayersTabDndState('over');
timeoutRef.current = window.setTimeout(() => {
timeoutRef.current = null;
store.dispatch(activeTabCanvasRightPanelChanged('layers'));
// When we switch tabs, the other tab should be pending
setLayersTabDndState('idle');
setGalleryTabDndState('potential');
}, 300);
};
const onDragLeave = () => {
// Set the state to idle or pending depending on the current tab
if (getIsOnLayersTab()) {
setLayersTabDndState('idle');
} else {
setLayersTabDndState('potential');
}
// Abort the tab switch if it hasn't happened yet
if (timeoutRef.current !== null) {
clearTimeout(timeoutRef.current);
}
};
const onDragStart = () => {
// Set the state to pending when a drag starts
setLayersTabDndState('potential');
};
return combine(
dropTargetForElements({
element: layersTabRef.current,
onDragEnter,
onDragLeave,
}),
monitorForElements({
canMonitor: ({ source }) => {
if (!singleImageDndSource.typeGuard(source.data) && !multipleImageDndSource.typeGuard(source.data)) {
return false;
}
// Only monitor if we are not already on the gallery tab
return !getIsOnLayersTab();
},
onDragStart,
}),
dropTargetForExternal({
element: layersTabRef.current,
onDragEnter,
onDragLeave,
}),
monitorForExternal({
canMonitor: () => !getIsOnLayersTab(),
onDragStart,
})
);
}, [store]);
useEffect(() => {
if (!galleryTabRef.current) {
return;
}
const getIsOnGalleryTab = () => selectActiveTabCanvasRightPanel(store.getState()) === 'gallery';
const onDragEnter = () => {
// If we are already on the gallery tab, do nothing
if (getIsOnGalleryTab()) {
return;
}
// Else set the state to active and switch to the gallery tab after a timeout
setGalleryTabDndState('over');
timeoutRef.current = window.setTimeout(() => {
timeoutRef.current = null;
store.dispatch(activeTabCanvasRightPanelChanged('gallery'));
// When we switch tabs, the other tab should be pending
setGalleryTabDndState('idle');
setLayersTabDndState('potential');
}, 300);
};
const onDragLeave = () => {
// Set the state to idle or pending depending on the current tab
if (getIsOnGalleryTab()) {
setGalleryTabDndState('idle');
} else {
setGalleryTabDndState('potential');
}
// Abort the tab switch if it hasn't happened yet
if (timeoutRef.current !== null) {
clearTimeout(timeoutRef.current);
}
};
const onDragStart = () => {
// Set the state to pending when a drag starts
setGalleryTabDndState('potential');
};
return combine(
dropTargetForElements({
element: galleryTabRef.current,
onDragEnter,
onDragLeave,
}),
monitorForElements({
canMonitor: ({ source }) => {
if (!singleImageDndSource.typeGuard(source.data) && !multipleImageDndSource.typeGuard(source.data)) {
return false;
}
// Only monitor if we are not already on the gallery tab
return !getIsOnGalleryTab();
},
onDragStart,
}),
dropTargetForExternal({
element: galleryTabRef.current,
onDragEnter,
onDragLeave,
}),
monitorForExternal({
canMonitor: () => !getIsOnGalleryTab(),
onDragStart,
})
);
}, [store]);
useEffect(() => {
const onDrop = () => {
// Reset the dnd state when a drop happens
setGalleryTabDndState('idle');
setLayersTabDndState('idle');
};
const cleanup = combine(monitorForElements({ onDrop }), monitorForExternal({ onDrop }));
return () => {
cleanup();
if (timeoutRef.current !== null) {
clearTimeout(timeoutRef.current);
}
};
}, []);
return (
<>
<Tab ref={layersTabRef} position="relative" w={32}>
<Box as="span" w="full">
{layersTabLabel}
</Box>
<DndDropOverlay dndState={layersTabDndState} withBackdrop={false} />
</Tab>
<Tab ref={galleryTabRef} position="relative" w={32}>
<Box as="span" w="full">
{t('gallery.gallery')}
</Box>
<DndDropOverlay dndState={galleryTabDndState} withBackdrop={false} />
</Tab>
</>
);
});
PanelTabs.displayName = 'PanelTabs';

View File

@@ -1,35 +0,0 @@
import { Spacer } from '@invoke-ai/ui-library';
import { CanvasEntityContainer } from 'features/controlLayers/components/CanvasEntityList/CanvasEntityContainer';
import { CanvasEntityHeader } from 'features/controlLayers/components/common/CanvasEntityHeader';
import { CanvasEntityHeaderCommonActions } from 'features/controlLayers/components/common/CanvasEntityHeaderCommonActions';
import { CanvasEntityEditableTitle } from 'features/controlLayers/components/common/CanvasEntityTitleEdit';
import { IPAdapterSettings } from 'features/controlLayers/components/IPAdapter/IPAdapterSettings';
import { CanvasEntityStateGate } from 'features/controlLayers/contexts/CanvasEntityStateGate';
import { EntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import { memo, useMemo } from 'react';
type Props = {
id: string;
};
export const IPAdapter = memo(({ id }: Props) => {
const entityIdentifier = useMemo<CanvasEntityIdentifier>(() => ({ id, type: 'reference_image' }), [id]);
return (
<EntityIdentifierContext.Provider value={entityIdentifier}>
<CanvasEntityStateGate entityIdentifier={entityIdentifier}>
<CanvasEntityContainer>
<CanvasEntityHeader ps={4} py={5}>
<CanvasEntityEditableTitle />
<Spacer />
<CanvasEntityHeaderCommonActions />
</CanvasEntityHeader>
<IPAdapterSettings />
</CanvasEntityContainer>
</CanvasEntityStateGate>
</EntityIdentifierContext.Provider>
);
});
IPAdapter.displayName = 'IPAdapter';

View File

@@ -1,37 +0,0 @@
/* eslint-disable i18next/no-literal-string */
import { createSelector } from '@reduxjs/toolkit';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { CanvasEntityGroupList } from 'features/controlLayers/components/CanvasEntityList/CanvasEntityGroupList';
import { IPAdapter } from 'features/controlLayers/components/IPAdapter/IPAdapter';
import { selectCanvasSlice, selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
import { getEntityIdentifier } from 'features/controlLayers/store/types';
import { memo } from 'react';
const selectEntityIdentifiers = createMemoizedSelector(selectCanvasSlice, (canvas) => {
return canvas.referenceImages.entities.map(getEntityIdentifier).toReversed();
});
const selectIsSelected = createSelector(selectSelectedEntityIdentifier, (selectedEntityIdentifier) => {
return selectedEntityIdentifier?.type === 'reference_image';
});
export const IPAdapterList = memo(() => {
const isSelected = useAppSelector(selectIsSelected);
const entityIdentifiers = useAppSelector(selectEntityIdentifiers);
if (entityIdentifiers.length === 0) {
return null;
}
if (entityIdentifiers.length > 0) {
return (
<CanvasEntityGroupList type="reference_image" isSelected={isSelected} entityIdentifiers={entityIdentifiers}>
{entityIdentifiers.map((entityIdentifiers) => (
<IPAdapter key={entityIdentifiers.id} id={entityIdentifiers.id} />
))}
</CanvasEntityGroupList>
);
}
});
IPAdapterList.displayName = 'IPAdapterList';

View File

@@ -1,180 +0,0 @@
import { Flex, IconButton } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { BeginEndStepPct } from 'features/controlLayers/components/common/BeginEndStepPct';
import { CanvasEntitySettingsWrapper } from 'features/controlLayers/components/common/CanvasEntitySettingsWrapper';
import { Weight } from 'features/controlLayers/components/common/Weight';
import { CLIPVisionModel } from 'features/controlLayers/components/IPAdapter/CLIPVisionModel';
import { FLUXReduxImageInfluence } from 'features/controlLayers/components/IPAdapter/FLUXReduxImageInfluence';
import { GlobalReferenceImageModel } from 'features/controlLayers/components/IPAdapter/GlobalReferenceImageModel';
import { IPAdapterMethod } from 'features/controlLayers/components/IPAdapter/IPAdapterMethod';
import { IPAdapterSettingsEmptyState } from 'features/controlLayers/components/IPAdapter/IPAdapterSettingsEmptyState';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { usePullBboxIntoGlobalReferenceImage } from 'features/controlLayers/hooks/saveCanvasHooks';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import {
referenceImageIPAdapterBeginEndStepPctChanged,
referenceImageIPAdapterCLIPVisionModelChanged,
referenceImageIPAdapterFLUXReduxImageInfluenceChanged,
referenceImageIPAdapterImageChanged,
referenceImageIPAdapterMethodChanged,
referenceImageIPAdapterModelChanged,
referenceImageIPAdapterWeightChanged,
} from 'features/controlLayers/store/canvasSlice';
import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasSlice, selectEntity, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
import type {
CanvasEntityIdentifier,
CLIPVisionModelV2,
FLUXReduxImageInfluence as FLUXReduxImageInfluenceType,
IPMethodV2,
} from 'features/controlLayers/store/types';
import type { SetGlobalReferenceImageDndTargetData } from 'features/dnd/dnd';
import { setGlobalReferenceImageDndTarget } from 'features/dnd/dnd';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiBoundingBoxBold } from 'react-icons/pi';
import type { ApiModelConfig, FLUXReduxModelConfig, ImageDTO, IPAdapterModelConfig } from 'services/api/types';
import { IPAdapterImagePreview } from './IPAdapterImagePreview';
const buildSelectIPAdapter = (entityIdentifier: CanvasEntityIdentifier<'reference_image'>) =>
createSelector(
selectCanvasSlice,
(canvas) => selectEntityOrThrow(canvas, entityIdentifier, 'IPAdapterSettings').ipAdapter
);
const IPAdapterSettingsContent = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const entityIdentifier = useEntityIdentifierContext('reference_image');
const selectIPAdapter = useMemo(() => buildSelectIPAdapter(entityIdentifier), [entityIdentifier]);
const ipAdapter = useAppSelector(selectIPAdapter);
const onChangeBeginEndStepPct = useCallback(
(beginEndStepPct: [number, number]) => {
dispatch(referenceImageIPAdapterBeginEndStepPctChanged({ entityIdentifier, beginEndStepPct }));
},
[dispatch, entityIdentifier]
);
const onChangeWeight = useCallback(
(weight: number) => {
dispatch(referenceImageIPAdapterWeightChanged({ entityIdentifier, weight }));
},
[dispatch, entityIdentifier]
);
const onChangeIPMethod = useCallback(
(method: IPMethodV2) => {
dispatch(referenceImageIPAdapterMethodChanged({ entityIdentifier, method }));
},
[dispatch, entityIdentifier]
);
const onChangeFLUXReduxImageInfluence = useCallback(
(imageInfluence: FLUXReduxImageInfluenceType) => {
dispatch(referenceImageIPAdapterFLUXReduxImageInfluenceChanged({ entityIdentifier, imageInfluence }));
},
[dispatch, entityIdentifier]
);
const onChangeModel = useCallback(
(modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ApiModelConfig) => {
dispatch(referenceImageIPAdapterModelChanged({ entityIdentifier, modelConfig }));
},
[dispatch, entityIdentifier]
);
const onChangeCLIPVisionModel = useCallback(
(clipVisionModel: CLIPVisionModelV2) => {
dispatch(referenceImageIPAdapterCLIPVisionModelChanged({ entityIdentifier, clipVisionModel }));
},
[dispatch, entityIdentifier]
);
const onChangeImage = useCallback(
(imageDTO: ImageDTO | null) => {
dispatch(referenceImageIPAdapterImageChanged({ entityIdentifier, imageDTO }));
},
[dispatch, entityIdentifier]
);
const dndTargetData = useMemo<SetGlobalReferenceImageDndTargetData>(
() => setGlobalReferenceImageDndTarget.getData({ entityIdentifier }, ipAdapter.image?.image_name),
[entityIdentifier, ipAdapter.image?.image_name]
);
const pullBboxIntoIPAdapter = usePullBboxIntoGlobalReferenceImage(entityIdentifier);
const isBusy = useCanvasIsBusy();
const isFLUX = useAppSelector(selectIsFLUX);
return (
<CanvasEntitySettingsWrapper>
<Flex flexDir="column" gap={2} position="relative" w="full">
<Flex gap={2} alignItems="center" w="full">
<GlobalReferenceImageModel modelKey={ipAdapter.model?.key ?? null} onChangeModel={onChangeModel} />
{ipAdapter.type === 'ip_adapter' && (
<CLIPVisionModel model={ipAdapter.clipVisionModel} onChange={onChangeCLIPVisionModel} />
)}
<IconButton
onClick={pullBboxIntoIPAdapter}
isDisabled={isBusy}
variant="ghost"
aria-label={t('controlLayers.pullBboxIntoReferenceImage')}
tooltip={t('controlLayers.pullBboxIntoReferenceImage')}
icon={<PiBoundingBoxBold />}
/>
</Flex>
<Flex gap={2} w="full">
{ipAdapter.type === 'ip_adapter' && (
<Flex flexDir="column" gap={2} w="full">
{!isFLUX && <IPAdapterMethod method={ipAdapter.method} onChange={onChangeIPMethod} />}
<Weight weight={ipAdapter.weight} onChange={onChangeWeight} />
<BeginEndStepPct beginEndStepPct={ipAdapter.beginEndStepPct} onChange={onChangeBeginEndStepPct} />
</Flex>
)}
{ipAdapter.type === 'flux_redux' && (
<Flex flexDir="column" gap={2} w="full" alignItems="flex-start">
<FLUXReduxImageInfluence
imageInfluence={ipAdapter.imageInfluence ?? 'lowest'}
onChange={onChangeFLUXReduxImageInfluence}
/>
</Flex>
)}
<Flex alignItems="center" justifyContent="center" h={32} w={32} aspectRatio="1/1" flexGrow={1}>
<IPAdapterImagePreview
image={ipAdapter.image}
onChangeImage={onChangeImage}
dndTarget={setGlobalReferenceImageDndTarget}
dndTargetData={dndTargetData}
/>
</Flex>
</Flex>
</Flex>
</CanvasEntitySettingsWrapper>
);
});
IPAdapterSettingsContent.displayName = 'IPAdapterSettingsContent';
const buildSelectIPAdapterHasImage = (entityIdentifier: CanvasEntityIdentifier<'reference_image'>) =>
createSelector(selectCanvasSlice, (canvas) => {
const referenceImage = selectEntity(canvas, entityIdentifier);
return !!referenceImage && referenceImage.ipAdapter.image !== null;
});
export const IPAdapterSettings = memo(() => {
const entityIdentifier = useEntityIdentifierContext('reference_image');
const selectIPAdapterHasImage = useMemo(() => buildSelectIPAdapterHasImage(entityIdentifier), [entityIdentifier]);
const hasImage = useAppSelector(selectIPAdapterHasImage);
if (!hasImage) {
return <IPAdapterSettingsEmptyState />;
}
return <IPAdapterSettingsContent />;
});
IPAdapterSettings.displayName = 'IPAdapterSettings';

View File

@@ -2,8 +2,7 @@ import { Checkbox, ConfirmationAlertDialog, Flex, FormControl, FormLabel, Text }
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
import { buildUseBoolean } from 'common/hooks/useBoolean';
import { newCanvasSessionRequested, newGallerySessionRequested } from 'features/controlLayers/store/actions';
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
import { canvasSessionReset, generateSessionReset } from 'features/controlLayers/store/canvasStagingAreaSlice';
import {
selectSystemShouldConfirmOnNewSession,
shouldConfirmOnNewSessionToggled,
@@ -17,15 +16,13 @@ const [useNewCanvasSessionDialog] = buildUseBoolean(false);
export const useNewGallerySession = () => {
const dispatch = useAppDispatch();
const imageViewer = useImageViewer();
const shouldConfirmOnNewSession = useAppSelector(selectSystemShouldConfirmOnNewSession);
const newSessionDialog = useNewGallerySessionDialog();
const newGallerySessionImmediate = useCallback(() => {
dispatch(newGallerySessionRequested());
imageViewer.open();
dispatch(generateSessionReset());
dispatch(activeTabCanvasRightPanelChanged('gallery'));
}, [dispatch, imageViewer]);
}, [dispatch]);
const newGallerySessionWithDialog = useCallback(() => {
if (shouldConfirmOnNewSession) {
@@ -40,15 +37,13 @@ export const useNewGallerySession = () => {
export const useNewCanvasSession = () => {
const dispatch = useAppDispatch();
const imageViewer = useImageViewer();
const shouldConfirmOnNewSession = useAppSelector(selectSystemShouldConfirmOnNewSession);
const newSessionDialog = useNewCanvasSessionDialog();
const newCanvasSessionImmediate = useCallback(() => {
dispatch(newCanvasSessionRequested());
imageViewer.close();
dispatch(canvasSessionReset());
dispatch(activeTabCanvasRightPanelChanged('layers'));
}, [dispatch, imageViewer]);
}, [dispatch]);
const newCanvasSessionWithDialog = useCallback(() => {
if (shouldConfirmOnNewSession) {

View File

@@ -1,5 +1,5 @@
import { MenuItem } from '@invoke-ai/ui-library';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useRefImageIdContext } from 'features/controlLayers/contexts/RefImageIdContext';
import { usePullBboxIntoGlobalReferenceImage } from 'features/controlLayers/hooks/saveCanvasHooks';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { memo } from 'react';
@@ -8,8 +8,8 @@ import { PiBoundingBoxBold } from 'react-icons/pi';
export const IPAdapterMenuItemPullBbox = memo(() => {
const { t } = useTranslation();
const entityIdentifier = useEntityIdentifierContext('reference_image');
const pullBboxIntoIPAdapter = usePullBboxIntoGlobalReferenceImage(entityIdentifier);
const id = useRefImageIdContext();
const pullBboxIntoIPAdapter = usePullBboxIntoGlobalReferenceImage(id);
const isBusy = useCanvasIsBusy();
return (

View File

@@ -3,7 +3,7 @@ import { IconMenuItemGroup } from 'common/components/IconMenuItem';
import { CanvasEntityMenuItemsArrange } from 'features/controlLayers/components/common/CanvasEntityMenuItemsArrange';
import { CanvasEntityMenuItemsDelete } from 'features/controlLayers/components/common/CanvasEntityMenuItemsDelete';
import { CanvasEntityMenuItemsDuplicate } from 'features/controlLayers/components/common/CanvasEntityMenuItemsDuplicate';
import { IPAdapterMenuItemPullBbox } from 'features/controlLayers/components/IPAdapter/IPAdapterMenuItemPullBbox';
import { IPAdapterMenuItemPullBbox } from 'features/controlLayers/components/RefImage/IPAdapterMenuItemPullBbox';
import { memo } from 'react';
export const IPAdapterMenuItems = memo(() => {

View File

@@ -0,0 +1,238 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import {
Divider,
Flex,
Icon,
IconButton,
Image,
Popover,
PopoverAnchor,
PopoverArrow,
PopoverBody,
PopoverContent,
Portal,
Skeleton,
Text,
} from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { POPPER_MODIFIERS } from 'common/components/InformationalPopover/constants';
import type { UseDisclosure } from 'common/hooks/useBoolean';
import { useDisclosure } from 'common/hooks/useBoolean';
import { DEFAULT_FILTER, useFilterableOutsideClick } from 'common/hooks/useFilterableOutsideClick';
import { RefImageHeader } from 'features/controlLayers/components/RefImage/RefImageHeader';
import { RefImageSettings } from 'features/controlLayers/components/RefImage/RefImageSettings';
import { useRefImageEntity } from 'features/controlLayers/components/RefImage/useRefImageEntity';
import { useRefImageIdContext } from 'features/controlLayers/contexts/RefImageIdContext';
import { isIPAdapterConfig } from 'features/controlLayers/store/types';
import { round } from 'lodash-es';
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { PiExclamationMarkBold, PiImageBold } from 'react-icons/pi';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
// There is some awkwardness here with closing the popover when clicking outside of it, related to Chakra's
// handling of refs, portals, outside clicks, and a race condition with framer-motion animations that can leave
// the popover closed when its internal state is still open.
//
// We have to manually manage the popover open state to work around the race condition, and then have to do special
// handling to close the popover when clicking outside of it.
// We have to reach outside react to identify the popover trigger element instead of using refs, thanks to how Chakra
// handles refs for PopoverAnchor internally. Maybe there is some way to merge them but I couldn't figure it out.
const getRefImagePopoverTriggerId = (id: string) => `ref-image-popover-trigger-${id}`;
export const RefImage = memo(() => {
const id = useRefImageIdContext();
const ref = useRef<HTMLDivElement>(null);
const disclosure = useDisclosure(false);
// This filter prevents the popover from closing when clicking on a sibling portal element, like the dropdown menu
// inside the ref image settings popover. It also prevents the popover from closing when clicking on the popover's
// own trigger element.
const filter = useCallback(
(el: HTMLElement | SVGElement) => {
return DEFAULT_FILTER(el) || el.id === getRefImagePopoverTriggerId(id);
},
[id]
);
useFilterableOutsideClick({ ref, handler: disclosure.close, filter });
return (
<Popover
// The popover contains a react-select component, which uses a portal to render its options. This portal
// is itself not lazy. As a result, if we do not unmount the popover when it is closed, the react-select
// component still exists but is invisible, and intercepts clicks!
isLazy
lazyBehavior="unmount"
isOpen={disclosure.isOpen}
closeOnBlur={false}
modifiers={POPPER_MODIFIERS}
>
<Thumbnail disclosure={disclosure} />
<Portal>
<PopoverContent ref={ref} w={400}>
<PopoverArrow />
<PopoverBody>
<Flex flexDir="column" gap={2} w="full" h="full">
<RefImageHeader />
<Divider />
<RefImageSettings />
</Flex>
</PopoverBody>
</PopoverContent>
</Portal>
</Popover>
);
});
RefImage.displayName = 'RefImage';
const baseSx: SystemStyleObject = {
opacity: 0.7,
transitionProperty: 'opacity',
transitionDuration: 'normal',
position: 'relative',
_hover: {
opacity: 1,
},
'&[data-is-open="true"]': {
opacity: 1,
},
'&[data-is-error="true"]': {
borderColor: 'error.500',
borderWidth: 2,
},
};
const weightDisplaySx: SystemStyleObject = {
pointerEvents: 'none',
transitionProperty: 'opacity',
transitionDuration: 'normal',
opacity: 0,
'&[data-visible="true"]': {
opacity: 1,
},
};
const getImageSxWithWeight = (weight: number): SystemStyleObject => {
const fillPercentage = Math.max(0, Math.min(100, weight * 100));
return {
...baseSx,
_after: {
content: '""',
position: 'absolute',
inset: 0,
background: `linear-gradient(to top, transparent ${fillPercentage}%, rgba(0, 0, 0, 0.8) ${fillPercentage}%)`,
pointerEvents: 'none',
borderRadius: 'base',
},
};
};
const Thumbnail = memo(({ disclosure }: { disclosure: UseDisclosure }) => {
const id = useRefImageIdContext();
const entity = useRefImageEntity(id);
const [showWeightDisplay, setShowWeightDisplay] = useState(false);
const { data: imageDTO } = useGetImageDTOQuery(entity.config.image?.image_name ?? skipToken);
const sx = useMemo(() => {
if (!isIPAdapterConfig(entity.config)) {
return baseSx;
}
return getImageSxWithWeight(entity.config.weight);
}, [entity.config]);
useEffect(() => {
if (!isIPAdapterConfig(entity.config)) {
return;
}
setShowWeightDisplay(true);
const timeout = window.setTimeout(() => {
setShowWeightDisplay(false);
}, 1000);
return () => {
window.clearTimeout(timeout);
};
}, [entity.config]);
if (!entity.config.image) {
return (
<PopoverAnchor>
<IconButton
id={getRefImagePopoverTriggerId(id)}
aria-label="Open Reference Image Settings"
h="full"
variant="ghost"
aspectRatio="1/1"
borderWidth="2px !important"
borderStyle="dashed !important"
borderColor="errorAlpha.500"
borderRadius="base"
icon={<PiImageBold />}
colorScheme="error"
onClick={disclosure.toggle}
flexShrink={0}
/>
</PopoverAnchor>
);
}
return (
<PopoverAnchor>
<Flex
position="relative"
borderWidth={1}
borderStyle="solid"
borderRadius="base"
aspectRatio="1/1"
maxW="full"
maxH="full"
flexShrink={0}
sx={sx}
data-is-open={disclosure.isOpen}
data-is-error={!entity.config.model}
id={getRefImagePopoverTriggerId(id)}
role="button"
onClick={disclosure.toggle}
cursor="pointer"
>
<Image
src={imageDTO?.thumbnail_url}
objectFit="contain"
aspectRatio="1/1"
height={imageDTO?.height}
fallback={<Skeleton h="full" aspectRatio="1/1" />}
maxW="full"
maxH="full"
borderRadius="base"
/>
{isIPAdapterConfig(entity.config) && (
<Flex
position="absolute"
inset={0}
fontWeight="semibold"
alignItems="center"
justifyContent="center"
zIndex={1}
data-visible={showWeightDisplay}
sx={weightDisplaySx}
>
<Text filter="drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 2px rgba(0, 0, 0, 1))">
{`${round(entity.config.weight * 100, 2)}%`}
</Text>
</Flex>
)}
{!entity.config.model && (
<Icon
position="absolute"
top="50%"
left="50%"
transform="translateX(-50%) translateY(-50%)"
filter="drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 2px rgba(0, 0, 0, 1))"
color="error.500"
boxSize={16}
as={PiExclamationMarkBold}
/>
)}
</Flex>
</PopoverAnchor>
);
});
Thumbnail.displayName = 'Thumbnail';

View File

@@ -0,0 +1,41 @@
import { Flex, IconButton, Text } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useRefImageEntity } from 'features/controlLayers/components/RefImage/useRefImageEntity';
import { useRefImageIdContext } from 'features/controlLayers/contexts/RefImageIdContext';
import { refImageDeleted } from 'features/controlLayers/store/refImagesSlice';
import { memo, useCallback } from 'react';
import { PiTrashBold } from 'react-icons/pi';
export const RefImageHeader = memo(() => {
const id = useRefImageIdContext();
const dispatch = useAppDispatch();
const entity = useRefImageEntity(id);
const deleteRefImage = useCallback(() => {
dispatch(refImageDeleted({ id }));
}, [dispatch, id]);
return (
<Flex justifyContent="space-between" alignItems="center" w="full">
{entity.config.image !== null && (
<Text fontWeight="semibold" color="base.300">
Reference Image
</Text>
)}
{entity.config.image === null && (
<Text fontWeight="semibold" color="base.300">
No Reference Image Selected
</Text>
)}
<IconButton
size="xs"
variant="link"
alignSelf="stretch"
icon={<PiTrashBold />}
onClick={deleteRefImage}
aria-label="Delete reference image"
colorScheme="error"
/>
</Flex>
);
});
RefImageHeader.displayName = 'RefImageHeader';

View File

@@ -1,7 +1,7 @@
import { Flex } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { skipToken } from '@reduxjs/toolkit/query';
import { UploadImageButton } from 'common/hooks/useImageUploadButton';
import { UploadImageIconButton } from 'common/hooks/useImageUploadButton';
import type { ImageWithDims } from 'features/controlLayers/store/types';
import type { setGlobalReferenceImageDndTarget, setRegionalGuidanceReferenceImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
@@ -21,7 +21,7 @@ type Props<T extends typeof setGlobalReferenceImageDndTarget | typeof setRegiona
dndTargetData: ReturnType<T['getData']>;
};
export const IPAdapterImagePreview = memo(
export const RefImageImage = memo(
<T extends typeof setGlobalReferenceImageDndTarget | typeof setRegionalGuidanceReferenceImageDndTarget>({
image,
onChangeImage,
@@ -51,7 +51,7 @@ export const IPAdapterImagePreview = memo(
return (
<Flex position="relative" w="full" h="full" alignItems="center" data-error={!imageDTO && !image?.image_name}>
{!imageDTO && (
<UploadImageButton
<UploadImageIconButton
w="full"
h="full"
isError={!imageDTO && !image?.image_name}
@@ -77,4 +77,4 @@ export const IPAdapterImagePreview = memo(
}
);
IPAdapterImagePreview.displayName = 'IPAdapterImagePreview';
RefImageImage.displayName = 'RefImageImage';

Some files were not shown because too many files have changed in this diff Show More