Compare commits

...

389 Commits

Author SHA1 Message Date
psychedelicious
82fb897b62 chore(ui): lint 2025-07-12 14:56:57 +10:00
psychedelicious
192b00d969 chore: bump version to v6.0.2 2025-07-12 14:56:57 +10:00
psychedelicious
7bb25ef1b4 fix(ui): gallery dnd 2025-07-12 14:56:57 +10:00
psychedelicious
62f52c74a8 fix(ui): linked negative style prompt not passed in
Closes #8256
2025-07-12 10:22:17 +10:00
psychedelicious
97439c1daa fix(ui): native context menu shown on right click on short fat images
Closes #8254
2025-07-12 10:22:17 +10:00
psychedelicious
b23bff1b53 fix(ui): center staging area images 2025-07-12 10:22:17 +10:00
psychedelicious
d9a1efbabf fix(ui): staging area images may be slightly too large 2025-07-12 10:22:17 +10:00
psychedelicious
d4e903ee2d chore: bump version to v6.0.1 2025-07-12 10:22:17 +10:00
Kevin Turner
bb3e5d16d8 feat(Model Manager): refuse to download a file when there's insufficient space 2025-07-12 10:14:25 +10:00
psychedelicious
e62d3f01a8 feat(app): better error message for failed model probe
- Old: No valid config found
- New: Unable to determine model type
2025-07-11 23:35:43 +10:00
psychedelicious
757ecdbf82 build(ui): downgrade idb-keyval
We have increased error rates after updating this package. Let's try
downgrading to see if that fixes the issue.
2025-07-11 15:00:10 +10:00
psychedelicious
694c85b041 fix(ui): language file filenames
Need to replace the underscores w/ dashes - this was missed in #8246.
2025-07-11 14:21:41 +10:00
psychedelicious
988d7ba24c chore: bump version to v6.0.1rc1 2025-07-11 09:05:24 +10:00
psychedelicious
ac981879ef fix(ui): runtime errors related to calling reduce on array iterator
Fix an issue in certain browsers/builds causing a runtime error.

A zod enum has a .options property, which is an array of all the options
for the enum. This is handy for when you need to derive something from a
zod schema.

In this case, we represented the possible focus regions in the zod enum,
then derived a mapping of region names to set of target HTML elements.
Why isn't important, but suffice to say, we were using the .options
property for this.

But actually, we were using .options.values(), then calling .reduce() on
that. An array's .values() method returns an _array iterator_. Array
iterators do not have .reduce() methods!

Except, apparently in some environments they do - it depends on the JS
engine and whether or not polyfills for iterator helpers were included
in the build.

Turns out my dev environment - and most user browsers - do provide
.reduce(), so we didn't catch this error. It took a large deployment and
error monitoring to catch it.

I've refactored the code to totally avoid deriving data from zod in this
way.
2025-07-11 08:25:47 +10:00
psychedelicious
fc71849c24 feat(app): expose a cursor, not a connection in db util 2025-07-11 08:20:06 +10:00
psychedelicious
a19aa3b032 feat(app): db abstraction to prevent threading conflicts
- Add a context manager to the SqliteDatabase class which abstracts away
creating a transaction, committing it on success and rolling back on
error.
- Use it everywhere. The context manager should be exited before
returning results. No business logic changes should be present.
2025-07-11 08:20:06 +10:00
psychedelicious
ef4d5d7377 feat(ui): virtualized list for staging area
Make the staging area a virtualized list so it doesn't choke when there
are a large number (i.e. more than a few hundred) of queue items.
2025-07-11 07:50:57 +10:00
Mary Hipp Rogers
6b0dfd8427 dont reset canvas if studio is loaded with canvas destination (#8252)
Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
2025-07-10 09:36:41 -04:00
psychedelicious
471c010217 fix(ui): invalid language crashes app
- Apparently locales must use hyphens instead of underscores. This must
have been a fairly recent change that we didn't catch. It caused i18n to
throw for Brasilian Portuguese and both Simplified and Traditional
Mandarin. Change the locales to use the right strings.
- Move the theme + locale provider inside of the error boundary. This
allows errors with locals to be caught by the error boundary instead of
hard-crashing the app. The error screen is unstyled if this happens but
at least it has the reset button.
- Add a migration for the system slice to fix existing users' language
selections. For example, if the user had an incorrect language setting
of `zh_CN`, it will be changed to the correct `zh-CN`.
2025-07-10 14:27:36 +10:00
psychedelicious
b1193022f7 fix(ui): sometimes images added to gallery show as placeholder only
The range-based fetching logic had a subtle bug - it didn't keep track
of what the _current_ visible range is - only the ranges that the user
last scrolled to.

When an image was added to the gallery, the logic saw that the images
had changed, but thought it had already loaded everything it needed to,
so it didn't load the new image.

The updated logic tracks the current visible range separately from the
accumulated scroll ranges to address this issue.
2025-07-10 14:27:36 +10:00
psychedelicious
2152ca092c fix(ui): workaround for dockview bug that lets you drag tabs in certain ways 2025-07-10 14:27:36 +10:00
psychedelicious
ccc62ba56d perf(ui): revised range-based fetching strategy
When the user scrolls in the gallery, we are alerted of the new range of
visible images. Then we fetch those specific images.

Previously, each change of range triggered a throttled function to fetch
that range. The throttle timeout was 100ms.

Now, each change of range appends that range to a list of ranges and
triggers the throttled fetch. The timeout is increased to 500ms, but to
compensate, each fetch handles all ranges that had been accumulated
since the last fetch.

The result is far fewer network requests, but each of them gets more
images.
2025-07-10 14:27:36 +10:00
psychedelicious
9cf82de8c5 fix(ui): check for absolute value of scroll velocity to handle scrolling up 2025-07-10 14:27:36 +10:00
psychedelicious
aced349152 perf(ui): increase viewport in gallery
This allows us to prefetch more images and reduce how often placeholders
are shown as we fetch more images in the gallery.
2025-07-10 14:27:36 +10:00
psychedelicious
0d67ee6548 tests(ui): fix logging mock 2025-07-09 23:15:25 +10:00
psychedelicious
03c21d1607 fix(ui): gallery not updating when saving staging area image 2025-07-09 23:15:25 +10:00
psychedelicious
752e8db1f5 tidy(ui): demote logging in nav api to trace 2025-07-09 23:15:25 +10:00
psychedelicious
85fc861dd9 chore(ui): lint 2025-07-09 23:15:25 +10:00
psychedelicious
458cbfd874 fix(ui): selected model not highlighted 2025-07-09 23:15:25 +10:00
psychedelicious
04331c070a fix(ui): set denoise w/h when running flux fill 2025-07-09 23:15:25 +10:00
psychedelicious
632ddf0cb4 tests(ui): update tests for navigation api 2025-07-09 23:15:25 +10:00
psychedelicious
2b193ff416 fix(ui): delete stored state on error & save new state 2025-07-09 23:15:25 +10:00
psychedelicious
96ee394f9e refactor(ui): use dockview's own ser/de for persistence 2025-07-09 23:15:25 +10:00
psychedelicious
0badc80c0c fix(ui): ignore disabled ref images in readiness checks 2025-07-09 23:15:25 +10:00
psychedelicious
78e6cbf96e fix(ui): default tab is generate 2025-07-09 23:15:25 +10:00
psychedelicious
0b969a661b fix(ui): remove dep on focus from useDeleteImage 2025-07-09 23:15:25 +10:00
psychedelicious
6fe47ec9f8 feat(ui): improve ref image model autoswitch logic 2025-07-09 23:15:25 +10:00
Kent Keirsey
3850dd61f8 update comment 2025-07-09 23:15:25 +10:00
Kent Keirsey
75520eaf0f Match Chatgpt4o and kontext names exactly 2025-07-09 23:15:25 +10:00
Kent Keirsey
10e88c58c1 fix and lint 2025-07-09 23:15:25 +10:00
Kent Keirsey
30ed4dbd92 lint 2025-07-09 23:15:25 +10:00
Kent Keirsey
ed9c090f33 fixes 2025-07-09 23:15:25 +10:00
Kent Keirsey
d29f65ed22 lint fixes 2025-07-09 23:15:25 +10:00
Kent Keirsey
2062ec8ac0 Update invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts
Co-authored-by: Mary Hipp Rogers <maryhipp@gmail.com>
2025-07-09 23:15:25 +10:00
Cursor Agent
49e818338a Changes from background composer bc-abfadb27-a265-41a7-b0db-829879f4701e 2025-07-09 23:15:25 +10:00
Cursor Agent
1caab2b9c4 Implement automatic reference image model switching on base model change
Co-authored-by: kent <kent@invoke.ai>
2025-07-09 23:15:25 +10:00
psychedelicious
50079ea349 fix(ui): big red cancel button has diff behaviour than staging discard 2025-07-09 23:15:25 +10:00
psychedelicious
fffa1b24c4 fix(ui): isStaging selector could return wrong query cache 2025-07-09 23:15:25 +10:00
psychedelicious
a6d6170387 fix(ui): discarding 1 item when 2 items left in staging area discards both 2025-07-09 23:15:25 +10:00
psychedelicious
e5fceb0448 fix(ui): whole app scrolls while selecting staging area image 2025-07-09 23:15:25 +10:00
psychedelicious
059baf5b29 chore(ui): lint 2025-07-09 23:15:25 +10:00
psychedelicious
1be8a9a310 fix(ui): add metadata i18nKey to handler; fixes metadata toasts 2025-07-09 23:15:25 +10:00
psychedelicious
7adc33e04d refactor(ui): metadata recall buttons & hotkeys (WIP) 2025-07-09 23:15:25 +10:00
psychedelicious
7f2dd22d47 refactor(ui): metadata recall buttons & hotkeys (WIP) 2025-07-09 23:15:25 +10:00
psychedelicious
bb50f4b8a2 fix(ui): prevent panels from growing on init
This works but I think a better solution is to use dockview's provided
serialization API to store and restore layouts.
2025-07-09 23:15:25 +10:00
psychedelicious
a48958e0d4 chore(ui): lint 2025-07-09 23:15:25 +10:00
psychedelicious
e3a1e9af53 feat(ui): staging area updates
- Smaller staged image previews.
- Move autoswitch buttons to staging area toolbar, remove from settings
popover and the little three-dots menu. Use persisted autoswitch
setting, which is renamed from `defaultAutoSwitch` to
`stagingAreaAutoSwitch`.
- Fix issue with misaligned border radii in staging area preview images.
Required small changes to DndImage and its usage elsewhere.
- Fix issue where staging area toolbar could show up without any
previews in the list.
- Migrate canvas settings slice to use zod schema and inferred types for
its state.
2025-07-09 23:15:25 +10:00
psychedelicious
c6fe11c42f fix(ui): disable gallery hotkeys when in staging area 2025-07-09 23:15:25 +10:00
psychedelicious
4eb1bd67df fix(ui): hide staging area when there are no items 2025-07-09 23:15:25 +10:00
psychedelicious
c376f914d2 chore: bump version v6.0.0 2025-07-09 23:15:25 +10:00
Kent Keirsey
b5d1c47ef7 final link fix 2025-07-09 10:17:38 +10:00
Kent Keirsey
004a52ca65 fix to direct links 2025-07-09 10:17:38 +10:00
Kent Keirsey
b1d5a51ddf add-quantized-kontext-dev 2025-07-09 10:17:38 +10:00
Kent Keirsey
2b2498eaa1 fix prettier quirk 2025-07-08 14:54:29 -04:00
Kent Keirsey
10dda4440e Fix label 2025-07-08 14:54:29 -04:00
Cursor Agent
98f78abefa Add default auto-switch mode setting for canvas sessions
Co-authored-by: kent <kent@invoke.ai>
2025-07-08 14:54:29 -04:00
Mary Hipp Rogers
cc93fa270f update whats new for v6 (#8234)
Co-authored-by: Mary Hipp <maryhipp@Marys-Air.lan>
2025-07-08 18:24:33 +00:00
Mary Hipp Rogers
014b27680f fix flux kontext error (#8235)
Co-authored-by: Mary Hipp <maryhipp@Marys-Air.lan>
2025-07-08 13:42:48 -04:00
Mary Hipp Rogers
c3d8f875de if on generate tab, recall dimensions instead of bbox (#8233)
Co-authored-by: Mary Hipp <maryhipp@Marys-Air.lan>
2025-07-08 13:09:21 -04:00
Mary Hipp Rogers
79f9dc6e4a fix(ui): dont show option to add new layer from if on generate tab (#8231)
* dont show option to add new layer from if on generate tab

* only disable width/height recall is staging AND canvas tab

---------

Co-authored-by: Mary Hipp <maryhipp@Marys-Air.lan>
2025-07-08 11:46:54 -04:00
psychedelicious
6e1c0c1105 chore: bump version to v6.0.0rc5 2025-07-08 11:26:47 -04:00
Mary Hipp Rogers
0362524040 remove hard-coded flux kontext dev guidance (#8230)
Co-authored-by: Mary Hipp <maryhipp@Marys-Air.lan>
2025-07-08 10:26:20 -04:00
psychedelicious
dc6656459b docs(ui): updated comments for navigation api 2025-07-08 07:30:36 -04:00
psychedelicious
3ea1b97f6f fix(ui): protect against getting stuck on tab loading screen 2025-07-08 07:30:36 -04:00
psychedelicious
a7c7405ccc feat(ui): style model picker selected item 2025-07-08 07:28:07 -04:00
psychedelicious
c391f1117a fix(ui): traverse groups when finding selected model in picker 2025-07-08 07:28:07 -04:00
psychedelicious
b1e2cb8401 fix(ui): queue tab list of queue items
Reverted incomplete change to how queue items are listed. In the future
I think we should redo it to work like the gallery. For now, it is back
the way it was in v5.
2025-07-08 07:22:51 -04:00
Emmanuel Ferdman
db6af134b7 fix: resolve FastAPI deprecation warning for example fields
Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com>
2025-07-08 20:54:08 +10:00
Emmanuel Ferdman
7e6cffb00c fix: resolve FastAPI deprecation warning for example fields
Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com>
2025-07-08 20:54:08 +10:00
psychedelicious
5b187bcb00 fix(ui): pull bbox into ref image component 2025-07-08 14:54:43 +10:00
psychedelicious
0843d609a3 feat(ui): add list of warnings in tooltip on ref image 2025-07-08 14:54:43 +10:00
Kent Keirsey
95bd9cef18 Lint 2025-07-08 14:54:43 +10:00
Kent Keirsey
931d6521f6 Adds bbox to ref image button 2025-07-08 14:54:43 +10:00
psychedelicious
e37665ff59 tests(ui): add wiggle room to timeout tests 2025-07-08 12:55:33 +10:00
psychedelicious
56857fbbe6 tests(ui): add tests for panel storage 2025-07-08 12:55:33 +10:00
psychedelicious
43cfb8a574 tests(ui): get tests passing
Still need tests for panel storage.
2025-07-08 12:55:33 +10:00
psychedelicious
05b1682d15 fix(ui): handle collapsed panels when rehydrating their state 2025-07-08 12:55:33 +10:00
psychedelicious
69a08ee7f2 feat(ui): panel state persistence (WIP) 2025-07-08 12:55:33 +10:00
psychedelicious
18212c7d8a feat(ui): clean up navigation API surface and add comments 2025-07-08 12:55:33 +10:00
psychedelicious
7de26f8e69 feat(ui): clean up auto layout context for panels 2025-07-08 12:55:33 +10:00
Kent Keirsey
0652b12a6f Address comments 2025-07-08 12:31:11 +10:00
Kent Keirsey
43a361a00f prettier 2025-07-08 12:31:11 +10:00
Kent Keirsey
cf68ad9cbc update links to playlist instead of video 2025-07-08 12:31:11 +10:00
Kent Keirsey
ec02a39325 fixes 2025-07-08 12:31:11 +10:00
Kent Keirsey
e52d7a05c2 Update support links. 2025-07-08 12:31:11 +10:00
Cursor Agent
c9d4e2b761 Refactor support videos modal to simplify video and playlist handling
Co-authored-by: kent <kent@invoke.ai>
2025-07-08 12:31:11 +10:00
Kent Keirsey
ac26aa9508 fix 2025-07-08 12:31:11 +10:00
Cursor Agent
9ff6ada15b Add support for video playlists in support videos modal
Co-authored-by: kent <kent@invoke.ai>
2025-07-08 12:31:11 +10:00
psychedelicious
e81a115169 chore(ui): lint 2025-07-08 12:23:57 +10:00
Kent Keirsey
52827807de remove ref image from upscale 2025-07-08 12:23:57 +10:00
Kent Keirsey
b631de4cb5 consistency 2025-07-08 12:20:08 +10:00
Kent Keirsey
099ebdbc37 fix 2025-07-08 12:20:08 +10:00
psychedelicious
4de6549be9 refactor(ui): track discarded items instead of using delete method 2025-07-08 12:12:55 +10:00
psychedelicious
368be34949 chore(ui): lint 2025-07-08 12:12:55 +10:00
psychedelicious
5baa4bd916 refactor(ui): use cancelation for staging area (mostly) 2025-07-08 12:12:55 +10:00
psychedelicious
4229377532 fix(app): ensure cancel events are emitted for current item when bulk canceling
There was a bug where bulk cancel operations would cancel the current
queue item in the DB but not emit the status changed events correctly.
2025-07-08 12:12:55 +10:00
psychedelicious
2610772ffd feat(ui): tighten up launchpad content to fit better 2025-07-08 08:57:44 +10:00
psychedelicious
193de6a8f2 feat(ui): add launchpad container component 2025-07-08 08:57:44 +10:00
psychedelicious
7ea343c787 tidy(ui): remove "staging" from the new settings verbiage 2025-07-08 07:10:55 +10:00
Kent Keirsey
12179dabba fix prettier 2025-07-08 07:10:55 +10:00
Cursor Agent
ef135f9923 Add option to save all staging images to gallery in canvas mode
Co-authored-by: kent <kent@invoke.ai>
2025-07-08 07:10:55 +10:00
Mary Hipp
e6c67cc00f update toast for prompt expansion failed 2025-07-08 06:42:00 +10:00
psychedelicious
179b988148 fix(ui): prompt concat derived state recall 2025-07-08 06:37:43 +10:00
psychedelicious
d913a3c85b fix(ui): reset selected ref image when replacing all
Fixes an unhandled error in a selector that can throw.
2025-07-08 06:37:43 +10:00
psychedelicious
e79525c40c docs(ui): update comments 2025-07-08 06:11:32 +10:00
psychedelicious
f409f913ac fix(ui): navigation api usage 2025-07-08 06:11:32 +10:00
Mary Hipp
7a79f61d4c add claude nodes to blacklist for publishing 2025-07-08 05:50:40 +10:00
psychedelicious
ea182c234b chore: bump version to v6.0.0rc4 2025-07-07 22:15:28 +10:00
psychedelicious
f2eee4a82d chore(ui): lint 2025-07-07 22:05:49 +10:00
psychedelicious
e129525306 fix(app): handle None in queue count queries 2025-07-07 22:05:49 +10:00
psychedelicious
ecedfce758 feat(ui): support a min expanded size for collapsible panels 2025-07-07 22:05:49 +10:00
psychedelicious
702cb2cb1e fix(ui): flux kontext special handlign for ref image models 2025-07-07 22:05:49 +10:00
psychedelicious
2e8db3cce3 fix(ui): ensure noise is correctly sized 2025-07-07 22:05:49 +10:00
psychedelicious
7845623fa5 fix(ui): session context indexing bug 2025-07-07 22:05:49 +10:00
psychedelicious
e6a25ca7a2 feat(ui): render progress as indeterminate when percentage is 0
When percentage is zero, the progress bar looks the same as it does when
no generation is in progress. Render it as indeterminate (pulsing) when
percentage is zero to indicate that somethign is happenign.
2025-07-07 22:05:49 +10:00
psychedelicious
71e12bcebe fix(ui): when no negative prompt is provided, recall it as null 2025-07-07 22:05:49 +10:00
psychedelicious
863c7eb9e2 fix(ui): metadata display for primitive values 2025-07-07 22:05:49 +10:00
psychedelicious
9945c20d02 refactor(ui): simplifiy graph builders (WIP) 2025-07-07 22:05:49 +10:00
psychedelicious
e3c1334b1f refactor(ui): simplifiy graph builders (WIP) 2025-07-07 22:05:49 +10:00
psychedelicious
c143f63ef0 refactor(ui): simplifiy graph builders (WIP) 2025-07-07 22:05:49 +10:00
psychedelicious
067026a0d0 feat(ui): add autocomplete for Graph.addEdgeToMetadata 2025-07-07 22:05:49 +10:00
psychedelicious
66991334fc refactor(ui): simplify graph builder handling of VAE encode and seed 2025-07-07 22:05:49 +10:00
psychedelicious
b771c3b164 refactor(ui): update graphs to use the right w/h/aspect 2025-07-07 22:05:49 +10:00
psychedelicious
4925694dc1 feat(ui): generate tab has separate w/h/aspect 2025-07-07 22:05:49 +10:00
psychedelicious
0a737ced44 feat(ui): add dimensions to params slice 2025-07-07 22:05:49 +10:00
psychedelicious
8d83caaae0 feat(ui): extract aspect ratios from canvas reducers 2025-07-07 22:05:49 +10:00
psychedelicious
16c8017f1a feat(ui): more resilient gallery scrollIntoView 2025-07-07 22:05:49 +10:00
psychedelicious
61a35f1396 fix(ui): skip optimistic updates for gallery when using search term 2025-07-07 22:05:49 +10:00
psychedelicious
6bd004d868 fix(ui): clear ref images when recalling all
Closes #8202
2025-07-07 22:05:49 +10:00
psychedelicious
b6a6d406c7 chore(ui): typegen 2025-07-07 10:25:24 +10:00
psychedelicious
8e287c32ee chore(ui): lint 2025-07-07 10:25:24 +10:00
psychedelicious
2d8b5e26c2 build(ui): bump vite to latest 2025-07-07 10:25:24 +10:00
psychedelicious
50914b74ee chore(build): update pnpm to v10 2025-07-07 10:25:24 +10:00
psychedelicious
0fc1c33536 chore(ui): knip 2025-07-07 10:25:24 +10:00
psychedelicious
3b08c35f72 chore(ui): update knip config 2025-07-07 10:25:24 +10:00
psychedelicious
607b2561fd chore(ui): bump knip to latest 2025-07-07 10:25:24 +10:00
psychedelicious
d68f922efb fix(ui): restore upscale-tab-specific settings components 2025-07-07 10:25:24 +10:00
psychedelicious
2bbd74d418 feat(ui): restore canvas busy spinner 2025-07-07 10:25:24 +10:00
psychedelicious
3a5392a9ee chore: bump version to v6.0.0rc3 2025-07-04 20:46:08 +10:00
psychedelicious
6f80efe71d fix(ui): bump expandprompt timeout to 15s 2025-07-04 20:46:08 +10:00
psychedelicious
7fac833813 fix(ui): ref image model types again 2025-07-04 20:35:29 +10:00
psychedelicious
b67eb4134d fix(ui): select next image when deleting 2025-07-04 20:35:29 +10:00
psychedelicious
522eeda2e2 fix(ui): ref image model types 2025-07-04 20:35:29 +10:00
psychedelicious
76233241f0 fix(ui): include ref image metadata for flux kontext 2025-07-04 20:35:29 +10:00
psychedelicious
54be9989c5 feat(ui): add 'replace' and 'merge' strategies for upsertMetadata 2025-07-04 20:35:29 +10:00
psychedelicious
0d3af08d27 fix(ui): prompt parsing in useImageActions 2025-07-04 20:35:29 +10:00
psychedelicious
767ac91f2c fix(nodes): revert unnecessary version bump 2025-07-04 20:35:29 +10:00
psychedelicious
68571ece8f tidy(app): remove unused methods 2025-07-04 20:35:29 +10:00
psychedelicious
01100a2b9a fix(ui): check for ref image config compatibility for flux kontext dev 2025-07-04 20:35:29 +10:00
psychedelicious
ce2e6d8ab6 fix(ui): kontext gen mode error tkey 2025-07-04 20:35:29 +10:00
psychedelicious
4887424ca3 chore: ruff 2025-07-04 20:35:29 +10:00
Kent Keirsey
28f6a20e71 format import block 2025-07-04 20:35:29 +10:00
Kent Keirsey
c4142e75b2 fix import 2025-07-04 20:35:29 +10:00
Kent Keirsey
fefe563127 fix resizing and versioning 2025-07-04 20:35:29 +10:00
Mary Hipp
1c72f1ff9f include flux kontext non-api models in ref image dropdown options 2025-07-04 20:35:29 +10:00
Mary Hipp
605cc7369d update flux kontext implementation to include flux kontext dev non-api models 2025-07-04 20:35:29 +10:00
Kent Keirsey
e7ce08cffa ruff format 2025-07-04 19:24:44 +10:00
Kent Keirsey
983cb5ebd2 ruff ruff 2025-07-04 19:24:44 +10:00
Kent Keirsey
52dbdb7118 ruff 2025-07-04 19:24:44 +10:00
Kent Keirsey
71e6f00e10 test fixes
fix

test

fix 2

fix 3

fix 4

yet another

attempt new fix

pray

more pray

lol
2025-07-04 19:24:44 +10:00
psychedelicious
e73150c3e6 feat(ui): improved automatic tab/panel switching on user actions 2025-07-04 19:18:03 +10:00
psychedelicious
f2426c3ab2 fix(ui): type for dnd action 2025-07-04 19:18:03 +10:00
psychedelicious
9d9c4c0f1a tidy(ui): remove unused old metadata impl 2025-07-04 17:53:47 +10:00
psychedelicious
acb930f6b9 fix(ui): flux redux saves metadata 2025-07-04 17:53:47 +10:00
psychedelicious
585b54dc7d feat(ui): ref image recall w/ old canvas metadata backup 2025-07-04 17:53:47 +10:00
psychedelicious
f65affc0ec fix(ui): do not attempt to recall ref images from canvas metadata 2025-07-04 17:53:47 +10:00
psychedelicious
22d574c92a feat(ui): canvas metadata recall 2025-07-04 17:53:47 +10:00
psychedelicious
f23be119fc refactor(ui): migrating to new metadata handlers 2025-07-04 17:53:47 +10:00
psychedelicious
2d06949e80 feat(ui): display cached metadata if it exists instead of always waiting for debounce 2025-07-04 17:53:47 +10:00
psychedelicious
67804313e1 fix(ui): add ref images to metadata 2025-07-04 17:53:47 +10:00
psychedelicious
dc23be117a refactor(ui): simplified metadata parsing (WIP) 2025-07-04 17:53:47 +10:00
psychedelicious
350de058fc refactor(ui): simplified metadata parsing (WIP) 2025-07-04 17:53:47 +10:00
psychedelicious
fd5cd707a3 refactor(ui): simplified metadata parsing (WIP) 2025-07-04 17:53:47 +10:00
psychedelicious
98ecefdce0 refactor(ui): simplified metadata parsing (WIP) 2025-07-04 17:53:47 +10:00
psychedelicious
42688a0993 refactor(ui): metadata parsing 2025-07-04 17:53:47 +10:00
psychedelicious
d94aa4abf7 feat(ui): enforce loader when switching tabs 2025-07-04 16:49:57 +10:00
psychedelicious
69a56aafed feat(ui): do not require root ref to focus on prompt 2025-07-04 16:49:57 +10:00
psychedelicious
56873f6936 feat(ui): queue and models tab are wrapped in dockview panels 2025-07-04 16:49:57 +10:00
psychedelicious
6bc6a680cf tests(ui): NavigationApi 2025-07-04 16:49:57 +10:00
psychedelicious
9a49682f60 feat(ui): utils to get tab/panel keys to prevent typos 2025-07-04 16:49:57 +10:00
psychedelicious
ff84b0a495 refactor(ui): navigation api 2025-07-04 16:49:57 +10:00
psychedelicious
bcced8a5e8 refactor(ui): navigation api 2025-07-04 16:49:57 +10:00
psychedelicious
4a18e9eaea refactor(ui): panel api (WIP) 2025-07-04 16:49:57 +10:00
psychedelicious
dde5bf61be feat(ui): use exact brand colors in loader 2025-07-04 16:49:57 +10:00
psychedelicious
987e401709 perf(ui): lora components 2025-07-04 14:55:52 +10:00
psychedelicious
5c5ac570e3 fix(ui): hardcode literals for run graph errors
When we build, the class names are minified. This hardcodes the values
to literals.
2025-07-04 14:52:08 +10:00
psychedelicious
309903fe0f feat(ui): refetch gallery image names on reconnect
Maybe fixes JP's issue (again)
2025-07-04 14:49:32 +10:00
psychedelicious
f16ea43e9a feat(ui): enable RTK Query's refetchOnReconnect 2025-07-04 14:49:32 +10:00
Jeremy Gooch
d794aedb43 fix(ui): sets cfg_rescael_multiplier to 0 if there is no default. Also fixes issue with truthiness check causing 0 value to be missed. See https://github.com/invoke-ai/InvokeAI/issues/7584 2025-07-04 06:20:14 +10:00
psychedelicious
9930440f33 chore: bump version to v6.0.0rc2 2025-07-03 12:35:04 +10:00
psychedelicious
f0a6c4aa1f fix(ui): after canceling a filter, layer loses its content 2025-07-03 12:30:01 +10:00
psychedelicious
f36d22f13c fix(ui): control layers ignored in txt2img 2025-07-03 12:27:05 +10:00
Cursor Agent
e0d7fab524 Fix: Toggle right panel instead of left panel in navigation
Co-authored-by: kent <kent@invoke.ai>
2025-07-03 12:15:22 +10:00
Cursor Agent
f20c230f4a Add drag-and-drop comparison image target to ImageViewerPanel
Co-authored-by: kent <kent@invoke.ai>
2025-07-03 12:10:51 +10:00
Cursor Agent
05c9bc730e Fix canvas export layer bounds calculation in PSD export hook
Co-authored-by: kent <kent@invoke.ai>
2025-07-03 12:07:22 +10:00
Cursor Agent
f17ac06591 Fix PSD export to use layer content bounds and crop canvas
Co-authored-by: kent <kent@invoke.ai>
2025-07-03 12:07:22 +10:00
Kent Keirsey
b35f93d919 Change implementation to check $ispending 2025-07-03 12:04:27 +10:00
Cursor Agent
289d8076d8 Reset canvas session when queue item is canceled in current session
Co-authored-by: kent <kent@invoke.ai>
2025-07-03 12:04:27 +10:00
skunkworxdark
604763d20f Update flux.py
Replace T5Tokenizer with T5TokenizerFast
2025-07-03 08:04:08 +10:00
Mary Hipp
7b452f098d lint 2025-07-02 16:27:44 -04:00
Mary Hipp
b41c18d35f disable dropzone if prompt expansion is disabled 2025-07-02 16:27:44 -04:00
Mary Hipp
8328081333 properly build batch for flux kontext api batches 2025-07-02 14:27:57 -04:00
Mary Hipp Rogers
07517cf2c2 remove pulsing animation (#8181)
Co-authored-by: Mary Hipp <maryhipp@Marys-Air.lan>
2025-07-02 16:12:52 +00:00
Kent Keirsey
6b98ad9095 Only display one icon on disabled state 2025-07-02 10:54:46 -04:00
Kent Keirsey
0de3967e7e remove stray file 2025-07-02 10:54:46 -04:00
Kent Keirsey
1335377fb1 Fixes 2025-07-02 10:54:46 -04:00
Cursor Agent
adbcc191d9 Add reference image enable/disable functionality
Co-authored-by: kent <kent@invoke.ai>
2025-07-02 10:54:46 -04:00
Kent Keirsey
11fc7af1c8 fix 2025-07-02 10:47:01 -04:00
Cursor Agent
6f12fd22b9 Optimize image API invalidation tags and simplify cache invalidation logic
Co-authored-by: kent <kent@invoke.ai>
2025-07-02 10:47:01 -04:00
Cursor Agent
324b6e2af4 Update LoRA select placeholder text for better clarity
Co-authored-by: kent <kent@invoke.ai>
2025-07-02 10:36:45 -04:00
Mary Hipp Rogers
038010a1ca feat(ui): prompt expansion (#8140)
* initializing prompt expansion and putting response in prompt box working for all methods

* properly disable UI and show loading state on prompt box when there is a pending prompt expansion item

* misc wrapup: disable apploying prompt templates, dont block textarea resize handle

* update progress to differentiate between prompt expansion and non

* cleanup

* lint

* more cleanup

* add image to background of loading state

* add allowPromptExpansion for front-end gating

* updated readiness text for needing to accept or discard

* fix tsc

* lint

* lint

* refactor(ui): prompt expansion logic

* tidy(ui): remove unnecessary changes

* revert(ui): unused arg on useImageUploadButton

* feat(ui): simplify prompt expansion state

* set pending for dragndrop and context menu

* add readiness logic for generate tab

* missing translation

* update error handling for prompt expansion

---------

Co-authored-by: Mary Hipp <maryhipp@Marys-Air.lan>
Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
2025-07-02 10:26:48 -04:00
Cursor Agent
2dd1bc54c9 Set brush tool automatically when sending image to canvas
Co-authored-by: kent <kent@invoke.ai>
2025-07-02 10:09:22 -04:00
Kent Keirsey
8b69842678 lint 2025-07-02 09:46:32 -04:00
Kent Keirsey
9821f7c4fc Remove Canvas Session 2025-07-02 09:46:32 -04:00
Cursor Agent
2290ff4ad6 Fix: Focus viewer panel when switching to workflow view mode
Co-authored-by: kent <kent@invoke.ai>
2025-07-02 09:42:21 -04:00
psychedelicious
8d82ad6d0b fix(api): return HTTP errors from session queue handlers 2025-07-02 08:42:06 -04:00
Mary Hipp
8ed9f652e8 lint 2025-07-02 08:25:42 -04:00
Mary Hipp
ee8ed344bd add modelRelationships and aboutModal to disable-able features 2025-07-02 08:25:42 -04:00
Mary Hipp
6d16cfdbe2 missing import 2025-07-02 08:23:13 -04:00
Mary Hipp
3ef2872dda handle flux-kontext models 2025-07-02 08:23:13 -04:00
Cursor Agent
b52ba149b4 Update regional guidance empty state translation key
Co-authored-by: kent <kent@invoke.ai>
2025-07-02 08:09:42 -04:00
Kent Keirsey
c6126c6875 Remove all references to New Sessions entirely. 2025-07-01 17:20:35 -04:00
psychedelicious
3f78ac9295 fix(ui): really do not load disabled tabs
Ensure disabled tabs are never mounted:
- Add didLoad flag to configSlice, default false
- Always merge in config - even it is is empty
- On first merge, set didLoad to true
- Until didLoad is true, mark _all_ tabs as disabled

This gets around an issue where tabs are all enabled for a brief moment
before the config is loaded.

A bit hacky but it works.
2025-07-01 10:52:28 -04:00
psychedelicious
79fea1ac40 chore: bump version to v6.0.0rc1 2025-07-02 00:14:13 +10:00
psychedelicious
6eade5781d feat(ui): remove mini metadata viewer 2025-07-01 23:37:31 +10:00
psychedelicious
3d8f865fb0 fix(ui): initial panel sizing 2025-07-01 23:37:31 +10:00
psychedelicious
dc9cd22d9d feat(ui): better naming for panel apis 2025-07-01 23:37:31 +10:00
psychedelicious
fe115ff8f9 fix(ui): models & queue tab styling 2025-07-01 23:37:31 +10:00
psychedelicious
1d35aad213 feat(ui): move more things over to pane lreg 2025-07-01 23:37:31 +10:00
psychedelicious
195d6ce893 refactor(ui): implement global panel registry, replace context-based panel API 2025-07-01 23:37:31 +10:00
psychedelicious
f13ced7ed4 fix(ui): rebase conflicts 2025-07-01 23:37:31 +10:00
psychedelicious
735fc276e5 tidy(ui): clean up focus/layout container 2025-07-01 23:37:31 +10:00
psychedelicious
cd3caf8c30 fix(ui): delete image hotkey 2025-07-01 23:37:31 +10:00
psychedelicious
e9012280ab fix(ui): upscaling tab boards/gallery collapse 2025-07-01 23:37:31 +10:00
psychedelicious
fa72a97794 refactor(ui): even more better focus handling 2025-07-01 23:37:31 +10:00
psychedelicious
e817631ba3 refactor(ui): focus handling for new layout system (WIP) 2025-07-01 23:37:31 +10:00
psychedelicious
d0619c033f feat(ui): add edit button to current image buttons 2025-07-01 16:29:20 +10:00
psychedelicious
6f4850f34f tidy(ui): launchpad tab with icon cleanup 2025-07-01 15:37:06 +10:00
Kent Keirsey
072cd9dee7 Styling Fixes 2025-07-01 15:37:06 +10:00
Cursor Agent
19b6dc1c1f Add custom Launchpad tab with dynamic icon based on active tab
Co-authored-by: kent <kent@invoke.ai>
2025-07-01 15:37:06 +10:00
Cursor Agent
7566d0d6c6 Enhance workflow mode toggle with panel navigation and focus
Co-authored-by: kent <kent@invoke.ai>
2025-07-01 15:27:21 +10:00
psychedelicious
f123888b46 feat(ui): tidy workflows tab launchapd 2025-07-01 15:24:08 +10:00
psychedelicious
aeab7d0cab feat(ui): tidy upscaling tab launchapd 2025-07-01 15:24:08 +10:00
Kent Keirsey
3f1b2c39ab Model Guide link update 2025-07-01 15:24:08 +10:00
Kent Keirsey
72e3a4b4be Fixes & Updates 2025-07-01 15:24:08 +10:00
Kent Keirsey
58e0f80138 Lint 2025-07-01 15:24:08 +10:00
Kent Keirsey
8b8e29d22d Fixes & Styling updates 2025-07-01 15:24:08 +10:00
Kent Keirsey
90201be670 lint 2025-07-01 15:24:08 +10:00
Kent Keirsey
46a5619100 Update all text to translations 2025-07-01 15:24:08 +10:00
Kent Keirsey
d608a7469e Upscale Workflow Launchpad updates & translation updates 2025-07-01 15:24:08 +10:00
Cursor Agent
a7d413d372 Refactor Upscaling and Workflows Launchpad Panels with enhanced UI
Co-authored-by: kent <kent@invoke.ai>
2025-07-01 15:24:08 +10:00
Cursor Agent
f5c9e68dbf Fix division by zero in multi-diffusion pipeline with creativity values
Co-authored-by: kent <kent@invoke.ai>

Revert unnecessary validation changes in multi-diffusion

Fix in python instead of graphbuilder

tidy(ui): remove extraneous comment
2025-07-01 15:00:02 +10:00
psychedelicious
1ded459f03 refactor(ui): clean up related models impl for picker 2025-07-01 14:52:26 +10:00
Kent Keirsey
d9024dc230 linting fixes 2025-07-01 14:52:26 +10:00
Kent Keirsey
40528692c3 Update icon 2025-07-01 14:52:26 +10:00
Kent Keirsey
f35b05be43 simplifies Modelpicker wrapper 2025-07-01 14:52:26 +10:00
Kent Keirsey
29e87fc615 lints 2025-07-01 14:52:26 +10:00
Kent Keirsey
ca26b2718e Small Changes 2025-07-01 14:52:26 +10:00
Cursor Agent
5fa6c0b413 Enhance model picker with related models and improved filtering
Co-authored-by: kent <kent@invoke.ai>
2025-07-01 14:52:26 +10:00
psychedelicious
c37c8c50cd tidy(ui): clean up psd export 2025-07-01 14:12:14 +10:00
Kent Keirsey
f0a4de245d Moved size constants to a reasonable spot... 2025-07-01 14:12:14 +10:00
Kent Keirsey
5db62f8643 Fix Type refs 2025-07-01 14:12:14 +10:00
Kent Keirsey
e1c478f94c Size Updates 2025-07-01 14:12:14 +10:00
Kent Keirsey
11fe3b6332 Comments 2025-07-01 14:12:14 +10:00
Kent Keirsey
e4aae1a591 prettier 2025-07-01 14:12:14 +10:00
Kent Keirsey
4d83d1c56d Linting 2025-07-01 14:12:14 +10:00
Kent Keirsey
34def323e8 Restyle & locate 2025-07-01 14:12:14 +10:00
Kent Keirsey
854956316b Fix export layers 2025-07-01 14:12:14 +10:00
Cursor Agent
91afe7884a Add PSD export functionality for canvas layers
Co-authored-by: kent <kent@invoke.ai>
2025-07-01 14:12:14 +10:00
psychedelicious
8417ee8a7b chore(ui): lint 2025-06-30 23:42:53 +10:00
psychedelicious
a035645ed3 refactor(ui): graph building respects selected tab 2025-06-30 23:42:53 +10:00
psychedelicious
e00ccba7d3 perf(ui): select only loading state for enqueueBatch mutation 2025-06-30 23:42:53 +10:00
psychedelicious
fb883d63aa refactor(ui): dedicated enqueue funcs for each tab 2025-06-30 23:42:53 +10:00
psychedelicious
b113c57fc4 refactor(ui): use redux-provided hooks for accessing store 2025-06-30 23:42:53 +10:00
psychedelicious
7636007349 fix(ui): useAppStore uses correct types 2025-06-30 23:42:53 +10:00
psychedelicious
fda86ae981 fix(app): incorrect node mappings when preparing collect nodes
The previous logic had a subtle python bug related the scope and nested
generators.

Python generators are lazily evaluated - the expressions are stored and
only evaluated when needed (e.g. calling next() or list() on them)

The old logic used a variable `s`, which was continually overwritten as
the generator expressions were created. As a result, the final mappings
all use the _final_ value for `s`.

Following the consequences of this down the line, we find that collect
nodes can end up with multiple edges from exactly one of their ancestor
nodes, instead of one edge from each ancestor. Notably, it's only the
source _node_id_ that is affected - the source _fields_ have the correct
values.

So the invalid edges will point to a real node and a real field, but the
field exists on a different node.

---

This can result in a number of cryptic problems - include an error about
incompatible field types:

```
InvalidEdgeError: Field types are incompatible
(31758fd5-14a8-4de7-a840-b73ec1a1b94f.value ->
3459c793-41a2-4d82-9204-7df2d6d099ba.item)
```

Here are the conditions that lead to this error:
- The collect node has at least two incoming connections.
- The two incoming connections come from nodes of different types.
- The nodes both output a value of the same type, but the name of the
output field differs between them.

---

This commit uses non-generator logic to build up the mappings, avoiding
the issue entirely. As a bonus, it is much easier to read.
2025-06-30 23:39:28 +10:00
psychedelicious
c02be4bdf4 refactor(app): lean on pydantic to get field types in edge validation logic
Previously we used python's own type introspection utilties to determine
input and output field types. We can use pydantic to get the field types
in a clearer, more direct way.

This improvement also exposed an awkward behaviour in this utility,
where it would return None when a field doesn't exist. I've added a
comment in the code describing the issue, but changing it would require
some significant changes and I don't want to risk breaking anything.
2025-06-30 23:39:28 +10:00
psychedelicious
ed7772d993 tests(app): add more tests for complex iterate/collect graph topologies 2025-06-30 23:39:28 +10:00
psychedelicious
baae998b5b tests(app): add failing test for collector edge case
squash

squash
2025-06-30 23:39:28 +10:00
DustyShoe
4077ffe595 Fixed a typo 2025-06-30 15:44:23 +10:00
psychedelicious
c1937b1379 chore: ruff 2025-06-30 12:56:51 +10:00
psychedelicious
5c66dfed8e fix(app): remove errant comment from prev impl 2025-06-30 12:56:51 +10:00
psychedelicious
126dcc96c0 feat(ui): clean up logging and comments in runGraph 2025-06-30 12:56:51 +10:00
psychedelicious
cb9c7b4a28 feat(ui): simplify runGraph logic for error handling 2025-06-30 12:56:51 +10:00
psychedelicious
e8c4f49a14 feat(ui): add .wrap() method to WrappedError 2025-06-30 12:56:51 +10:00
psychedelicious
30fffae637 feat(ui): runGraph settlement callbacks can simply return or throw 2025-06-30 12:56:51 +10:00
psychedelicious
4558a292b6 tests(ui): update runGraph tests for separate options 2025-06-30 12:56:51 +10:00
psychedelicious
825d17441c feat(ui): separate options arg for runGraph 2025-06-30 12:56:51 +10:00
psychedelicious
9b16504af9 docs(ui): improved runGraph docstring 2025-06-30 12:56:51 +10:00
psychedelicious
46c92fadff feat(ui): use system logger for runGraph 2025-06-30 12:56:51 +10:00
psychedelicious
c0467b82ac tests(ui): update runGraph tests for new error state 2025-06-30 12:56:51 +10:00
psychedelicious
6dafa67286 feat(ui): improved logging for runGraph 2025-06-30 12:56:51 +10:00
psychedelicious
eb406aa07e feat(ui): mark runGraph error properties public readonly 2025-06-30 12:56:51 +10:00
psychedelicious
d9422ffebd tests(ui): add testes for enriched cancel/timeout errors 2025-06-30 12:56:51 +10:00
psychedelicious
d5c033be4d feat(ui): enrich cancel/timeout errors when queue item cancel fails 2025-06-30 12:56:51 +10:00
psychedelicious
4662cd6f15 fix(ui): await cancelation of queue item before returning 2025-06-30 12:56:51 +10:00
psychedelicious
a740a22613 feat(ui): runGraph uses settle for all promise handling, better comments 2025-06-30 12:56:51 +10:00
psychedelicious
bf4016b4bc feat(ui): add getNodes method to Graph 2025-06-30 12:56:51 +10:00
psychedelicious
6fa7c8c2ee feat(ui): better exception naming and docstrings in runGraph 2025-06-30 12:56:51 +10:00
psychedelicious
ea40f582da tweak(ui): naming, code style 2025-06-30 12:56:51 +10:00
psychedelicious
01caf56251 feat(ui): clearer naming in WrappedError 2025-06-30 12:56:51 +10:00
psychedelicious
42d577e65a tests(ui): check for error instance instead of message 2025-06-30 12:56:51 +10:00
psychedelicious
38d80c9ce5 fix(ui): clear cleanupFunctions when finished calling them 2025-06-30 12:56:51 +10:00
psychedelicious
6acaa8abbf refactor(ui): use deferred promise as workaround to antipattern of async promise executor 2025-06-30 12:56:51 +10:00
psychedelicious
4b84e34599 refactor(ui): better race condition handling in runGraph 2025-06-30 12:56:51 +10:00
psychedelicious
bbd21b1eb2 feat(ui): rename isSettled -> isFinished 2025-06-30 12:56:51 +10:00
psychedelicious
4fa83a6228 feat(ui): better error handling for runGraph 2025-06-30 12:56:51 +10:00
psychedelicious
051876dcff feat(ui): ensure promise always marked as settled, better comments 2025-06-30 12:56:51 +10:00
psychedelicious
8dc6d0b5ae feat(ui): use runGraph in canvas 2025-06-30 12:56:51 +10:00
psychedelicious
40e9624954 tests(ui): edge cases in runGraph 2025-06-30 12:56:51 +10:00
psychedelicious
ae27c83dc4 feat(ui): log when cancelation fails 2025-06-30 12:56:51 +10:00
psychedelicious
161059551b fix(ui): handle errors during cleanup 2025-06-30 12:56:51 +10:00
psychedelicious
c196f8a5d5 tests(ui): add tests for runGraph 2025-06-30 12:56:51 +10:00
psychedelicious
2c6d22664e feat(ui): use DI to make runGraph testable 2025-06-30 12:56:51 +10:00
psychedelicious
b9ce5389ef fix(ui): clean up signal 2025-06-30 12:56:51 +10:00
psychedelicious
d1cbf56695 feat(ui): iterate on runGraph 2025-06-30 12:56:51 +10:00
psychedelicious
e379ac12c3 feat(ui): abstraction to make a graph await-able 2025-06-30 12:56:51 +10:00
psychedelicious
aa10373292 feat(ui): loosen typings for Result 2025-06-30 12:56:51 +10:00
psychedelicious
780f3692a0 chore(ui): typegen 2025-06-30 12:56:51 +10:00
psychedelicious
3604dcfdd1 feat(api): return list of enqueued item ids when enqueuing 2025-06-30 12:56:51 +10:00
Jonathan
2b1cffde5e typegen 2025-06-30 11:28:02 +10:00
Jonathan
83d642ed15 Update flux_denoise.py
Fixed version to 4.0.0
2025-06-30 11:28:02 +10:00
Jonathan
455c73235e Update flux_denoise.py
Updated version, removed WithBoard and WithMetadata
2025-06-30 11:28:02 +10:00
psychedelicious
8efef8da41 feat(ui): workflows styling tweaks 2025-06-30 11:17:29 +10:00
psychedelicious
060a9e57b9 fix(ui): prevent NaN from getting into konva internals 2025-06-30 10:43:11 +10:00
skunkworxdark
099d75ca1e use "\u2581" instead of the character itself for clarity 2025-06-30 10:40:31 +10:00
skunkworxdark
bbb5d68146 Update flux_text_encoder.py
Added tokenizer logging to flux
2025-06-30 10:40:31 +10:00
psychedelicious
9066dc1839 tidy(nodes): remove extraneous comments & add useful ones 2025-06-27 18:27:46 +10:00
psychedelicious
075345bffd feat(app): add flux kontext dev to starter modelss 2025-06-27 18:27:46 +10:00
psychedelicious
74d1239c87 chore(ui): typegen 2025-06-27 18:27:46 +10:00
Kent Keirsey
51e1c56636 ruff 2025-06-27 18:27:46 +10:00
Kent Keirsey
ca1df60e54 Explain the Magic 2025-06-27 18:27:46 +10:00
Cursor Agent
7549c1250d Add FLUX Kontext conditioning support for reference images
Co-authored-by: kent <kent@invoke.ai>

Fix Kontext sequence length handling in Flux denoise invocation

Co-authored-by: kent <kent@invoke.ai>

Fix Kontext step callback to handle combined token sequences

Co-authored-by: kent <kent@invoke.ai>

fix ruff

Fix Flux Kontext
2025-06-27 18:27:46 +10:00
psychedelicious
df8751b5a1 fix(ui): remove extraneous rect in stagingareamodule 2025-06-27 15:45:53 +10:00
psychedelicious
651b80b997 fix(ui): remove extraneous syncPlaceholderSize method and calls 2025-06-27 15:45:53 +10:00
psychedelicious
5d236ae4e7 fix(ui): canvas staging waiting for image placeholder sizing and layout 2025-06-27 15:45:53 +10:00
psychedelicious
e5dc606f5e fix(ui): get accurate theme tokens 2025-06-27 15:45:53 +10:00
Kent Keirsey
dc6b8e13bd prettier 2025-06-27 15:45:53 +10:00
Cursor Agent
c1b34e1f11 Standardize UI spacing and constants across canvas and image components
Co-authored-by: kent <kent@invoke.ai>
2025-06-27 15:45:53 +10:00
Cursor Agent
89f1684072 Improve placeholder styling with badge and refined text positioning
Co-authored-by: kent <kent@invoke.ai>
2025-06-27 15:45:53 +10:00
Kent Keirsey
14fbee17a3 Rule of 3rds Composition Guide (#8130)
* Add Rule of 4 composition guide to canvas settings and rendering

Co-authored-by: kent <kent@invoke.ai>

* Rename Rule of 4 Guide to Rule of Thirds in canvas composition guide

Co-authored-by: kent <kent@invoke.ai>

* Updates to comp guide and naming

* Fix reference

* Update translation keys and organize settings.

* revert to previous canvas manager for conflict

* Re-add composition guide.

* Fix lint

* prettier

* feat(ui): improve markup in canvas settings popover

* feat(ui): use brand colors for canvas rule of thirds guide

---------

Co-authored-by: Cursor Agent <cursoragent@cursor.com>
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
2025-06-27 15:05:34 +10:00
psychedelicious
5dbc32e06e feat(ui): minor restyle of style preset list 2025-06-27 14:40:35 +10:00
psychedelicious
23baf61e51 fix(ui): remove extraneous slice migration for style presets 2025-06-27 14:40:35 +10:00
Kent Keirsey
5e55f6074b prettier 2025-06-27 14:40:35 +10:00
Kent Keirsey
f7c555e501 Change to Toggle Tooltip 2025-06-27 14:40:35 +10:00
Cursor Agent
6aa605e811 Add toggle for showing/hiding style preset prompt previews
Co-authored-by: kent <kent@invoke.ai>
2025-06-27 14:40:35 +10:00
psychedelicious
f51014e108 feat(ui): make launchpad button its own component 2025-06-27 14:37:30 +10:00
psychedelicious
9862ba9210 feat(ui): improved starter model buttons & tooltips 2025-06-27 14:37:30 +10:00
psychedelicious
920aea08cc tidy(ui): remove unused translation strings 2025-06-27 14:37:30 +10:00
psychedelicious
39e584297e feat(ui): fix missing translations 2025-06-27 14:37:30 +10:00
psychedelicious
62a14bb935 feat(ui): use enriched starter model metadata 2025-06-27 14:37:30 +10:00
psychedelicious
d7ae2cdf75 chore(ui): typegen 2025-06-27 14:37:30 +10:00
psychedelicious
6172c859ac feat(api): enrich starer model bundle metadata 2025-06-27 14:37:30 +10:00
psychedelicious
b26fb1f617 feat(ui): simplify markup for install models launchpad form 2025-06-27 14:37:30 +10:00
psychedelicious
05167dfd7a feat(ui): use existing design language for install model bundle buttons 2025-06-27 14:37:30 +10:00
psychedelicious
c090ea7387 feat(ui): use existing design language for install model launchpad buttons 2025-06-27 14:37:30 +10:00
psychedelicious
7ba6c67049 feat(ui): named install models tabs 2025-06-27 14:37:30 +10:00
psychedelicious
3de186061d chore(ui): lint 2025-06-27 14:37:30 +10:00
Kent Keirsey
a716381733 Model Launchpad prettier 2025-06-27 14:37:30 +10:00
Kent Keirsey
fb5df06835 Updating toinclude translations and import fixes 2025-06-27 14:37:30 +10:00
Kent Keirsey
33c597c224 fix lint 2025-06-27 14:37:30 +10:00
Kent Keirsey
19d882d038 Address comments 2025-06-27 14:37:30 +10:00
Kent Keirsey
ee4bc49bd4 Prettier. 2025-06-27 14:37:30 +10:00
Kent Keirsey
188cf37f48 fix lint 2025-06-27 14:37:30 +10:00
Kent Keirsey
15a0a7134c fix circ dependency 2025-06-27 14:37:30 +10:00
Kent Keirsey
22cea0de8b Remove scrap 2025-06-27 14:37:30 +10:00
Kent Keirsey
cd21816d12 Model Launchpad 2025-06-27 14:37:30 +10:00
psychedelicious
605b912ba4 fix(ui): remove noop hook 2025-06-27 11:37:47 +10:00
psychedelicious
52e31112f9 chore(ui): lint 2025-06-27 11:37:47 +10:00
Kent Keirsey
a4c9346cd7 lint 2025-06-27 11:37:47 +10:00
Kent Keirsey
a1647e4c6e Address comments 2025-06-27 11:37:47 +10:00
Kent Keirsey
8c9ca088a7 update tooltip 2025-06-27 11:37:47 +10:00
Cursor Agent
7a7a2e147c Add toggle for non-raster layers with hotkey and UI button 2025-06-27 11:37:47 +10:00
psychedelicious
adf4cc750a fix(ui): Fix LoRA picker to default to current base model architecture (#8135)
Enhance LoRA picker to default filter by current base model architecture

## Summary
Fixes new LoRA picker to auto select the architecture filter for the
current model group

## Related Issues / Discussions
N/A

## QA Instructions

Open LoRA menu with any model group selected. The right models should be
filtered.

## Merge Plan
Merge when ready.

## Checklist

- [X] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
2025-06-27 11:21:39 +10:00
psychedelicious
9f1ea9d1c7 fix(ui): use existing GroupStatusMap type 2025-06-27 11:19:24 +10:00
Cursor Agent
571d286506 Enhance LoRA picker to default to current base model architecture
Co-authored-by: kent <kent@invoke.ai>

Enhance LoRA picker to default filter by current base model architecture

Co-authored-by: kent <kent@invoke.ai>
2025-06-26 20:43:43 -04:00
Mary Hipp
1320a2c5f8 add option to override text for no options available 2025-06-26 18:09:57 -04:00
Mary Hipp
26a9b3131d convert LoRA picker to use new model picker component 2025-06-26 18:09:57 -04:00
psychedelicious
d48140b35d fix(ui): regional guidance ref image not selecting 2025-06-26 10:05:25 -04:00
422 changed files with 22829 additions and 16166 deletions

View File

@@ -3,15 +3,15 @@ description: Installs frontend dependencies with pnpm, with caching
runs:
using: 'composite'
steps:
- name: setup node 18
- name: setup node 20
uses: actions/setup-node@v4
with:
node-version: '18'
node-version: '20'
- name: setup pnpm
uses: pnpm/action-setup@v4
with:
version: 8.15.6
version: 10
run_install: false
- name: get pnpm store directory

View File

@@ -35,7 +35,7 @@ More detail on system requirements can be found [here](./requirements.md).
## Step 2: Download
Download the most launcher for your operating system:
Download the most recent launcher for your operating system:
- [Download for Windows](https://download.invoke.ai/Invoke%20Community%20Edition.exe)
- [Download for macOS](https://download.invoke.ai/Invoke%20Community%20Edition.dmg)

View File

@@ -72,7 +72,7 @@ async def upload_image(
resize_to: Optional[str] = Body(
default=None,
description=f"Dimensions to resize the image to, must be stringified tuple of 2 integers. Max total pixel count: {ResizeToDimensions.MAX_SIZE}",
example='"[1024,1024]"',
examples=['"[1024,1024]"'],
),
metadata: Optional[str] = Body(
default=None,

View File

@@ -41,6 +41,7 @@ from invokeai.backend.model_manager.starter_models import (
STARTER_BUNDLES,
STARTER_MODELS,
StarterModel,
StarterModelBundle,
StarterModelWithoutDependencies,
)
@@ -291,7 +292,7 @@ async def get_hugging_face_models(
)
async def update_model_record(
key: Annotated[str, Path(description="Unique key of model")],
changes: Annotated[ModelRecordChanges, Body(description="Model config", example=example_model_input)],
changes: Annotated[ModelRecordChanges, Body(description="Model config", examples=[example_model_input])],
) -> AnyModelConfig:
"""Update a model's config."""
logger = ApiDependencies.invoker.services.logger
@@ -449,7 +450,7 @@ async def install_model(
access_token: Optional[str] = Query(description="access token for the remote resource", default=None),
config: ModelRecordChanges = Body(
description="Object containing fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
example={"name": "string", "description": "string"},
examples=[{"name": "string", "description": "string"}],
),
) -> ModelInstallJob:
"""Install a model using a string identifier.
@@ -799,7 +800,7 @@ async def convert_model(
class StarterModelResponse(BaseModel):
starter_models: list[StarterModel]
starter_bundles: dict[str, list[StarterModel]]
starter_bundles: dict[str, StarterModelBundle]
def get_is_installed(
@@ -833,7 +834,7 @@ async def get_starter_models() -> StarterModelResponse:
model.dependencies = missing_deps
for bundle in starter_bundles.values():
for model in bundle:
for model in bundle.models:
model.is_installed = get_is_installed(model, installed_models)
# Remove already-installed dependencies
missing_deps: list[StarterModelWithoutDependencies] = []

View File

@@ -1,6 +1,6 @@
from typing import Optional
from fastapi import Body, Path, Query
from fastapi import Body, HTTPException, Path, Query
from fastapi.routing import APIRouter
from pydantic import BaseModel, Field
@@ -22,6 +22,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
RetryItemsResult,
SessionQueueCountsByDestination,
SessionQueueItem,
SessionQueueItemNotFoundError,
SessionQueueStatus,
)
from invokeai.app.services.shared.pagination import CursorPaginatedResults
@@ -59,10 +60,12 @@ async def enqueue_batch(
),
) -> EnqueueBatchResult:
"""Processes a batch and enqueues the output graphs for execution."""
return await ApiDependencies.invoker.services.session_queue.enqueue_batch(
queue_id=queue_id, batch=batch, prepend=prepend
)
try:
return await ApiDependencies.invoker.services.session_queue.enqueue_batch(
queue_id=queue_id, batch=batch, prepend=prepend
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while enqueuing batch: {e}")
@session_queue_router.get(
@@ -82,14 +85,17 @@ async def list_queue_items(
) -> CursorPaginatedResults[SessionQueueItem]:
"""Gets cursor-paginated queue items"""
return ApiDependencies.invoker.services.session_queue.list_queue_items(
queue_id=queue_id,
limit=limit,
status=status,
cursor=cursor,
priority=priority,
destination=destination,
)
try:
return ApiDependencies.invoker.services.session_queue.list_queue_items(
queue_id=queue_id,
limit=limit,
status=status,
cursor=cursor,
priority=priority,
destination=destination,
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while listing all items: {e}")
@session_queue_router.get(
@@ -104,11 +110,13 @@ async def list_all_queue_items(
destination: Optional[str] = Query(default=None, description="The destination of queue items to fetch"),
) -> list[SessionQueueItem]:
"""Gets all queue items"""
return ApiDependencies.invoker.services.session_queue.list_all_queue_items(
queue_id=queue_id,
destination=destination,
)
try:
return ApiDependencies.invoker.services.session_queue.list_all_queue_items(
queue_id=queue_id,
destination=destination,
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while listing all queue items: {e}")
@session_queue_router.put(
@@ -120,7 +128,10 @@ async def resume(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> SessionProcessorStatus:
"""Resumes session processor"""
return ApiDependencies.invoker.services.session_processor.resume()
try:
return ApiDependencies.invoker.services.session_processor.resume()
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while resuming queue: {e}")
@session_queue_router.put(
@@ -132,7 +143,10 @@ async def Pause(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> SessionProcessorStatus:
"""Pauses session processor"""
return ApiDependencies.invoker.services.session_processor.pause()
try:
return ApiDependencies.invoker.services.session_processor.pause()
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while pausing queue: {e}")
@session_queue_router.put(
@@ -144,7 +158,10 @@ async def cancel_all_except_current(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> CancelAllExceptCurrentResult:
"""Immediately cancels all queue items except in-processing items"""
return ApiDependencies.invoker.services.session_queue.cancel_all_except_current(queue_id=queue_id)
try:
return ApiDependencies.invoker.services.session_queue.cancel_all_except_current(queue_id=queue_id)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling all except current: {e}")
@session_queue_router.put(
@@ -156,7 +173,10 @@ async def delete_all_except_current(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> DeleteAllExceptCurrentResult:
"""Immediately deletes all queue items except in-processing items"""
return ApiDependencies.invoker.services.session_queue.delete_all_except_current(queue_id=queue_id)
try:
return ApiDependencies.invoker.services.session_queue.delete_all_except_current(queue_id=queue_id)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while deleting all except current: {e}")
@session_queue_router.put(
@@ -169,7 +189,12 @@ async def cancel_by_batch_ids(
batch_ids: list[str] = Body(description="The list of batch_ids to cancel all queue items for", embed=True),
) -> CancelByBatchIDsResult:
"""Immediately cancels all queue items from the given batch ids"""
return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(queue_id=queue_id, batch_ids=batch_ids)
try:
return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(
queue_id=queue_id, batch_ids=batch_ids
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling by batch id: {e}")
@session_queue_router.put(
@@ -182,9 +207,12 @@ async def cancel_by_destination(
destination: str = Query(description="The destination to cancel all queue items for"),
) -> CancelByDestinationResult:
"""Immediately cancels all queue items with the given origin"""
return ApiDependencies.invoker.services.session_queue.cancel_by_destination(
queue_id=queue_id, destination=destination
)
try:
return ApiDependencies.invoker.services.session_queue.cancel_by_destination(
queue_id=queue_id, destination=destination
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling by destination: {e}")
@session_queue_router.put(
@@ -197,7 +225,10 @@ async def retry_items_by_id(
item_ids: list[int] = Body(description="The queue item ids to retry"),
) -> RetryItemsResult:
"""Immediately cancels all queue items with the given origin"""
return ApiDependencies.invoker.services.session_queue.retry_items_by_id(queue_id=queue_id, item_ids=item_ids)
try:
return ApiDependencies.invoker.services.session_queue.retry_items_by_id(queue_id=queue_id, item_ids=item_ids)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while retrying queue items: {e}")
@session_queue_router.put(
@@ -211,11 +242,14 @@ async def clear(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> ClearResult:
"""Clears the queue entirely, immediately canceling the currently-executing session"""
queue_item = ApiDependencies.invoker.services.session_queue.get_current(queue_id)
if queue_item is not None:
ApiDependencies.invoker.services.session_queue.cancel_queue_item(queue_item.item_id)
clear_result = ApiDependencies.invoker.services.session_queue.clear(queue_id)
return clear_result
try:
queue_item = ApiDependencies.invoker.services.session_queue.get_current(queue_id)
if queue_item is not None:
ApiDependencies.invoker.services.session_queue.cancel_queue_item(queue_item.item_id)
clear_result = ApiDependencies.invoker.services.session_queue.clear(queue_id)
return clear_result
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while clearing queue: {e}")
@session_queue_router.put(
@@ -229,7 +263,10 @@ async def prune(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> PruneResult:
"""Prunes all completed or errored queue items"""
return ApiDependencies.invoker.services.session_queue.prune(queue_id)
try:
return ApiDependencies.invoker.services.session_queue.prune(queue_id)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while pruning queue: {e}")
@session_queue_router.get(
@@ -243,7 +280,10 @@ async def get_current_queue_item(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> Optional[SessionQueueItem]:
"""Gets the currently execution queue item"""
return ApiDependencies.invoker.services.session_queue.get_current(queue_id)
try:
return ApiDependencies.invoker.services.session_queue.get_current(queue_id)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while getting current queue item: {e}")
@session_queue_router.get(
@@ -257,7 +297,10 @@ async def get_next_queue_item(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> Optional[SessionQueueItem]:
"""Gets the next queue item, without executing it"""
return ApiDependencies.invoker.services.session_queue.get_next(queue_id)
try:
return ApiDependencies.invoker.services.session_queue.get_next(queue_id)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while getting next queue item: {e}")
@session_queue_router.get(
@@ -271,9 +314,12 @@ async def get_queue_status(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> SessionQueueAndProcessorStatus:
"""Gets the status of the session queue"""
queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id)
processor = ApiDependencies.invoker.services.session_processor.get_status()
return SessionQueueAndProcessorStatus(queue=queue, processor=processor)
try:
queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id)
processor = ApiDependencies.invoker.services.session_processor.get_status()
return SessionQueueAndProcessorStatus(queue=queue, processor=processor)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while getting queue status: {e}")
@session_queue_router.get(
@@ -288,7 +334,10 @@ async def get_batch_status(
batch_id: str = Path(description="The batch to get the status of"),
) -> BatchStatus:
"""Gets the status of the session queue"""
return ApiDependencies.invoker.services.session_queue.get_batch_status(queue_id=queue_id, batch_id=batch_id)
try:
return ApiDependencies.invoker.services.session_queue.get_batch_status(queue_id=queue_id, batch_id=batch_id)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while getting batch status: {e}")
@session_queue_router.get(
@@ -304,7 +353,12 @@ async def get_queue_item(
item_id: int = Path(description="The queue item to get"),
) -> SessionQueueItem:
"""Gets a queue item"""
return ApiDependencies.invoker.services.session_queue.get_queue_item(item_id)
try:
return ApiDependencies.invoker.services.session_queue.get_queue_item(item_id)
except SessionQueueItemNotFoundError:
raise HTTPException(status_code=404, detail=f"Queue item with id {item_id} not found in queue {queue_id}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while fetching queue item: {e}")
@session_queue_router.delete(
@@ -316,7 +370,10 @@ async def delete_queue_item(
item_id: int = Path(description="The queue item to delete"),
) -> None:
"""Deletes a queue item"""
ApiDependencies.invoker.services.session_queue.delete_queue_item(item_id)
try:
ApiDependencies.invoker.services.session_queue.delete_queue_item(item_id)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while deleting queue item: {e}")
@session_queue_router.put(
@@ -331,8 +388,12 @@ async def cancel_queue_item(
item_id: int = Path(description="The queue item to cancel"),
) -> SessionQueueItem:
"""Deletes a queue item"""
return ApiDependencies.invoker.services.session_queue.cancel_queue_item(item_id)
try:
return ApiDependencies.invoker.services.session_queue.cancel_queue_item(item_id)
except SessionQueueItemNotFoundError:
raise HTTPException(status_code=404, detail=f"Queue item with id {item_id} not found in queue {queue_id}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling queue item: {e}")
@session_queue_router.get(
@@ -345,9 +406,12 @@ async def counts_by_destination(
destination: str = Query(description="The destination to query"),
) -> SessionQueueCountsByDestination:
"""Gets the counts of queue items by destination"""
return ApiDependencies.invoker.services.session_queue.get_counts_by_destination(
queue_id=queue_id, destination=destination
)
try:
return ApiDependencies.invoker.services.session_queue.get_counts_by_destination(
queue_id=queue_id, destination=destination
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while fetching counts by destination: {e}")
@session_queue_router.delete(
@@ -360,6 +424,9 @@ async def delete_by_destination(
destination: str = Path(description="The destination to query"),
) -> DeleteByDestinationResult:
"""Deletes all items with the given destination"""
return ApiDependencies.invoker.services.session_queue.delete_by_destination(
queue_id=queue_id, destination=destination
)
try:
return ApiDependencies.invoker.services.session_queue.delete_by_destination(
queue_id=queue_id, destination=destination
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while deleting by destination: {e}")

View File

@@ -215,6 +215,7 @@ class FieldDescriptions:
flux_redux_conditioning = "FLUX Redux conditioning tensor"
vllm_model = "The VLLM model to use"
flux_fill_conditioning = "FLUX Fill conditioning tensor"
flux_kontext_conditioning = "FLUX Kontext conditioning (reference image)"
class ImageField(BaseModel):
@@ -291,6 +292,12 @@ class FluxFillConditioningField(BaseModel):
mask: TensorField = Field(description="The FLUX Fill inpaint mask.")
class FluxKontextConditioningField(BaseModel):
"""A conditioning field for FLUX Kontext (reference image)."""
image: ImageField = Field(description="The Kontext reference image.")
class SD3ConditioningField(BaseModel):
"""A conditioning tensor primitive value"""

View File

@@ -16,13 +16,12 @@ from invokeai.app.invocations.fields import (
FieldDescriptions,
FluxConditioningField,
FluxFillConditioningField,
FluxKontextConditioningField,
FluxReduxConditioningField,
ImageField,
Input,
InputField,
LatentsField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.flux_controlnet import FluxControlNetField
from invokeai.app.invocations.flux_vae_encode import FluxVaeEncodeInvocation
@@ -34,6 +33,7 @@ from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXCo
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux
from invokeai.backend.flux.denoise import denoise
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
from invokeai.backend.flux.extensions.kontext_extension import KontextExtension
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
@@ -63,9 +63,9 @@ from invokeai.backend.util.devices import TorchDevice
title="FLUX Denoise",
tags=["image", "flux"],
category="image",
version="3.3.0",
version="4.0.0",
)
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
class FluxDenoiseInvocation(BaseInvocation):
"""Run denoising process with a FLUX transformer model."""
# If latents is provided, this means we are doing image-to-image.
@@ -145,11 +145,20 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
description=FieldDescriptions.vae,
input=Input.Connection,
)
# This node accepts a images for features like FLUX Fill, ControlNet, and Kontext, but needs to operate on them in
# latent space. We'll run the VAE to encode them in this node instead of requiring the user to run the VAE in
# upstream nodes.
ip_adapter: IPAdapterField | list[IPAdapterField] | None = InputField(
description=FieldDescriptions.ip_adapter, title="IP-Adapter", default=None, input=Input.Connection
)
kontext_conditioning: Optional[FluxKontextConditioningField] = InputField(
default=None,
description="FLUX Kontext conditioning (reference image).",
input=Input.Connection,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = self._run_diffusion(context)
@@ -376,6 +385,27 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
dtype=inference_dtype,
)
kontext_extension = None
if self.kontext_conditioning is not None:
if not self.controlnet_vae:
raise ValueError("A VAE (e.g., controlnet_vae) must be provided to use Kontext conditioning.")
kontext_extension = KontextExtension(
context=context,
kontext_conditioning=self.kontext_conditioning,
vae_field=self.controlnet_vae,
device=TorchDevice.choose_torch_device(),
dtype=inference_dtype,
)
# Prepare Kontext conditioning if provided
img_cond_seq = None
img_cond_seq_ids = None
if kontext_extension is not None:
# Ensure batch sizes match
kontext_extension.ensure_batch_size(x.shape[0])
img_cond_seq, img_cond_seq_ids = kontext_extension.kontext_latents, kontext_extension.kontext_ids
x = denoise(
model=transformer,
img=x,
@@ -391,6 +421,8 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
pos_ip_adapter_extensions=pos_ip_adapter_extensions,
neg_ip_adapter_extensions=neg_ip_adapter_extensions,
img_cond=img_cond,
img_cond_seq=img_cond_seq,
img_cond_seq_ids=img_cond_seq_ids,
)
x = unpack(x.float(), self.height, self.width)
@@ -865,7 +897,10 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
def step_callback(state: PipelineIntermediateState) -> None:
state.latents = unpack(state.latents.float(), self.height, self.width).squeeze()
# The denoise function now handles Kontext conditioning correctly,
# so we don't need to slice the latents here
latents = state.latents.float()
state.latents = unpack(latents, self.height, self.width).squeeze()
context.util.flux_step_callback(state)
return step_callback

View File

@@ -0,0 +1,40 @@
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import (
FieldDescriptions,
FluxKontextConditioningField,
InputField,
OutputField,
)
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.services.shared.invocation_context import InvocationContext
@invocation_output("flux_kontext_output")
class FluxKontextOutput(BaseInvocationOutput):
"""The conditioning output of a FLUX Kontext invocation."""
kontext_cond: FluxKontextConditioningField = OutputField(
description=FieldDescriptions.flux_kontext_conditioning, title="Kontext Conditioning"
)
@invocation(
"flux_kontext",
title="Kontext Conditioning - FLUX",
tags=["conditioning", "kontext", "flux"],
category="conditioning",
version="1.0.0",
)
class FluxKontextInvocation(BaseInvocation):
"""Prepares a reference image for FLUX Kontext conditioning."""
image: ImageField = InputField(description="The Kontext reference image.")
def invoke(self, context: InvocationContext) -> FluxKontextOutput:
"""Packages the provided image into a Kontext conditioning field."""
return FluxKontextOutput(kontext_cond=FluxKontextConditioningField(image=self.image))

View File

@@ -1,5 +1,5 @@
from contextlib import ExitStack
from typing import Iterator, Literal, Optional, Tuple
from typing import Iterator, Literal, Optional, Tuple, Union
import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer, T5TokenizerFast
@@ -111,6 +111,9 @@ class FluxTextEncoderInvocation(BaseInvocation):
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)
if context.config.get().log_tokenization:
self._log_t5_tokenization(context, t5_tokenizer)
context.util.signal_progress("Running T5 encoder")
prompt_embeds = t5_encoder(prompt)
@@ -151,6 +154,9 @@ class FluxTextEncoderInvocation(BaseInvocation):
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
if context.config.get().log_tokenization:
self._log_clip_tokenization(context, clip_tokenizer)
context.util.signal_progress("Running CLIP encoder")
pooled_prompt_embeds = clip_encoder(prompt)
@@ -170,3 +176,88 @@ class FluxTextEncoderInvocation(BaseInvocation):
assert isinstance(lora_info.model, ModelPatchRaw)
yield (lora_info.model, lora.weight)
del lora_info
def _log_t5_tokenization(
self,
context: InvocationContext,
tokenizer: Union[T5Tokenizer, T5TokenizerFast],
) -> None:
"""Logs the tokenization of a prompt for a T5-based model like FLUX."""
# Tokenize the prompt using the same parameters as the model's text encoder.
# T5 tokenizers add an EOS token (</s>) and then pad to max_length.
tokenized_output = tokenizer(
self.prompt,
padding="max_length",
max_length=self.t5_max_seq_len,
truncation=True,
add_special_tokens=True, # This is important for T5 to add the EOS token.
return_tensors="pt",
)
input_ids = tokenized_output.input_ids[0]
tokens = tokenizer.convert_ids_to_tokens(input_ids)
# The T5 tokenizer uses a space-like character ' ' (U+2581) to denote spaces.
# We'll replace it with a regular space for readability.
tokens = [t.replace("\u2581", " ") for t in tokens]
tokenized_str = ""
used_tokens = 0
for token in tokens:
if token == tokenizer.eos_token:
tokenized_str += f"\x1b[0;31m{token}\x1b[0m" # Red for EOS
used_tokens += 1
elif token == tokenizer.pad_token:
# tokenized_str += f"\x1b[0;34m{token}\x1b[0m" # Blue for PAD
continue
else:
color = (used_tokens % 6) + 1 # Cycle through 6 colors
tokenized_str += f"\x1b[0;3{color}m{token}\x1b[0m"
used_tokens += 1
context.logger.info(f">> [T5 TOKENLOG] Tokens ({used_tokens}/{self.t5_max_seq_len}):")
context.logger.info(f"{tokenized_str}\x1b[0m")
def _log_clip_tokenization(
self,
context: InvocationContext,
tokenizer: CLIPTokenizer,
) -> None:
"""Logs the tokenization of a prompt for a CLIP-based model."""
max_length = tokenizer.model_max_length
tokenized_output = tokenizer(
self.prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
input_ids = tokenized_output.input_ids[0]
attention_mask = tokenized_output.attention_mask[0]
tokens = tokenizer.convert_ids_to_tokens(input_ids)
# The CLIP tokenizer uses '</w>' to denote spaces.
# We'll replace it with a regular space for readability.
tokens = [t.replace("</w>", " ") for t in tokens]
tokenized_str = ""
used_tokens = 0
for i, token in enumerate(tokens):
if attention_mask[i] == 0:
# Do not log padding tokens.
continue
if token == tokenizer.bos_token:
tokenized_str += f"\x1b[0;32m{token}\x1b[0m" # Green for BOS
elif token == tokenizer.eos_token:
tokenized_str += f"\x1b[0;31m{token}\x1b[0m" # Red for EOS
else:
color = (used_tokens % 6) + 1 # Cycle through 6 colors
tokenized_str += f"\x1b[0;3{color}m{token}\x1b[0m"
used_tokens += 1
context.logger.info(f">> [CLIP TOKENLOG] Tokens ({used_tokens}/{max_length}):")
context.logger.info(f"{tokenized_str}\x1b[0m")

View File

@@ -14,15 +14,14 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._conn = db.conn
self._db = db
def add_image_to_board(
self,
board_id: str,
image_name: str,
) -> None:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
INSERT INTO board_images (board_id, image_name)
@@ -31,17 +30,12 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
""",
(board_id, image_name, board_id),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise e
def remove_image_from_board(
self,
image_name: str,
) -> None:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
DELETE FROM board_images
@@ -49,10 +43,6 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
""",
(image_name,),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise e
def get_images_for_board(
self,
@@ -60,27 +50,26 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
offset: int = 0,
limit: int = 10,
) -> OffsetPaginatedResults[ImageRecord]:
# TODO: this isn't paginated yet?
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT images.*
FROM board_images
INNER JOIN images ON board_images.image_name = images.image_name
WHERE board_images.board_id = ?
ORDER BY board_images.updated_at DESC;
""",
(board_id,),
)
result = cast(list[sqlite3.Row], cursor.fetchall())
images = [deserialize_image_record(dict(r)) for r in result]
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT images.*
FROM board_images
INNER JOIN images ON board_images.image_name = images.image_name
WHERE board_images.board_id = ?
ORDER BY board_images.updated_at DESC;
""",
(board_id,),
)
result = cast(list[sqlite3.Row], cursor.fetchall())
images = [deserialize_image_record(dict(r)) for r in result]
cursor.execute(
"""--sql
SELECT COUNT(*) FROM images WHERE 1=1;
"""
)
count = cast(int, cursor.fetchone()[0])
cursor.execute(
"""--sql
SELECT COUNT(*) FROM images WHERE 1=1;
"""
)
count = cast(int, cursor.fetchone()[0])
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
@@ -90,56 +79,55 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
categories: list[ImageCategory] | None,
is_intermediate: bool | None,
) -> list[str]:
params: list[str | bool] = []
with self._db.transaction() as cursor:
params: list[str | bool] = []
# Base query is a join between images and board_images
stmt = """
SELECT images.image_name
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
# Base query is a join between images and board_images
stmt = """
SELECT images.image_name
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
# Handle board_id filter
if board_id == "none":
stmt += """--sql
AND board_images.board_id IS NULL
"""
else:
stmt += """--sql
AND board_images.board_id = ?
"""
params.append(board_id)
# Handle board_id filter
if board_id == "none":
stmt += """--sql
AND board_images.board_id IS NULL
"""
else:
stmt += """--sql
AND board_images.board_id = ?
"""
params.append(board_id)
# Add the category filter
if categories is not None:
# Convert the enum values to unique list of strings
category_strings = [c.value for c in set(categories)]
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
stmt += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
# Add the category filter
if categories is not None:
# Convert the enum values to unique list of strings
category_strings = [c.value for c in set(categories)]
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
stmt += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
# Unpack the included categories into the query params
for c in category_strings:
params.append(c)
# Unpack the included categories into the query params
for c in category_strings:
params.append(c)
# Add the is_intermediate filter
if is_intermediate is not None:
stmt += """--sql
AND images.is_intermediate = ?
"""
params.append(is_intermediate)
# Add the is_intermediate filter
if is_intermediate is not None:
stmt += """--sql
AND images.is_intermediate = ?
"""
params.append(is_intermediate)
# Put a ring on it
stmt += ";"
# Put a ring on it
stmt += ";"
# Execute the query
cursor = self._conn.cursor()
cursor.execute(stmt, params)
cursor.execute(stmt, params)
result = cast(list[sqlite3.Row], cursor.fetchall())
result = cast(list[sqlite3.Row], cursor.fetchall())
image_names = [r[0] for r in result]
return image_names
@@ -147,31 +135,31 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
self,
image_name: str,
) -> Optional[str]:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT board_id
FROM board_images
WHERE image_name = ?;
""",
(image_name,),
)
result = cursor.fetchone()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT board_id
FROM board_images
WHERE image_name = ?;
""",
(image_name,),
)
result = cursor.fetchone()
if result is None:
return None
return cast(str, result[0])
def get_image_count_for_board(self, board_id: str) -> int:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT COUNT(*)
FROM board_images
INNER JOIN images ON board_images.image_name = images.image_name
WHERE images.is_intermediate = FALSE
AND board_images.board_id = ?;
""",
(board_id,),
)
count = cast(int, cursor.fetchone()[0])
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT COUNT(*)
FROM board_images
INNER JOIN images ON board_images.image_name = images.image_name
WHERE images.is_intermediate = FALSE
AND board_images.board_id = ?;
""",
(board_id,),
)
count = cast(int, cursor.fetchone()[0])
return count

View File

@@ -20,61 +20,57 @@ from invokeai.app.util.misc import uuid_string
class SqliteBoardRecordStorage(BoardRecordStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._conn = db.conn
self._db = db
def delete(self, board_id: str) -> None:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
DELETE FROM boards
WHERE board_id = ?;
""",
(board_id,),
)
self._conn.commit()
except Exception as e:
self._conn.rollback()
raise BoardRecordDeleteException from e
with self._db.transaction() as cursor:
try:
cursor.execute(
"""--sql
DELETE FROM boards
WHERE board_id = ?;
""",
(board_id,),
)
except Exception as e:
raise BoardRecordDeleteException from e
def save(
self,
board_name: str,
) -> BoardRecord:
try:
board_id = uuid_string()
cursor = self._conn.cursor()
cursor.execute(
"""--sql
INSERT OR IGNORE INTO boards (board_id, board_name)
VALUES (?, ?);
""",
(board_id, board_name),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordSaveException from e
with self._db.transaction() as cursor:
try:
board_id = uuid_string()
cursor.execute(
"""--sql
INSERT OR IGNORE INTO boards (board_id, board_name)
VALUES (?, ?);
""",
(board_id, board_name),
)
except sqlite3.Error as e:
raise BoardRecordSaveException from e
return self.get(board_id)
def get(
self,
board_id: str,
) -> BoardRecord:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM boards
WHERE board_id = ?;
""",
(board_id,),
)
with self._db.transaction() as cursor:
try:
cursor.execute(
"""--sql
SELECT *
FROM boards
WHERE board_id = ?;
""",
(board_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
except sqlite3.Error as e:
raise BoardRecordNotFoundException from e
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
except sqlite3.Error as e:
raise BoardRecordNotFoundException from e
if result is None:
raise BoardRecordNotFoundException
return BoardRecord(**dict(result))
@@ -84,45 +80,43 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
board_id: str,
changes: BoardChanges,
) -> BoardRecord:
try:
cursor = self._conn.cursor()
# Change the name of a board
if changes.board_name is not None:
cursor.execute(
"""--sql
UPDATE boards
SET board_name = ?
WHERE board_id = ?;
""",
(changes.board_name, board_id),
)
with self._db.transaction() as cursor:
try:
# Change the name of a board
if changes.board_name is not None:
cursor.execute(
"""--sql
UPDATE boards
SET board_name = ?
WHERE board_id = ?;
""",
(changes.board_name, board_id),
)
# Change the cover image of a board
if changes.cover_image_name is not None:
cursor.execute(
"""--sql
UPDATE boards
SET cover_image_name = ?
WHERE board_id = ?;
""",
(changes.cover_image_name, board_id),
)
# Change the cover image of a board
if changes.cover_image_name is not None:
cursor.execute(
"""--sql
UPDATE boards
SET cover_image_name = ?
WHERE board_id = ?;
""",
(changes.cover_image_name, board_id),
)
# Change the archived status of a board
if changes.archived is not None:
cursor.execute(
"""--sql
UPDATE boards
SET archived = ?
WHERE board_id = ?;
""",
(changes.archived, board_id),
)
# Change the archived status of a board
if changes.archived is not None:
cursor.execute(
"""--sql
UPDATE boards
SET archived = ?
WHERE board_id = ?;
""",
(changes.archived, board_id),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordSaveException from e
except sqlite3.Error as e:
raise BoardRecordSaveException from e
return self.get(board_id)
def get_many(
@@ -133,78 +127,77 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
limit: int = 10,
include_archived: bool = False,
) -> OffsetPaginatedResults[BoardRecord]:
cursor = self._conn.cursor()
# Build base query
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY {order_by} {direction}
LIMIT ? OFFSET ?;
"""
# Determine archived filter condition
archived_filter = "" if include_archived else "WHERE archived = 0"
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
# Execute query to fetch boards
cursor.execute(final_query, (limit, offset))
result = cast(list[sqlite3.Row], cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]
# Determine count query
if include_archived:
count_query = """
SELECT COUNT(*)
FROM boards;
"""
else:
count_query = """
SELECT COUNT(*)
with self._db.transaction() as cursor:
# Build base query
base_query = """
SELECT *
FROM boards
WHERE archived = 0;
{archived_filter}
ORDER BY {order_by} {direction}
LIMIT ? OFFSET ?;
"""
# Execute count query
cursor.execute(count_query)
# Determine archived filter condition
archived_filter = "" if include_archived else "WHERE archived = 0"
count = cast(int, cursor.fetchone()[0])
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
# Execute query to fetch boards
cursor.execute(final_query, (limit, offset))
result = cast(list[sqlite3.Row], cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]
# Determine count query
if include_archived:
count_query = """
SELECT COUNT(*)
FROM boards;
"""
else:
count_query = """
SELECT COUNT(*)
FROM boards
WHERE archived = 0;
"""
# Execute count query
cursor.execute(count_query)
count = cast(int, cursor.fetchone()[0])
return OffsetPaginatedResults[BoardRecord](items=boards, offset=offset, limit=limit, total=count)
def get_all(
self, order_by: BoardRecordOrderBy, direction: SQLiteDirection, include_archived: bool = False
) -> list[BoardRecord]:
cursor = self._conn.cursor()
if order_by == BoardRecordOrderBy.Name:
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY LOWER(board_name) {direction}
"""
else:
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY {order_by} {direction}
"""
with self._db.transaction() as cursor:
if order_by == BoardRecordOrderBy.Name:
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY LOWER(board_name) {direction}
"""
else:
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY {order_by} {direction}
"""
archived_filter = "" if include_archived else "WHERE archived = 0"
archived_filter = "" if include_archived else "WHERE archived = 0"
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
cursor.execute(final_query)
cursor.execute(final_query)
result = cast(list[sqlite3.Row], cursor.fetchall())
result = cast(list[sqlite3.Row], cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]
return boards

View File

@@ -8,6 +8,7 @@ import time
import traceback
from pathlib import Path
from queue import Empty, PriorityQueue
from shutil import disk_usage
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Set
import requests
@@ -335,6 +336,14 @@ class DownloadQueueService(DownloadQueueServiceBase):
assert job.download_path
free_space = disk_usage(job.download_path.parent).free
GB = 2**30
self._logger.debug(f"Download is {job.total_bytes / GB:.2f} GB of {free_space / GB:.2f} GB free.")
if free_space < job.total_bytes:
raise RuntimeError(
f"Free disk space {free_space / GB:.2f} GB is not enough for download of {job.total_bytes / GB:.2f} GB."
)
# Don't clobber an existing file. See commit 82c2c85202f88c6d24ff84710f297cfc6ae174af
# for code that instead resumes an interrupted download.
if job.download_path.exists():

View File

@@ -24,22 +24,22 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class SqliteImageRecordStorage(ImageRecordStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._conn = db.conn
self._db = db
def get(self, image_name: str) -> ImageRecord:
try:
cursor = self._conn.cursor()
cursor.execute(
f"""--sql
SELECT {IMAGE_DTO_COLS} FROM images
WHERE image_name = ?;
""",
(image_name,),
)
with self._db.transaction() as cursor:
try:
cursor.execute(
f"""--sql
SELECT {IMAGE_DTO_COLS} FROM images
WHERE image_name = ?;
""",
(image_name,),
)
result = cast(Optional[sqlite3.Row], cursor.fetchone())
except sqlite3.Error as e:
raise ImageRecordNotFoundException from e
result = cast(Optional[sqlite3.Row], cursor.fetchone())
except sqlite3.Error as e:
raise ImageRecordNotFoundException from e
if not result:
raise ImageRecordNotFoundException
@@ -47,17 +47,20 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
return deserialize_image_record(dict(result))
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT metadata FROM images
WHERE image_name = ?;
""",
(image_name,),
)
with self._db.transaction() as cursor:
try:
cursor.execute(
"""--sql
SELECT metadata FROM images
WHERE image_name = ?;
""",
(image_name,),
)
result = cast(Optional[sqlite3.Row], cursor.fetchone())
result = cast(Optional[sqlite3.Row], cursor.fetchone())
except sqlite3.Error as e:
raise ImageRecordNotFoundException from e
if not result:
raise ImageRecordNotFoundException
@@ -65,64 +68,60 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
as_dict = dict(result)
metadata_raw = cast(Optional[str], as_dict.get("metadata", None))
return MetadataFieldValidator.validate_json(metadata_raw) if metadata_raw is not None else None
except sqlite3.Error as e:
raise ImageRecordNotFoundException from e
def update(
self,
image_name: str,
changes: ImageRecordChanges,
) -> None:
try:
cursor = self._conn.cursor()
# Change the category of the image
if changes.image_category is not None:
cursor.execute(
"""--sql
UPDATE images
SET image_category = ?
WHERE image_name = ?;
""",
(changes.image_category, image_name),
)
with self._db.transaction() as cursor:
try:
# Change the category of the image
if changes.image_category is not None:
cursor.execute(
"""--sql
UPDATE images
SET image_category = ?
WHERE image_name = ?;
""",
(changes.image_category, image_name),
)
# Change the session associated with the image
if changes.session_id is not None:
cursor.execute(
"""--sql
UPDATE images
SET session_id = ?
WHERE image_name = ?;
""",
(changes.session_id, image_name),
)
# Change the session associated with the image
if changes.session_id is not None:
cursor.execute(
"""--sql
UPDATE images
SET session_id = ?
WHERE image_name = ?;
""",
(changes.session_id, image_name),
)
# Change the image's `is_intermediate`` flag
if changes.is_intermediate is not None:
cursor.execute(
"""--sql
UPDATE images
SET is_intermediate = ?
WHERE image_name = ?;
""",
(changes.is_intermediate, image_name),
)
# Change the image's `is_intermediate`` flag
if changes.is_intermediate is not None:
cursor.execute(
"""--sql
UPDATE images
SET is_intermediate = ?
WHERE image_name = ?;
""",
(changes.is_intermediate, image_name),
)
# Change the image's `starred`` state
if changes.starred is not None:
cursor.execute(
"""--sql
UPDATE images
SET starred = ?
WHERE image_name = ?;
""",
(changes.starred, image_name),
)
# Change the image's `starred`` state
if changes.starred is not None:
cursor.execute(
"""--sql
UPDATE images
SET starred = ?
WHERE image_name = ?;
""",
(changes.starred, image_name),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordSaveException from e
except sqlite3.Error as e:
raise ImageRecordSaveException from e
def get_many(
self,
@@ -136,170 +135,162 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageRecord]:
cursor = self._conn.cursor()
# Manually build two queries - one for the count, one for the records
count_query = """--sql
SELECT COUNT(*)
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
images_query = f"""--sql
SELECT {IMAGE_DTO_COLS}
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
query_conditions = ""
query_params: list[Union[int, str, bool]] = []
if image_origin is not None:
query_conditions += """--sql
AND images.image_origin = ?
"""
query_params.append(image_origin.value)
if categories is not None:
# Convert the enum values to unique list of strings
category_strings = [c.value for c in set(categories)]
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
query_conditions += f"""--sql
AND images.image_category IN ( {placeholders} )
with self._db.transaction() as cursor:
# Manually build two queries - one for the count, one for the records
count_query = """--sql
SELECT COUNT(*)
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
# Unpack the included categories into the query params
for c in category_strings:
query_params.append(c)
if is_intermediate is not None:
query_conditions += """--sql
AND images.is_intermediate = ?
images_query = f"""--sql
SELECT {IMAGE_DTO_COLS}
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
query_params.append(is_intermediate)
query_conditions = ""
query_params: list[Union[int, str, bool]] = []
# board_id of "none" is reserved for images without a board
if board_id == "none":
query_conditions += """--sql
AND board_images.board_id IS NULL
"""
elif board_id is not None:
query_conditions += """--sql
AND board_images.board_id = ?
"""
query_params.append(board_id)
if image_origin is not None:
query_conditions += """--sql
AND images.image_origin = ?
"""
query_params.append(image_origin.value)
# Search term condition
if search_term:
query_conditions += """--sql
AND (
images.metadata LIKE ?
OR images.created_at LIKE ?
)
"""
query_params.append(f"%{search_term.lower()}%")
query_params.append(f"%{search_term.lower()}%")
if categories is not None:
# Convert the enum values to unique list of strings
category_strings = [c.value for c in set(categories)]
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
if starred_first:
query_pagination = f"""--sql
ORDER BY images.starred DESC, images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
else:
query_pagination = f"""--sql
ORDER BY images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
query_conditions += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
# Final images query with pagination
images_query += query_conditions + query_pagination + ";"
# Add all the parameters
images_params = query_params.copy()
# Add the pagination parameters
images_params.extend([limit, offset])
# Unpack the included categories into the query params
for c in category_strings:
query_params.append(c)
# Build the list of images, deserializing each row
cursor.execute(images_query, images_params)
result = cast(list[sqlite3.Row], cursor.fetchall())
images = [deserialize_image_record(dict(r)) for r in result]
if is_intermediate is not None:
query_conditions += """--sql
AND images.is_intermediate = ?
"""
# Set up and execute the count query, without pagination
count_query += query_conditions + ";"
count_params = query_params.copy()
cursor.execute(count_query, count_params)
count = cast(int, cursor.fetchone()[0])
query_params.append(is_intermediate)
# board_id of "none" is reserved for images without a board
if board_id == "none":
query_conditions += """--sql
AND board_images.board_id IS NULL
"""
elif board_id is not None:
query_conditions += """--sql
AND board_images.board_id = ?
"""
query_params.append(board_id)
# Search term condition
if search_term:
query_conditions += """--sql
AND (
images.metadata LIKE ?
OR images.created_at LIKE ?
)
"""
query_params.append(f"%{search_term.lower()}%")
query_params.append(f"%{search_term.lower()}%")
if starred_first:
query_pagination = f"""--sql
ORDER BY images.starred DESC, images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
else:
query_pagination = f"""--sql
ORDER BY images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
# Final images query with pagination
images_query += query_conditions + query_pagination + ";"
# Add all the parameters
images_params = query_params.copy()
# Add the pagination parameters
images_params.extend([limit, offset])
# Build the list of images, deserializing each row
cursor.execute(images_query, images_params)
result = cast(list[sqlite3.Row], cursor.fetchall())
images = [deserialize_image_record(dict(r)) for r in result]
# Set up and execute the count query, without pagination
count_query += query_conditions + ";"
count_params = query_params.copy()
cursor.execute(count_query, count_params)
count = cast(int, cursor.fetchone()[0])
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
def delete(self, image_name: str) -> None:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
DELETE FROM images
WHERE image_name = ?;
""",
(image_name,),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordDeleteException from e
with self._db.transaction() as cursor:
try:
cursor.execute(
"""--sql
DELETE FROM images
WHERE image_name = ?;
""",
(image_name,),
)
except sqlite3.Error as e:
raise ImageRecordDeleteException from e
def delete_many(self, image_names: list[str]) -> None:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
try:
placeholders = ",".join("?" for _ in image_names)
placeholders = ",".join("?" for _ in image_names)
# Construct the SQLite query with the placeholders
query = f"DELETE FROM images WHERE image_name IN ({placeholders})"
# Construct the SQLite query with the placeholders
query = f"DELETE FROM images WHERE image_name IN ({placeholders})"
# Execute the query with the list of IDs as parameters
cursor.execute(query, image_names)
# Execute the query with the list of IDs as parameters
cursor.execute(query, image_names)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordDeleteException from e
except sqlite3.Error as e:
raise ImageRecordDeleteException from e
def get_intermediates_count(self) -> int:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT COUNT(*) FROM images
WHERE is_intermediate = TRUE;
"""
)
count = cast(int, cursor.fetchone()[0])
self._conn.commit()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT COUNT(*) FROM images
WHERE is_intermediate = TRUE;
"""
)
count = cast(int, cursor.fetchone()[0])
return count
def delete_intermediates(self) -> list[str]:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT image_name FROM images
WHERE is_intermediate = TRUE;
"""
)
result = cast(list[sqlite3.Row], cursor.fetchall())
image_names = [r[0] for r in result]
cursor.execute(
"""--sql
DELETE FROM images
WHERE is_intermediate = TRUE;
"""
)
self._conn.commit()
return image_names
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordDeleteException from e
with self._db.transaction() as cursor:
try:
cursor.execute(
"""--sql
SELECT image_name FROM images
WHERE is_intermediate = TRUE;
"""
)
result = cast(list[sqlite3.Row], cursor.fetchall())
image_names = [r[0] for r in result]
cursor.execute(
"""--sql
DELETE FROM images
WHERE is_intermediate = TRUE;
"""
)
except sqlite3.Error as e:
raise ImageRecordDeleteException from e
return image_names
def save(
self,
@@ -315,73 +306,71 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
node_id: Optional[str] = None,
metadata: Optional[str] = None,
) -> datetime:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
INSERT OR IGNORE INTO images (
image_name,
image_origin,
image_category,
width,
height,
node_id,
session_id,
metadata,
is_intermediate,
starred,
has_workflow
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
""",
(
image_name,
image_origin.value,
image_category.value,
width,
height,
node_id,
session_id,
metadata,
is_intermediate,
starred,
has_workflow,
),
)
self._conn.commit()
with self._db.transaction() as cursor:
try:
cursor.execute(
"""--sql
INSERT OR IGNORE INTO images (
image_name,
image_origin,
image_category,
width,
height,
node_id,
session_id,
metadata,
is_intermediate,
starred,
has_workflow
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
""",
(
image_name,
image_origin.value,
image_category.value,
width,
height,
node_id,
session_id,
metadata,
is_intermediate,
starred,
has_workflow,
),
)
cursor.execute(
"""--sql
SELECT created_at
FROM images
WHERE image_name = ?;
""",
(image_name,),
)
cursor.execute(
"""--sql
SELECT created_at
FROM images
WHERE image_name = ?;
""",
(image_name,),
)
created_at = datetime.fromisoformat(cursor.fetchone()[0])
created_at = datetime.fromisoformat(cursor.fetchone()[0])
return created_at
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordSaveException from e
except sqlite3.Error as e:
raise ImageRecordSaveException from e
return created_at
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT images.*
FROM images
JOIN board_images ON images.image_name = board_images.image_name
WHERE board_images.board_id = ?
AND images.is_intermediate = FALSE
ORDER BY images.starred DESC, images.created_at DESC
LIMIT 1;
""",
(board_id,),
)
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT images.*
FROM images
JOIN board_images ON images.image_name = board_images.image_name
WHERE board_images.board_id = ?
AND images.is_intermediate = FALSE
ORDER BY images.starred DESC, images.created_at DESC
LIMIT 1;
""",
(board_id,),
)
result = cast(Optional[sqlite3.Row], cursor.fetchone())
result = cast(Optional[sqlite3.Row], cursor.fetchone())
if result is None:
return None
@@ -398,85 +387,84 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> ImageNamesResult:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
# Build query conditions (reused for both starred count and image names queries)
query_conditions = ""
query_params: list[Union[int, str, bool]] = []
# Build query conditions (reused for both starred count and image names queries)
query_conditions = ""
query_params: list[Union[int, str, bool]] = []
if image_origin is not None:
query_conditions += """--sql
AND images.image_origin = ?
"""
query_params.append(image_origin.value)
if image_origin is not None:
query_conditions += """--sql
AND images.image_origin = ?
"""
query_params.append(image_origin.value)
if categories is not None:
category_strings = [c.value for c in set(categories)]
placeholders = ",".join("?" * len(category_strings))
query_conditions += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
for c in category_strings:
query_params.append(c)
if categories is not None:
category_strings = [c.value for c in set(categories)]
placeholders = ",".join("?" * len(category_strings))
query_conditions += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
for c in category_strings:
query_params.append(c)
if is_intermediate is not None:
query_conditions += """--sql
AND images.is_intermediate = ?
"""
query_params.append(is_intermediate)
if is_intermediate is not None:
query_conditions += """--sql
AND images.is_intermediate = ?
"""
query_params.append(is_intermediate)
if board_id == "none":
query_conditions += """--sql
AND board_images.board_id IS NULL
"""
elif board_id is not None:
query_conditions += """--sql
AND board_images.board_id = ?
"""
query_params.append(board_id)
if board_id == "none":
query_conditions += """--sql
AND board_images.board_id IS NULL
"""
elif board_id is not None:
query_conditions += """--sql
AND board_images.board_id = ?
"""
query_params.append(board_id)
if search_term:
query_conditions += """--sql
AND (
images.metadata LIKE ?
OR images.created_at LIKE ?
)
"""
query_params.append(f"%{search_term.lower()}%")
query_params.append(f"%{search_term.lower()}%")
if search_term:
query_conditions += """--sql
AND (
images.metadata LIKE ?
OR images.created_at LIKE ?
)
"""
query_params.append(f"%{search_term.lower()}%")
query_params.append(f"%{search_term.lower()}%")
# Get starred count if starred_first is enabled
starred_count = 0
if starred_first:
starred_count_query = f"""--sql
SELECT COUNT(*)
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE images.starred = TRUE AND (1=1{query_conditions})
"""
cursor.execute(starred_count_query, query_params)
starred_count = cast(int, cursor.fetchone()[0])
# Get starred count if starred_first is enabled
starred_count = 0
if starred_first:
starred_count_query = f"""--sql
SELECT COUNT(*)
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE images.starred = TRUE AND (1=1{query_conditions})
"""
cursor.execute(starred_count_query, query_params)
starred_count = cast(int, cursor.fetchone()[0])
# Get all image names with proper ordering
if starred_first:
names_query = f"""--sql
SELECT images.image_name
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1{query_conditions}
ORDER BY images.starred DESC, images.created_at {order_dir.value}
"""
else:
names_query = f"""--sql
SELECT images.image_name
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1{query_conditions}
ORDER BY images.created_at {order_dir.value}
"""
# Get all image names with proper ordering
if starred_first:
names_query = f"""--sql
SELECT images.image_name
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1{query_conditions}
ORDER BY images.starred DESC, images.created_at {order_dir.value}
"""
else:
names_query = f"""--sql
SELECT images.image_name
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1{query_conditions}
ORDER BY images.created_at {order_dir.value}
"""
cursor.execute(names_query, query_params)
result = cast(list[sqlite3.Row], cursor.fetchall())
cursor.execute(names_query, query_params)
result = cast(list[sqlite3.Row], cursor.fetchall())
image_names = [row[0] for row in result]
return ImageNamesResult(image_names=image_names, starred_count=starred_count, total_count=len(image_names))

View File

@@ -78,11 +78,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
self._db = db
self._logger = logger
@property
def db(self) -> SqliteDatabase:
"""Return the underlying database."""
return self._db
def add_model(self, config: AnyModelConfig) -> AnyModelConfig:
"""
Add a model to the database.
@@ -93,38 +88,33 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
"""
try:
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
INSERT INTO models (
id,
config
)
VALUES (?,?);
""",
(
config.key,
config.model_dump_json(),
),
)
self._db.conn.commit()
with self._db.transaction() as cursor:
try:
cursor.execute(
"""--sql
INSERT INTO models (
id,
config
)
VALUES (?,?);
""",
(
config.key,
config.model_dump_json(),
),
)
except sqlite3.IntegrityError as e:
self._db.conn.rollback()
if "UNIQUE constraint failed" in str(e):
if "models.path" in str(e):
msg = f"A model with path '{config.path}' is already installed"
elif "models.name" in str(e):
msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
except sqlite3.IntegrityError as e:
if "UNIQUE constraint failed" in str(e):
if "models.path" in str(e):
msg = f"A model with path '{config.path}' is already installed"
elif "models.name" in str(e):
msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
else:
msg = f"A model with key '{config.key}' is already installed"
raise DuplicateModelException(msg) from e
else:
msg = f"A model with key '{config.key}' is already installed"
raise DuplicateModelException(msg) from e
else:
raise e
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
raise e
return self.get_model(config.key)
@@ -136,8 +126,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Can raise an UnknownModelException
"""
try:
cursor = self._db.conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
DELETE FROM models
@@ -147,22 +136,17 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
)
if cursor.rowcount == 0:
raise UnknownModelException("model not found")
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
record = self.get_model(key)
with self._db.transaction() as cursor:
record = self.get_model(key)
# Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic.
for field_name in changes.model_fields_set:
setattr(record, field_name, getattr(changes, field_name))
# Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic.
for field_name in changes.model_fields_set:
setattr(record, field_name, getattr(changes, field_name))
json_serialized = record.model_dump_json()
json_serialized = record.model_dump_json()
try:
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
UPDATE models
@@ -174,10 +158,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
)
if cursor.rowcount == 0:
raise UnknownModelException("model not found")
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
return self.get_model(key)
@@ -189,30 +169,30 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Exceptions: UnknownModelException
"""
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE id=?;
""",
(key,),
)
rows = cursor.fetchone()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE id=?;
""",
(key,),
)
rows = cursor.fetchone()
if not rows:
raise UnknownModelException("model not found")
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
return model
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
rows = cursor.fetchone()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
rows = cursor.fetchone()
if not rows:
raise UnknownModelException("model not found")
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
@@ -224,15 +204,15 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
:param key: Unique key for the model to be deleted
"""
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
select count(*) FROM models
WHERE id=?;
""",
(key,),
)
count = cursor.fetchone()[0]
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
select count(*) FROM models
WHERE id=?;
""",
(key,),
)
count = cursor.fetchone()[0]
return count > 0
def search_by_attr(
@@ -255,43 +235,42 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
If none of the optional filters are passed, will return all
models in the database.
"""
with self._db.transaction() as cursor:
assert isinstance(order_by, ModelRecordOrderBy)
ordering = {
ModelRecordOrderBy.Default: "type, base, name, format",
ModelRecordOrderBy.Type: "type",
ModelRecordOrderBy.Base: "base",
ModelRecordOrderBy.Name: "name",
ModelRecordOrderBy.Format: "format",
}
assert isinstance(order_by, ModelRecordOrderBy)
ordering = {
ModelRecordOrderBy.Default: "type, base, name, format",
ModelRecordOrderBy.Type: "type",
ModelRecordOrderBy.Base: "base",
ModelRecordOrderBy.Name: "name",
ModelRecordOrderBy.Format: "format",
}
where_clause: list[str] = []
bindings: list[str] = []
if model_name:
where_clause.append("name=?")
bindings.append(model_name)
if base_model:
where_clause.append("base=?")
bindings.append(base_model)
if model_type:
where_clause.append("type=?")
bindings.append(model_type)
if model_format:
where_clause.append("format=?")
bindings.append(model_format)
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
where_clause: list[str] = []
bindings: list[str] = []
if model_name:
where_clause.append("name=?")
bindings.append(model_name)
if base_model:
where_clause.append("base=?")
bindings.append(base_model)
if model_type:
where_clause.append("type=?")
bindings.append(model_type)
if model_format:
where_clause.append("format=?")
bindings.append(model_format)
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
cursor = self._db.conn.cursor()
cursor.execute(
f"""--sql
SELECT config, strftime('%s',updated_at)
FROM models
{where}
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason;
""",
tuple(bindings),
)
result = cursor.fetchall()
cursor.execute(
f"""--sql
SELECT config, strftime('%s',updated_at)
FROM models
{where}
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason;
""",
tuple(bindings),
)
result = cursor.fetchall()
# Parse the model configs.
results: list[AnyModelConfig] = []
@@ -313,69 +292,68 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
"""Return models with the indicated path."""
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE path=?;
""",
(str(path),),
)
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE path=?;
""",
(str(path),),
)
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
return results
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
"""Return models with the indicated hash."""
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
return results
def list_models(
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
) -> PaginatedResults[ModelSummary]:
"""Return a paginated summary listing of each model in the database."""
assert isinstance(order_by, ModelRecordOrderBy)
ordering = {
ModelRecordOrderBy.Default: "type, base, name, format",
ModelRecordOrderBy.Type: "type",
ModelRecordOrderBy.Base: "base",
ModelRecordOrderBy.Name: "name",
ModelRecordOrderBy.Format: "format",
}
with self._db.transaction() as cursor:
assert isinstance(order_by, ModelRecordOrderBy)
ordering = {
ModelRecordOrderBy.Default: "type, base, name, format",
ModelRecordOrderBy.Type: "type",
ModelRecordOrderBy.Base: "base",
ModelRecordOrderBy.Name: "name",
ModelRecordOrderBy.Format: "format",
}
cursor = self._db.conn.cursor()
# Lock so that the database isn't updated while we're doing the two queries.
# query1: get the total number of model configs
cursor.execute(
"""--sql
select count(*) from models;
""",
(),
)
total = int(cursor.fetchone()[0])
# Lock so that the database isn't updated while we're doing the two queries.
# query1: get the total number of model configs
cursor.execute(
"""--sql
select count(*) from models;
""",
(),
)
total = int(cursor.fetchone()[0])
# query2: fetch key fields
cursor.execute(
f"""--sql
SELECT config
FROM models
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
LIMIT ?
OFFSET ?;
""",
(
per_page,
page * per_page,
),
)
rows = cursor.fetchall()
# query2: fetch key fields
cursor.execute(
f"""--sql
SELECT config
FROM models
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
LIMIT ?
OFFSET ?;
""",
(
per_page,
page * per_page,
),
)
rows = cursor.fetchall()
items = [ModelSummary.model_validate(dict(x)) for x in rows]
return PaginatedResults(page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items)

View File

@@ -1,5 +1,3 @@
import sqlite3
from invokeai.app.services.model_relationship_records.model_relationship_records_base import (
ModelRelationshipRecordStorageBase,
)
@@ -9,58 +7,49 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class SqliteModelRelationshipRecordStorage(ModelRelationshipRecordStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._conn = db.conn
self._db = db
def add_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
if model_key_1 == model_key_2:
raise ValueError("Cannot relate a model to itself.")
a, b = sorted([model_key_1, model_key_2])
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
if model_key_1 == model_key_2:
raise ValueError("Cannot relate a model to itself.")
a, b = sorted([model_key_1, model_key_2])
cursor.execute(
"INSERT OR IGNORE INTO model_relationships (model_key_1, model_key_2) VALUES (?, ?)",
(a, b),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise e
def remove_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
a, b = sorted([model_key_1, model_key_2])
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
a, b = sorted([model_key_1, model_key_2])
cursor.execute(
"DELETE FROM model_relationships WHERE model_key_1 = ? AND model_key_2 = ?",
(a, b),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise e
def get_related_model_keys(self, model_key: str) -> list[str]:
cursor = self._conn.cursor()
cursor.execute(
"""
SELECT model_key_2 FROM model_relationships WHERE model_key_1 = ?
UNION
SELECT model_key_1 FROM model_relationships WHERE model_key_2 = ?
""",
(model_key, model_key),
)
return [row[0] for row in cursor.fetchall()]
with self._db.transaction() as cursor:
cursor.execute(
"""
SELECT model_key_2 FROM model_relationships WHERE model_key_1 = ?
UNION
SELECT model_key_1 FROM model_relationships WHERE model_key_2 = ?
""",
(model_key, model_key),
)
result = [row[0] for row in cursor.fetchall()]
return result
def get_related_model_keys_batch(self, model_keys: list[str]) -> list[str]:
cursor = self._conn.cursor()
key_list = ",".join("?" for _ in model_keys)
cursor.execute(
f"""
SELECT model_key_2 FROM model_relationships WHERE model_key_1 IN ({key_list})
UNION
SELECT model_key_1 FROM model_relationships WHERE model_key_2 IN ({key_list})
""",
model_keys + model_keys,
)
return [row[0] for row in cursor.fetchall()]
with self._db.transaction() as cursor:
key_list = ",".join("?" for _ in model_keys)
cursor.execute(
f"""
SELECT model_key_2 FROM model_relationships WHERE model_key_1 IN ({key_list})
UNION
SELECT model_key_1 FROM model_relationships WHERE model_key_2 IN ({key_list})
""",
model_keys + model_keys,
)
result = [row[0] for row in cursor.fetchall()]
return result

View File

@@ -332,6 +332,7 @@ class EnqueueBatchResult(BaseModel):
requested: int = Field(description="The total number of queue items requested to be enqueued")
batch: Batch = Field(description="The batch that was enqueued")
priority: int = Field(description="The priority of the enqueued batch")
item_ids: list[int] = Field(description="The IDs of the queue items that were enqueued")
class RetryItemsResult(BaseModel):

View File

@@ -50,15 +50,14 @@ class SqliteSessionQueue(SessionQueueBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._conn = db.conn
self._db = db
def _set_in_progress_to_canceled(self) -> None:
"""
Sets all in_progress queue items to canceled. Run on app startup, not associated with any queue.
This is necessary because the invoker may have been killed while processing a queue item.
"""
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
UPDATE session_queue
@@ -66,99 +65,104 @@ class SqliteSessionQueue(SessionQueueBase):
WHERE status = 'in_progress';
"""
)
except Exception:
self._conn.rollback()
raise
def _get_current_queue_size(self, queue_id: str) -> int:
"""Gets the current number of pending queue items"""
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
""",
(queue_id,),
)
return cast(int, cursor.fetchone()[0])
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
""",
(queue_id,),
)
count = cast(int, cursor.fetchone()[0])
return count
def _get_highest_priority(self, queue_id: str) -> int:
"""Gets the highest priority value in the queue"""
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT MAX(priority)
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
""",
(queue_id,),
)
return cast(Union[int, None], cursor.fetchone()[0]) or 0
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT MAX(priority)
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
""",
(queue_id,),
)
priority = cast(Union[int, None], cursor.fetchone()[0]) or 0
return priority
async def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
try:
# TODO: how does this work in a multi-user scenario?
current_queue_size = self._get_current_queue_size(queue_id)
max_queue_size = self.__invoker.services.configuration.max_queue_size
max_new_queue_items = max_queue_size - current_queue_size
current_queue_size = self._get_current_queue_size(queue_id)
max_queue_size = self.__invoker.services.configuration.max_queue_size
max_new_queue_items = max_queue_size - current_queue_size
priority = 0
if prepend:
priority = self._get_highest_priority(queue_id) + 1
priority = 0
if prepend:
priority = self._get_highest_priority(queue_id) + 1
requested_count = await asyncio.to_thread(
calc_session_count,
batch=batch,
)
values_to_insert = await asyncio.to_thread(
prepare_values_to_insert,
queue_id=queue_id,
batch=batch,
priority=priority,
max_new_queue_items=max_new_queue_items,
)
enqueued_count = len(values_to_insert)
requested_count = await asyncio.to_thread(
calc_session_count,
batch=batch,
)
values_to_insert = await asyncio.to_thread(
prepare_values_to_insert,
queue_id=queue_id,
batch=batch,
priority=priority,
max_new_queue_items=max_new_queue_items,
)
enqueued_count = len(values_to_insert)
with self._conn:
cursor = self._conn.cursor()
cursor.executemany(
"""--sql
with self._db.transaction() as cursor:
cursor.executemany(
"""--sql
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
values_to_insert,
)
except Exception:
raise
values_to_insert,
)
cursor.execute(
"""--sql
SELECT item_id
FROM session_queue
WHERE batch_id = ?
ORDER BY item_id DESC;
""",
(batch.batch_id,),
)
item_ids = [row[0] for row in cursor.fetchall()]
enqueue_result = EnqueueBatchResult(
queue_id=queue_id,
requested=requested_count,
enqueued=enqueued_count,
batch=batch,
priority=priority,
item_ids=item_ids,
)
self.__invoker.services.events.emit_batch_enqueued(enqueue_result)
return enqueue_result
def dequeue(self) -> Optional[SessionQueueItem]:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE status = 'pending'
ORDER BY
priority DESC,
item_id ASC
LIMIT 1
"""
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE status = 'pending'
ORDER BY
priority DESC,
item_id ASC
LIMIT 1
"""
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
if result is None:
return None
queue_item = SessionQueueItem.queue_item_from_dict(dict(result))
@@ -166,40 +170,40 @@ class SqliteSessionQueue(SessionQueueBase):
return queue_item
def get_next(self, queue_id: str) -> Optional[SessionQueueItem]:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
ORDER BY
priority DESC,
created_at ASC
LIMIT 1
""",
(queue_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
ORDER BY
priority DESC,
created_at ASC
LIMIT 1
""",
(queue_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
if result is None:
return None
return SessionQueueItem.queue_item_from_dict(dict(result))
def get_current(self, queue_id: str) -> Optional[SessionQueueItem]:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE
queue_id = ?
AND status = 'in_progress'
LIMIT 1
""",
(queue_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE
queue_id = ?
AND status = 'in_progress'
LIMIT 1
""",
(queue_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
if result is None:
return None
return SessionQueueItem.queue_item_from_dict(dict(result))
@@ -212,8 +216,7 @@ class SqliteSessionQueue(SessionQueueBase):
error_message: Optional[str] = None,
error_traceback: Optional[str] = None,
) -> SessionQueueItem:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT status FROM session_queue WHERE item_id = ?
@@ -221,12 +224,15 @@ class SqliteSessionQueue(SessionQueueBase):
(item_id,),
)
row = cursor.fetchone()
if row is None:
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
current_status = row[0]
# Only update if not already finished (completed, failed or canceled)
if current_status in ("completed", "failed", "canceled"):
return self.get_queue_item(item_id)
if row is None:
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
current_status = row[0]
# Only update if not already finished (completed, failed or canceled)
if current_status in ("completed", "failed", "canceled"):
return self.get_queue_item(item_id)
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
UPDATE session_queue
@@ -235,10 +241,7 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(status, error_type, error_message, error_traceback, item_id),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
queue_item = self.get_queue_item(item_id)
batch_status = self.get_batch_status(queue_id=queue_item.queue_id, batch_id=queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_item.queue_id)
@@ -246,35 +249,34 @@ class SqliteSessionQueue(SessionQueueBase):
return queue_item
def is_empty(self, queue_id: str) -> IsEmptyResult:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE queue_id = ?
""",
(queue_id,),
)
is_empty = cast(int, cursor.fetchone()[0]) == 0
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE queue_id = ?
""",
(queue_id,),
)
is_empty = cast(int, cursor.fetchone()[0]) == 0
return IsEmptyResult(is_empty=is_empty)
def is_full(self, queue_id: str) -> IsFullResult:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE queue_id = ?
""",
(queue_id,),
)
max_queue_size = self.__invoker.services.configuration.max_queue_size
is_full = cast(int, cursor.fetchone()[0]) >= max_queue_size
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE queue_id = ?
""",
(queue_id,),
)
max_queue_size = self.__invoker.services.configuration.max_queue_size
is_full = cast(int, cursor.fetchone()[0]) >= max_queue_size
return IsFullResult(is_full=is_full)
def clear(self, queue_id: str) -> ClearResult:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT COUNT(*)
@@ -292,24 +294,19 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
self.__invoker.services.events.emit_queue_cleared(queue_id)
return ClearResult(deleted=count)
def prune(self, queue_id: str) -> PruneResult:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
where = """--sql
WHERE
queue_id = ?
AND (
queue_id = ?
AND (
status = 'completed'
OR status = 'failed'
OR status = 'canceled'
)
)
"""
cursor.execute(
f"""--sql
@@ -328,10 +325,6 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return PruneResult(deleted=count)
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
@@ -344,8 +337,7 @@ class SqliteSessionQueue(SessionQueueBase):
self.cancel_queue_item(item_id)
except SessionQueueItemNotFoundError:
pass
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
DELETE
@@ -354,10 +346,6 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(item_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
def complete_queue_item(self, item_id: int) -> SessionQueueItem:
queue_item = self._set_queue_item_status(item_id=item_id, status="completed")
@@ -380,8 +368,7 @@ class SqliteSessionQueue(SessionQueueBase):
return queue_item
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
current_queue_item = self.get_current(queue_id)
placeholders = ", ".join(["?" for _ in batch_ids])
where = f"""--sql
@@ -391,6 +378,8 @@ class SqliteSessionQueue(SessionQueueBase):
AND status != 'canceled'
AND status != 'completed'
AND status != 'failed'
-- We will cancel the current item separately below - skip it here
AND status != 'in_progress'
"""
params = [queue_id] + batch_ids
cursor.execute(
@@ -410,17 +399,14 @@ class SqliteSessionQueue(SessionQueueBase):
""",
tuple(params),
)
self._conn.commit()
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
except Exception:
self._conn.rollback()
raise
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
return CancelByBatchIDsResult(canceled=count)
def cancel_by_destination(self, queue_id: str, destination: str) -> CancelByDestinationResult:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
current_queue_item = self.get_current(queue_id)
where = """--sql
WHERE
@@ -429,6 +415,8 @@ class SqliteSessionQueue(SessionQueueBase):
AND status != 'canceled'
AND status != 'completed'
AND status != 'failed'
-- We will cancel the current item separately below - skip it here
AND status != 'in_progress'
"""
params = (queue_id, destination)
cursor.execute(
@@ -448,17 +436,12 @@ class SqliteSessionQueue(SessionQueueBase):
""",
params,
)
self._conn.commit()
if current_queue_item is not None and current_queue_item.destination == destination:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
except Exception:
self._conn.rollback()
raise
if current_queue_item is not None and current_queue_item.destination == destination:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
return CancelByDestinationResult(canceled=count)
def delete_by_destination(self, queue_id: str, destination: str) -> DeleteByDestinationResult:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
current_queue_item = self.get_current(queue_id)
if current_queue_item is not None and current_queue_item.destination == destination:
self.cancel_queue_item(current_queue_item.item_id)
@@ -484,15 +467,10 @@ class SqliteSessionQueue(SessionQueueBase):
""",
params,
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return DeleteByDestinationResult(deleted=count)
def delete_all_except_current(self, queue_id: str) -> DeleteAllExceptCurrentResult:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
where = """--sql
WHERE
queue_id == ?
@@ -515,15 +493,10 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return DeleteAllExceptCurrentResult(deleted=count)
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
current_queue_item = self.get_current(queue_id)
where = """--sql
WHERE
@@ -531,6 +504,8 @@ class SqliteSessionQueue(SessionQueueBase):
AND status != 'canceled'
AND status != 'completed'
AND status != 'failed'
-- We will cancel the current item separately below - skip it here
AND status != 'in_progress'
"""
params = [queue_id]
cursor.execute(
@@ -550,21 +525,13 @@ class SqliteSessionQueue(SessionQueueBase):
""",
tuple(params),
)
self._conn.commit()
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_id)
self.__invoker.services.events.emit_queue_item_status_changed(
current_queue_item, batch_status, queue_status
)
except Exception:
self._conn.rollback()
raise
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
return CancelByQueueIDResult(canceled=count)
def cancel_all_except_current(self, queue_id: str) -> CancelAllExceptCurrentResult:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
where = """--sql
WHERE
queue_id == ?
@@ -587,30 +554,25 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return CancelAllExceptCurrentResult(canceled=count)
def get_queue_item(self, item_id: int) -> SessionQueueItem:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT * FROM session_queue
WHERE
item_id = ?
""",
(item_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT * FROM session_queue
WHERE
item_id = ?
""",
(item_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
if result is None:
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
return SessionQueueItem.queue_item_from_dict(dict(result))
def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
# Use exclude_none so we don't end up with a bunch of nulls in the graph - this can cause validation errors
# when the graph is loaded. Graph execution occurs purely in memory - the session saved here is not referenced
# during execution.
@@ -623,10 +585,6 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(session_json, item_id),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return self.get_queue_item(item_id)
def list_queue_items(
@@ -638,42 +596,42 @@ class SqliteSessionQueue(SessionQueueBase):
status: Optional[QUEUE_ITEM_STATUS] = None,
destination: Optional[str] = None,
) -> CursorPaginatedResults[SessionQueueItem]:
cursor_ = self._conn.cursor()
item_id = cursor
query = """--sql
SELECT *
FROM session_queue
WHERE queue_id = ?
"""
params: list[Union[str, int]] = [queue_id]
if status is not None:
query += """--sql
AND status = ?
"""
params.append(status)
if destination is not None:
query += """---sql
AND destination = ?
with self._db.transaction() as cursor_:
item_id = cursor
query = """--sql
SELECT *
FROM session_queue
WHERE queue_id = ?
"""
params.append(destination)
params: list[Union[str, int]] = [queue_id]
if item_id is not None:
query += """--sql
AND (priority < ?) OR (priority = ? AND item_id > ?)
if status is not None:
query += """--sql
AND status = ?
"""
params.append(status)
if destination is not None:
query += """---sql
AND destination = ?
"""
params.extend([priority, priority, item_id])
params.append(destination)
query += """--sql
ORDER BY
priority DESC,
item_id ASC
LIMIT ?
"""
params.append(limit + 1)
cursor_.execute(query, params)
results = cast(list[sqlite3.Row], cursor_.fetchall())
if item_id is not None:
query += """--sql
AND (priority < ?) OR (priority = ? AND item_id > ?)
"""
params.extend([priority, priority, item_id])
query += """--sql
ORDER BY
priority DESC,
item_id ASC
LIMIT ?
"""
params.append(limit + 1)
cursor_.execute(query, params)
results = cast(list[sqlite3.Row], cursor_.fetchall())
items = [SessionQueueItem.queue_item_from_dict(dict(result)) for result in results]
has_more = False
if len(items) > limit:
@@ -688,46 +646,46 @@ class SqliteSessionQueue(SessionQueueBase):
destination: Optional[str] = None,
) -> list[SessionQueueItem]:
"""Gets all queue items that match the given parameters"""
cursor_ = self._conn.cursor()
query = """--sql
SELECT *
FROM session_queue
WHERE queue_id = ?
"""
params: list[Union[str, int]] = [queue_id]
if destination is not None:
query += """---sql
AND destination = ?
with self._db.transaction() as cursor:
query = """--sql
SELECT *
FROM session_queue
WHERE queue_id = ?
"""
params.append(destination)
params: list[Union[str, int]] = [queue_id]
query += """--sql
ORDER BY
priority DESC,
item_id ASC
;
"""
cursor_.execute(query, params)
results = cast(list[sqlite3.Row], cursor_.fetchall())
if destination is not None:
query += """---sql
AND destination = ?
"""
params.append(destination)
query += """--sql
ORDER BY
priority DESC,
item_id ASC
;
"""
cursor.execute(query, params)
results = cast(list[sqlite3.Row], cursor.fetchall())
items = [SessionQueueItem.queue_item_from_dict(dict(result)) for result in results]
return items
def get_queue_status(self, queue_id: str) -> SessionQueueStatus:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
WHERE queue_id = ?
GROUP BY status
""",
(queue_id,),
)
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
WHERE queue_id = ?
GROUP BY status
""",
(queue_id,),
)
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
current_item = self.get_current(queue_id=queue_id)
total = sum(row[1] for row in counts_result)
total = sum(row[1] or 0 for row in counts_result)
counts: dict[str, int] = {row[0]: row[1] for row in counts_result}
return SessionQueueStatus(
queue_id=queue_id,
@@ -743,20 +701,20 @@ class SqliteSessionQueue(SessionQueueBase):
)
def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT status, count(*), origin, destination
FROM session_queue
WHERE
queue_id = ?
AND batch_id = ?
GROUP BY status
""",
(queue_id, batch_id),
)
result = cast(list[sqlite3.Row], cursor.fetchall())
total = sum(row[1] for row in result)
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT status, count(*), origin, destination
FROM session_queue
WHERE
queue_id = ?
AND batch_id = ?
GROUP BY status
""",
(queue_id, batch_id),
)
result = cast(list[sqlite3.Row], cursor.fetchall())
total = sum(row[1] or 0 for row in result)
counts: dict[str, int] = {row[0]: row[1] for row in result}
origin = result[0]["origin"] if result else None
destination = result[0]["destination"] if result else None
@@ -775,20 +733,20 @@ class SqliteSessionQueue(SessionQueueBase):
)
def get_counts_by_destination(self, queue_id: str, destination: str) -> SessionQueueCountsByDestination:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
WHERE queue_id = ?
AND destination = ?
GROUP BY status
""",
(queue_id, destination),
)
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
WHERE queue_id = ?
AND destination = ?
GROUP BY status
""",
(queue_id, destination),
)
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
total = sum(row[1] for row in counts_result)
total = sum(row[1] or 0 for row in counts_result)
counts: dict[str, int] = {row[0]: row[1] for row in counts_result}
return SessionQueueCountsByDestination(
@@ -804,8 +762,7 @@ class SqliteSessionQueue(SessionQueueBase):
def retry_items_by_id(self, queue_id: str, item_ids: list[int]) -> RetryItemsResult:
"""Retries the given queue items"""
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
values_to_insert: list[ValueToInsertTuple] = []
retried_item_ids: list[int] = []
@@ -856,10 +813,6 @@ class SqliteSessionQueue(SessionQueueBase):
values_to_insert,
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
retry_result = RetryItemsResult(
queue_id=queue_id,
retried_item_ids=retried_item_ids,

View File

@@ -2,7 +2,7 @@
import copy
import itertools
from typing import Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
from typing import Any, Optional, TypeVar, Union, get_args, get_origin
import networkx as nx
from pydantic import (
@@ -58,17 +58,32 @@ class Edge(BaseModel):
def get_output_field_type(node: BaseInvocation, field: str) -> Any:
node_type = type(node)
node_outputs = get_type_hints(node_type.get_output_annotation())
node_output_field = node_outputs.get(field) or None
return node_output_field
# TODO(psyche): This is awkward - if field_info is None, it means the field is not defined in the output, which
# really should raise. The consumers of this utility expect it to never raise, and return None instead. Fixing this
# would require some fairly significant changes and I don't want risk breaking anything.
try:
invocation_class = type(node)
invocation_output_class = invocation_class.get_output_annotation()
field_info = invocation_output_class.model_fields.get(field)
assert field_info is not None, f"Output field '{field}' not found in {invocation_output_class.get_type()}"
output_field_type = field_info.annotation
return output_field_type
except Exception:
return None
def get_input_field_type(node: BaseInvocation, field: str) -> Any:
node_type = type(node)
node_inputs = get_type_hints(node_type)
node_input_field = node_inputs.get(field) or None
return node_input_field
# TODO(psyche): This is awkward - if field_info is None, it means the field is not defined in the output, which
# really should raise. The consumers of this utility expect it to never raise, and return None instead. Fixing this
# would require some fairly significant changes and I don't want risk breaking anything.
try:
invocation_class = type(node)
field_info = invocation_class.model_fields.get(field)
assert field_info is not None, f"Input field '{field}' not found in {invocation_class.get_type()}"
input_field_type = field_info.annotation
return input_field_type
except Exception:
return None
def is_union_subtype(t1, t2):
@@ -992,10 +1007,11 @@ class GraphExecutionState(BaseModel):
new_node_ids = []
if isinstance(next_node, CollectInvocation):
# Collapse all iterator input mappings and create a single execution node for the collect invocation
all_iteration_mappings = list(
itertools.chain(*(((s, p) for p in self.source_prepared_mapping[s]) for s in next_node_parents))
)
# all_iteration_mappings = list(set(itertools.chain(*prepared_parent_mappings)))
all_iteration_mappings = []
for source_node_id in next_node_parents:
prepared_nodes = self.source_prepared_mapping[source_node_id]
all_iteration_mappings.extend([(source_node_id, p) for p in prepared_nodes])
create_results = self._create_execution_node(next_node_id, all_iteration_mappings)
if create_results is not None:
new_node_ids.extend(create_results)

View File

@@ -1,4 +1,7 @@
import sqlite3
import threading
from collections.abc import Generator
from contextlib import contextmanager
from logging import Logger
from pathlib import Path
@@ -26,46 +29,65 @@ class SqliteDatabase:
def __init__(self, db_path: Path | None, logger: Logger, verbose: bool = False) -> None:
"""Initializes the database. This is used internally by the class constructor."""
self.logger = logger
self.db_path = db_path
self.verbose = verbose
self._logger = logger
self._db_path = db_path
self._verbose = verbose
self._lock = threading.RLock()
if not self.db_path:
if not self._db_path:
logger.info("Initializing in-memory database")
else:
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self.logger.info(f"Initializing database at {self.db_path}")
self._db_path.parent.mkdir(parents=True, exist_ok=True)
self._logger.info(f"Initializing database at {self._db_path}")
self.conn = sqlite3.connect(database=self.db_path or sqlite_memory, check_same_thread=False)
self.conn.row_factory = sqlite3.Row
self._conn = sqlite3.connect(database=self._db_path or sqlite_memory, check_same_thread=False)
self._conn.row_factory = sqlite3.Row
if self.verbose:
self.conn.set_trace_callback(self.logger.debug)
if self._verbose:
self._conn.set_trace_callback(self._logger.debug)
# Enable foreign key constraints
self.conn.execute("PRAGMA foreign_keys = ON;")
self._conn.execute("PRAGMA foreign_keys = ON;")
# Enable Write-Ahead Logging (WAL) mode for better concurrency
self.conn.execute("PRAGMA journal_mode = WAL;")
self._conn.execute("PRAGMA journal_mode = WAL;")
# Set a busy timeout to prevent database lockups during writes
self.conn.execute("PRAGMA busy_timeout = 5000;") # 5 seconds
self._conn.execute("PRAGMA busy_timeout = 5000;") # 5 seconds
def clean(self) -> None:
"""
Cleans the database by running the VACUUM command, reporting on the freed space.
"""
# No need to clean in-memory database
if not self.db_path:
if not self._db_path:
return
try:
initial_db_size = Path(self.db_path).stat().st_size
self.conn.execute("VACUUM;")
self.conn.commit()
final_db_size = Path(self.db_path).stat().st_size
freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
if freed_space_in_mb > 0:
self.logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
with self._conn as conn:
initial_db_size = Path(self._db_path).stat().st_size
conn.execute("VACUUM;")
conn.commit()
final_db_size = Path(self._db_path).stat().st_size
freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
if freed_space_in_mb > 0:
self._logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
except Exception as e:
self.logger.error(f"Error cleaning database: {e}")
self._logger.error(f"Error cleaning database: {e}")
raise
@contextmanager
def transaction(self) -> Generator[sqlite3.Cursor, None, None]:
"""
Thread-safe context manager for DB work.
Acquires the RLock, yields a Cursor, then commits or rolls back.
"""
with self._lock:
cursor = self._conn.cursor()
try:
yield cursor
self._conn.commit()
except:
self._conn.rollback()
raise
finally:
cursor.close()

View File

@@ -32,7 +32,7 @@ class SqliteMigrator:
def __init__(self, db: SqliteDatabase) -> None:
self._db = db
self._logger = db.logger
self._logger = db._logger
self._migration_set = MigrationSet()
self._backup_path: Optional[Path] = None
@@ -45,7 +45,7 @@ class SqliteMigrator:
"""Migrates the database to the latest version."""
# This throws if there is a problem.
self._migration_set.validate_migration_chain()
cursor = self._db.conn.cursor()
cursor = self._db._conn.cursor()
self._create_migrations_table(cursor=cursor)
if self._migration_set.count == 0:
@@ -59,13 +59,13 @@ class SqliteMigrator:
self._logger.info("Database update needed")
# Make a backup of the db if it needs to be updated and is a file db
if self._db.db_path is not None:
if self._db._db_path is not None:
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
self._backup_path = self._db.db_path.parent / f"{self._db.db_path.stem}_backup_{timestamp}.db"
self._backup_path = self._db._db_path.parent / f"{self._db._db_path.stem}_backup_{timestamp}.db"
self._logger.info(f"Backing up database to {str(self._backup_path)}")
# Use SQLite to do the backup
with closing(sqlite3.connect(self._backup_path)) as backup_conn:
self._db.conn.backup(backup_conn)
self._db._conn.backup(backup_conn)
else:
self._logger.info("Using in-memory database, no backup needed")
@@ -81,7 +81,7 @@ class SqliteMigrator:
try:
# Using sqlite3.Connection as a context manager commits a the transaction on exit, or rolls it back if an
# exception is raised.
with self._db.conn as conn:
with self._db._conn as conn:
cursor = conn.cursor()
if self._get_current_version(cursor) != migration.from_version:
raise MigrationError(

View File

@@ -17,7 +17,7 @@ from invokeai.app.util.misc import uuid_string
class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._conn = db.conn
self._db = db
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
@@ -25,24 +25,23 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
def get(self, style_preset_id: str) -> StylePresetRecordDTO:
"""Gets a style preset by ID."""
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM style_presets
WHERE id = ?;
""",
(style_preset_id,),
)
row = cursor.fetchone()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT *
FROM style_presets
WHERE id = ?;
""",
(style_preset_id,),
)
row = cursor.fetchone()
if row is None:
raise StylePresetNotFoundError(f"Style preset with id {style_preset_id} not found")
return StylePresetRecordDTO.from_dict(dict(row))
def create(self, style_preset: StylePresetWithoutId) -> StylePresetRecordDTO:
style_preset_id = uuid_string()
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
INSERT OR IGNORE INTO style_presets (
@@ -60,16 +59,11 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
style_preset.type,
),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return self.get(style_preset_id)
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
style_preset_ids = []
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
for style_preset in style_presets:
style_preset_id = uuid_string()
style_preset_ids.append(style_preset_id)
@@ -90,16 +84,11 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
style_preset.type,
),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return None
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
# Change the name of a style preset
if changes.name is not None:
cursor.execute(
@@ -122,15 +111,10 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
(changes.preset_data.model_dump_json(), style_preset_id),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return self.get(style_preset_id)
def delete(self, style_preset_id: str) -> None:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
DELETE from style_presets
@@ -138,51 +122,41 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
""",
(style_preset_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return None
def get_many(self, type: PresetType | None = None) -> list[StylePresetRecordDTO]:
main_query = """
SELECT
*
FROM style_presets
"""
with self._db.transaction() as cursor:
main_query = """
SELECT
*
FROM style_presets
"""
if type is not None:
main_query += "WHERE type = ? "
if type is not None:
main_query += "WHERE type = ? "
main_query += "ORDER BY LOWER(name) ASC"
main_query += "ORDER BY LOWER(name) ASC"
cursor = self._conn.cursor()
if type is not None:
cursor.execute(main_query, (type,))
else:
cursor.execute(main_query)
if type is not None:
cursor.execute(main_query, (type,))
else:
cursor.execute(main_query)
rows = cursor.fetchall()
rows = cursor.fetchall()
style_presets = [StylePresetRecordDTO.from_dict(dict(row)) for row in rows]
return style_presets
def _sync_default_style_presets(self) -> None:
"""Syncs default style presets to the database. Internal use only."""
# First delete all existing default style presets
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
# First delete all existing default style presets
cursor.execute(
"""--sql
DELETE FROM style_presets
WHERE type = "default";
"""
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
# Next, parse and create the default style presets
with open(Path(__file__).parent / Path("default_style_presets.json"), "r") as file:
presets = json.load(file)

View File

@@ -25,7 +25,7 @@ SQL_TIME_FORMAT = "%Y-%m-%d %H:%M:%f"
class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._conn = db.conn
self._db = db
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
@@ -33,16 +33,16 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
def get(self, workflow_id: str) -> WorkflowRecordDTO:
"""Gets a workflow by ID. Updates the opened_at column."""
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT workflow_id, workflow, name, created_at, updated_at, opened_at
FROM workflow_library
WHERE workflow_id = ?;
""",
(workflow_id,),
)
row = cursor.fetchone()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT workflow_id, workflow, name, created_at, updated_at, opened_at
FROM workflow_library
WHERE workflow_id = ?;
""",
(workflow_id,),
)
row = cursor.fetchone()
if row is None:
raise WorkflowNotFoundError(f"Workflow with id {workflow_id} not found")
return WorkflowRecordDTO.from_dict(dict(row))
@@ -51,9 +51,8 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
if workflow.meta.category is WorkflowCategory.Default:
raise ValueError("Default workflows cannot be created via this method")
try:
with self._db.transaction() as cursor:
workflow_with_id = Workflow(**workflow.model_dump(), id=uuid_string())
cursor = self._conn.cursor()
cursor.execute(
"""--sql
INSERT OR IGNORE INTO workflow_library (
@@ -64,18 +63,13 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
""",
(workflow_with_id.id, workflow_with_id.model_dump_json()),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return self.get(workflow_with_id.id)
def update(self, workflow: Workflow) -> WorkflowRecordDTO:
if workflow.meta.category is WorkflowCategory.Default:
raise ValueError("Default workflows cannot be updated")
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
UPDATE workflow_library
@@ -84,18 +78,13 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
""",
(workflow.model_dump_json(), workflow.id),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return self.get(workflow.id)
def delete(self, workflow_id: str) -> None:
if self.get(workflow_id).workflow.meta.category is WorkflowCategory.Default:
raise ValueError("Default workflows cannot be deleted")
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
DELETE from workflow_library
@@ -103,10 +92,6 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
""",
(workflow_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return None
def get_many(
@@ -121,108 +106,108 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
has_been_opened: Optional[bool] = None,
is_published: Optional[bool] = None,
) -> PaginatedResults[WorkflowRecordListItemDTO]:
# sanitize!
assert order_by in WorkflowRecordOrderBy
assert direction in SQLiteDirection
with self._db.transaction() as cursor:
# sanitize!
assert order_by in WorkflowRecordOrderBy
assert direction in SQLiteDirection
# We will construct the query dynamically based on the query params
# We will construct the query dynamically based on the query params
# The main query to get the workflows / counts
main_query = """
SELECT
workflow_id,
category,
name,
description,
created_at,
updated_at,
opened_at,
tags
FROM workflow_library
"""
count_query = "SELECT COUNT(*) FROM workflow_library"
# The main query to get the workflows / counts
main_query = """
SELECT
workflow_id,
category,
name,
description,
created_at,
updated_at,
opened_at,
tags
FROM workflow_library
"""
count_query = "SELECT COUNT(*) FROM workflow_library"
# Start with an empty list of conditions and params
conditions: list[str] = []
params: list[str | int] = []
# Start with an empty list of conditions and params
conditions: list[str] = []
params: list[str | int] = []
if categories:
# Categories is a list of WorkflowCategory enum values, and a single string in the DB
if categories:
# Categories is a list of WorkflowCategory enum values, and a single string in the DB
# Ensure all categories are valid (is this necessary?)
assert all(c in WorkflowCategory for c in categories)
# Ensure all categories are valid (is this necessary?)
assert all(c in WorkflowCategory for c in categories)
# Construct a placeholder string for the number of categories
placeholders = ", ".join("?" for _ in categories)
# Construct a placeholder string for the number of categories
placeholders = ", ".join("?" for _ in categories)
# Construct the condition string & params
category_condition = f"category IN ({placeholders})"
category_params = [category.value for category in categories]
# Construct the condition string & params
category_condition = f"category IN ({placeholders})"
category_params = [category.value for category in categories]
conditions.append(category_condition)
params.extend(category_params)
conditions.append(category_condition)
params.extend(category_params)
if tags:
# Tags is a list of strings, and a single string in the DB
# The string in the DB has no guaranteed format
if tags:
# Tags is a list of strings, and a single string in the DB
# The string in the DB has no guaranteed format
# Construct a list of conditions for each tag
tags_conditions = ["tags LIKE ?" for _ in tags]
tags_conditions_joined = " OR ".join(tags_conditions)
tags_condition = f"({tags_conditions_joined})"
# Construct a list of conditions for each tag
tags_conditions = ["tags LIKE ?" for _ in tags]
tags_conditions_joined = " OR ".join(tags_conditions)
tags_condition = f"({tags_conditions_joined})"
# And the params for the tags, case-insensitive
tags_params = [f"%{t.strip()}%" for t in tags]
# And the params for the tags, case-insensitive
tags_params = [f"%{t.strip()}%" for t in tags]
conditions.append(tags_condition)
params.extend(tags_params)
conditions.append(tags_condition)
params.extend(tags_params)
if has_been_opened:
conditions.append("opened_at IS NOT NULL")
elif has_been_opened is False:
conditions.append("opened_at IS NULL")
if has_been_opened:
conditions.append("opened_at IS NOT NULL")
elif has_been_opened is False:
conditions.append("opened_at IS NULL")
# Ignore whitespace in the query
stripped_query = query.strip() if query else None
if stripped_query:
# Construct a wildcard query for the name, description, and tags
wildcard_query = "%" + stripped_query + "%"
query_condition = "(name LIKE ? OR description LIKE ? OR tags LIKE ?)"
# Ignore whitespace in the query
stripped_query = query.strip() if query else None
if stripped_query:
# Construct a wildcard query for the name, description, and tags
wildcard_query = "%" + stripped_query + "%"
query_condition = "(name LIKE ? OR description LIKE ? OR tags LIKE ?)"
conditions.append(query_condition)
params.extend([wildcard_query, wildcard_query, wildcard_query])
conditions.append(query_condition)
params.extend([wildcard_query, wildcard_query, wildcard_query])
if conditions:
# If there are conditions, add a WHERE clause and then join the conditions
main_query += " WHERE "
count_query += " WHERE "
if conditions:
# If there are conditions, add a WHERE clause and then join the conditions
main_query += " WHERE "
count_query += " WHERE "
all_conditions = " AND ".join(conditions)
main_query += all_conditions
count_query += all_conditions
all_conditions = " AND ".join(conditions)
main_query += all_conditions
count_query += all_conditions
# After this point, the query and params differ for the main query and the count query
main_params = params.copy()
count_params = params.copy()
# After this point, the query and params differ for the main query and the count query
main_params = params.copy()
count_params = params.copy()
# Main query also gets ORDER BY and LIMIT/OFFSET
main_query += f" ORDER BY {order_by.value} {direction.value}"
# Main query also gets ORDER BY and LIMIT/OFFSET
main_query += f" ORDER BY {order_by.value} {direction.value}"
if per_page:
main_query += " LIMIT ? OFFSET ?"
main_params.extend([per_page, page * per_page])
if per_page:
main_query += " LIMIT ? OFFSET ?"
main_params.extend([per_page, page * per_page])
# Put a ring on it
main_query += ";"
count_query += ";"
# Put a ring on it
main_query += ";"
count_query += ";"
cursor = self._conn.cursor()
cursor.execute(main_query, main_params)
rows = cursor.fetchall()
workflows = [WorkflowRecordListItemDTOValidator.validate_python(dict(row)) for row in rows]
cursor.execute(main_query, main_params)
rows = cursor.fetchall()
workflows = [WorkflowRecordListItemDTOValidator.validate_python(dict(row)) for row in rows]
cursor.execute(count_query, count_params)
total = cursor.fetchone()[0]
cursor.execute(count_query, count_params)
total = cursor.fetchone()[0]
if per_page:
pages = total // per_page + (total % per_page > 0)
@@ -247,46 +232,46 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
if not tags:
return {}
cursor = self._conn.cursor()
result: dict[str, int] = {}
# Base conditions for categories and selected tags
base_conditions: list[str] = []
base_params: list[str | int] = []
with self._db.transaction() as cursor:
result: dict[str, int] = {}
# Base conditions for categories and selected tags
base_conditions: list[str] = []
base_params: list[str | int] = []
# Add category conditions
if categories:
assert all(c in WorkflowCategory for c in categories)
placeholders = ", ".join("?" for _ in categories)
base_conditions.append(f"category IN ({placeholders})")
base_params.extend([category.value for category in categories])
# Add category conditions
if categories:
assert all(c in WorkflowCategory for c in categories)
placeholders = ", ".join("?" for _ in categories)
base_conditions.append(f"category IN ({placeholders})")
base_params.extend([category.value for category in categories])
if has_been_opened:
base_conditions.append("opened_at IS NOT NULL")
elif has_been_opened is False:
base_conditions.append("opened_at IS NULL")
if has_been_opened:
base_conditions.append("opened_at IS NOT NULL")
elif has_been_opened is False:
base_conditions.append("opened_at IS NULL")
# For each tag to count, run a separate query
for tag in tags:
# Start with the base conditions
conditions = base_conditions.copy()
params = base_params.copy()
# For each tag to count, run a separate query
for tag in tags:
# Start with the base conditions
conditions = base_conditions.copy()
params = base_params.copy()
# Add this specific tag condition
conditions.append("tags LIKE ?")
params.append(f"%{tag.strip()}%")
# Add this specific tag condition
conditions.append("tags LIKE ?")
params.append(f"%{tag.strip()}%")
# Construct the full query
stmt = """--sql
SELECT COUNT(*)
FROM workflow_library
"""
# Construct the full query
stmt = """--sql
SELECT COUNT(*)
FROM workflow_library
"""
if conditions:
stmt += " WHERE " + " AND ".join(conditions)
if conditions:
stmt += " WHERE " + " AND ".join(conditions)
cursor.execute(stmt, params)
count = cursor.fetchone()[0]
result[tag] = count
cursor.execute(stmt, params)
count = cursor.fetchone()[0]
result[tag] = count
return result
@@ -296,52 +281,51 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
has_been_opened: Optional[bool] = None,
is_published: Optional[bool] = None,
) -> dict[str, int]:
cursor = self._conn.cursor()
result: dict[str, int] = {}
# Base conditions for categories
base_conditions: list[str] = []
base_params: list[str | int] = []
with self._db.transaction() as cursor:
result: dict[str, int] = {}
# Base conditions for categories
base_conditions: list[str] = []
base_params: list[str | int] = []
# Add category conditions
if categories:
assert all(c in WorkflowCategory for c in categories)
placeholders = ", ".join("?" for _ in categories)
base_conditions.append(f"category IN ({placeholders})")
base_params.extend([category.value for category in categories])
# Add category conditions
if categories:
assert all(c in WorkflowCategory for c in categories)
placeholders = ", ".join("?" for _ in categories)
base_conditions.append(f"category IN ({placeholders})")
base_params.extend([category.value for category in categories])
if has_been_opened:
base_conditions.append("opened_at IS NOT NULL")
elif has_been_opened is False:
base_conditions.append("opened_at IS NULL")
if has_been_opened:
base_conditions.append("opened_at IS NOT NULL")
elif has_been_opened is False:
base_conditions.append("opened_at IS NULL")
# For each category to count, run a separate query
for category in categories:
# Start with the base conditions
conditions = base_conditions.copy()
params = base_params.copy()
# For each category to count, run a separate query
for category in categories:
# Start with the base conditions
conditions = base_conditions.copy()
params = base_params.copy()
# Add this specific category condition
conditions.append("category = ?")
params.append(category.value)
# Add this specific category condition
conditions.append("category = ?")
params.append(category.value)
# Construct the full query
stmt = """--sql
SELECT COUNT(*)
FROM workflow_library
"""
# Construct the full query
stmt = """--sql
SELECT COUNT(*)
FROM workflow_library
"""
if conditions:
stmt += " WHERE " + " AND ".join(conditions)
if conditions:
stmt += " WHERE " + " AND ".join(conditions)
cursor.execute(stmt, params)
count = cursor.fetchone()[0]
result[category.value] = count
cursor.execute(stmt, params)
count = cursor.fetchone()[0]
result[category.value] = count
return result
def update_opened_at(self, workflow_id: str) -> None:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
f"""--sql
UPDATE workflow_library
@@ -350,10 +334,6 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
""",
(workflow_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
def _sync_default_workflows(self) -> None:
"""Syncs default workflows to the database. Internal use only."""
@@ -368,8 +348,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
meaningless, as they are overwritten every time the server starts.
"""
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
workflows_from_file: list[Workflow] = []
workflows_to_update: list[Workflow] = []
workflows_to_add: list[Workflow] = []
@@ -449,8 +428,3 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
""",
(w.model_dump_json(), w.id),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise

View File

@@ -123,7 +123,11 @@ def calc_percentage(intermediate_state: PipelineIntermediateState) -> float:
if total_steps == 0:
return 0.0
if order == 2:
return floor(step / 2) / floor(total_steps / 2)
# Prevent division by zero when total_steps is 1 or 2
denominator = floor(total_steps / 2)
if denominator == 0:
return 0.0
return floor(step / 2) / denominator
# order == 1
return step / total_steps

View File

@@ -30,8 +30,11 @@ def denoise(
controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension],
pos_ip_adapter_extensions: list[XLabsIPAdapterExtension],
neg_ip_adapter_extensions: list[XLabsIPAdapterExtension],
# extra img tokens
# extra img tokens (channel-wise)
img_cond: torch.Tensor | None,
# extra img tokens (sequence-wise) - for Kontext conditioning
img_cond_seq: torch.Tensor | None = None,
img_cond_seq_ids: torch.Tensor | None = None,
):
# step 0 is the initial state
total_steps = len(timesteps) - 1
@@ -46,6 +49,10 @@ def denoise(
)
# guidance_vec is ignored for schnell.
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
# Store original sequence length for slicing predictions
original_seq_len = img.shape[1]
for step_index, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
@@ -71,10 +78,26 @@ def denoise(
# controlnet_residuals datastructure is efficient in that it likely contains multiple references to the same
# tensors. Calculating the sum materializes each tensor into its own instance.
merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals)
pred_img = torch.cat((img, img_cond), dim=-1) if img_cond is not None else img
# Prepare input for model - concatenate fresh each step
img_input = img
img_input_ids = img_ids
# Add channel-wise conditioning (for ControlNet, FLUX Fill, etc.)
if img_cond is not None:
img_input = torch.cat((img_input, img_cond), dim=-1)
# Add sequence-wise conditioning (for Kontext)
if img_cond_seq is not None:
assert img_cond_seq_ids is not None, (
"You need to provide either both or neither of the sequence conditioning"
)
img_input = torch.cat((img_input, img_cond_seq), dim=1)
img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1)
pred = model(
img=pred_img,
img_ids=img_ids,
img=img_input,
img_ids=img_input_ids,
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
@@ -88,6 +111,10 @@ def denoise(
regional_prompting_extension=pos_regional_prompting_extension,
)
# Slice prediction to only include the main image tokens
if img_input_ids is not None:
pred = pred[:, :original_seq_len]
step_cfg_scale = cfg_scale[step_index]
# If step_cfg_scale, is 1.0, then we don't need to run the negative prediction.

View File

@@ -0,0 +1,149 @@
import einops
import numpy as np
import torch
from einops import repeat
from PIL import Image
from invokeai.app.invocations.fields import FluxKontextConditioningField
from invokeai.app.invocations.flux_vae_encode import FluxVaeEncodeInvocation
from invokeai.app.invocations.model import VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.sampling_utils import pack
from invokeai.backend.flux.util import PREFERED_KONTEXT_RESOLUTIONS
def generate_img_ids_with_offset(
latent_height: int,
latent_width: int,
batch_size: int,
device: torch.device,
dtype: torch.dtype,
idx_offset: int = 0,
) -> torch.Tensor:
"""Generate tensor of image position ids with an optional offset.
Args:
latent_height (int): Height of image in latent space (after packing, this becomes h//2).
latent_width (int): Width of image in latent space (after packing, this becomes w//2).
batch_size (int): Number of images in the batch.
device (torch.device): Device to create tensors on.
dtype (torch.dtype): Data type for the tensors.
idx_offset (int): Offset to add to the first dimension of the image ids.
Returns:
torch.Tensor: Image position ids with shape [batch_size, (latent_height//2 * latent_width//2), 3].
"""
if device.type == "mps":
orig_dtype = dtype
dtype = torch.float16
# After packing, the spatial dimensions are halved due to the 2x2 patch structure
packed_height = latent_height // 2
packed_width = latent_width // 2
# Create base tensor for position IDs with shape [packed_height, packed_width, 3]
# The 3 channels represent: [batch_offset, y_position, x_position]
img_ids = torch.zeros(packed_height, packed_width, 3, device=device, dtype=dtype)
# Set the batch offset for all positions
img_ids[..., 0] = idx_offset
# Create y-coordinate indices (vertical positions)
y_indices = torch.arange(packed_height, device=device, dtype=dtype)
# Broadcast y_indices to match the spatial dimensions [packed_height, 1]
img_ids[..., 1] = y_indices[:, None]
# Create x-coordinate indices (horizontal positions)
x_indices = torch.arange(packed_width, device=device, dtype=dtype)
# Broadcast x_indices to match the spatial dimensions [1, packed_width]
img_ids[..., 2] = x_indices[None, :]
# Expand to include batch dimension: [batch_size, (packed_height * packed_width), 3]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
if device.type == "mps":
img_ids = img_ids.to(orig_dtype)
return img_ids
class KontextExtension:
"""Applies FLUX Kontext (reference image) conditioning."""
def __init__(
self,
kontext_conditioning: FluxKontextConditioningField,
context: InvocationContext,
vae_field: VAEField,
device: torch.device,
dtype: torch.dtype,
):
"""
Initializes the KontextExtension, pre-processing the reference image
into latents and positional IDs.
"""
self._context = context
self._device = device
self._dtype = dtype
self._vae_field = vae_field
self.kontext_conditioning = kontext_conditioning
# Pre-process and cache the kontext latents and ids upon initialization.
self.kontext_latents, self.kontext_ids = self._prepare_kontext()
def _prepare_kontext(self) -> tuple[torch.Tensor, torch.Tensor]:
"""Encodes the reference image and prepares its latents and IDs."""
image = self._context.images.get_pil(self.kontext_conditioning.image.image_name)
# Calculate aspect ratio of input image
width, height = image.size
aspect_ratio = width / height
# Find the closest preferred resolution by aspect ratio
_, target_width, target_height = min(
((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS), key=lambda x: x[0]
)
# Apply BFL's scaling formula
# This ensures compatibility with the model's training
scaled_width = 2 * int(target_width / 16)
scaled_height = 2 * int(target_height / 16)
# Resize to the exact resolution used during training
image = image.convert("RGB")
final_width = 8 * scaled_width
final_height = 8 * scaled_height
image = image.resize((final_width, final_height), Image.Resampling.LANCZOS)
# Convert to tensor with same normalization as BFL
image_np = np.array(image)
image_tensor = torch.from_numpy(image_np).float() / 127.5 - 1.0
image_tensor = einops.rearrange(image_tensor, "h w c -> 1 c h w")
image_tensor = image_tensor.to(self._device)
# Continue with VAE encoding
vae_info = self._context.models.load(self._vae_field.vae)
kontext_latents_unpacked = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
# Extract tensor dimensions
batch_size, _, latent_height, latent_width = kontext_latents_unpacked.shape
# Pack the latents and generate IDs
kontext_latents_packed = pack(kontext_latents_unpacked).to(self._device, self._dtype)
kontext_ids = generate_img_ids_with_offset(
latent_height=latent_height,
latent_width=latent_width,
batch_size=batch_size,
device=self._device,
dtype=self._dtype,
idx_offset=1,
)
return kontext_latents_packed, kontext_ids
def ensure_batch_size(self, target_batch_size: int) -> None:
"""Ensures the kontext latents and IDs match the target batch size by repeating if necessary."""
if self.kontext_latents.shape[0] != target_batch_size:
self.kontext_latents = self.kontext_latents.repeat(target_batch_size, 1, 1)
self.kontext_ids = self.kontext_ids.repeat(target_batch_size, 1, 1)

View File

@@ -174,11 +174,13 @@ def generate_img_ids(h: int, w: int, batch_size: int, device: torch.device, dtyp
dtype = torch.float16
img_ids = torch.zeros(h // 2, w // 2, 3, device=device, dtype=dtype)
# Set batch offset to 0 for main image tokens
img_ids[..., 0] = 0
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=device, dtype=dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=device, dtype=dtype)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
if device.type == "mps":
img_ids.to(orig_dtype)
img_ids = img_ids.to(orig_dtype)
return img_ids

View File

@@ -18,6 +18,29 @@ class ModelSpec:
repo_ae: str | None
# Preferred resolutions for Kontext models to avoid tiling artifacts
# These are the specific resolutions the model was trained on
PREFERED_KONTEXT_RESOLUTIONS = [
(672, 1568),
(688, 1504),
(720, 1456),
(752, 1392),
(800, 1328),
(832, 1248),
(880, 1184),
(944, 1104),
(1024, 1024),
(1104, 944),
(1184, 880),
(1248, 832),
(1328, 800),
(1392, 752),
(1456, 720),
(1504, 688),
(1568, 672),
]
max_seq_lengths: Dict[str, Literal[256, 512]] = {
"flux-dev": 512,
"flux-dev-fill": 512,

View File

@@ -187,7 +187,7 @@ class ModelConfigBase(ABC, BaseModel):
else:
return config_cls.from_model_on_disk(mod, **overrides)
raise InvalidModelConfigException("No valid config found")
raise InvalidModelConfigException("Unable to determine model type")
@classmethod
def get_tag(cls) -> Tag:

View File

@@ -7,7 +7,14 @@ from typing import Optional
import accelerate
import torch
from safetensors.torch import load_file
from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
from transformers import (
AutoConfig,
AutoModelForTextEncoding,
CLIPTextModel,
CLIPTokenizer,
T5EncoderModel,
T5TokenizerFast,
)
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
@@ -139,7 +146,7 @@ class BnbQuantizedLlmInt8bCheckpointModel(ModelLoader):
)
match submodel_type:
case SubModelType.Tokenizer2 | SubModelType.Tokenizer3:
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
return T5TokenizerFast.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
case SubModelType.TextEncoder2 | SubModelType.TextEncoder3:
te2_model_path = Path(config.path) / "text_encoder_2"
model_config = AutoConfig.from_pretrained(te2_model_path)
@@ -183,7 +190,7 @@ class T5EncoderCheckpointModel(ModelLoader):
match submodel_type:
case SubModelType.Tokenizer2 | SubModelType.Tokenizer3:
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
return T5TokenizerFast.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
case SubModelType.TextEncoder2 | SubModelType.TextEncoder3:
return T5EncoderModel.from_pretrained(
Path(config.path) / "text_encoder_2", torch_dtype="auto", low_cpu_mem_usage=True

View File

@@ -23,7 +23,7 @@ class StarterModel(StarterModelWithoutDependencies):
dependencies: Optional[list[StarterModelWithoutDependencies]] = None
class StarterModelBundles(BaseModel):
class StarterModelBundle(BaseModel):
name: str
models: list[StarterModel]
@@ -109,7 +109,7 @@ flux_vae = StarterModel(
# region: Main
flux_schnell_quantized = StarterModel(
name="FLUX Schnell (Quantized)",
name="FLUX.1 schnell (quantized)",
base=BaseModelType.Flux,
source="InvokeAI/flux_schnell::transformer/bnb_nf4/flux1-schnell-bnb_nf4.safetensors",
description="FLUX schnell transformer quantized to bitsandbytes NF4 format. Total size with dependencies: ~12GB",
@@ -117,7 +117,7 @@ flux_schnell_quantized = StarterModel(
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
)
flux_dev_quantized = StarterModel(
name="FLUX Dev (Quantized)",
name="FLUX.1 dev (quantized)",
base=BaseModelType.Flux,
source="InvokeAI/flux_dev::transformer/bnb_nf4/flux1-dev-bnb_nf4.safetensors",
description="FLUX dev transformer quantized to bitsandbytes NF4 format. Total size with dependencies: ~12GB",
@@ -125,7 +125,7 @@ flux_dev_quantized = StarterModel(
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
)
flux_schnell = StarterModel(
name="FLUX Schnell",
name="FLUX.1 schnell",
base=BaseModelType.Flux,
source="InvokeAI/flux_schnell::transformer/base/flux1-schnell.safetensors",
description="FLUX schnell transformer in bfloat16. Total size with dependencies: ~33GB",
@@ -133,13 +133,29 @@ flux_schnell = StarterModel(
dependencies=[t5_base_encoder, flux_vae, clip_l_encoder],
)
flux_dev = StarterModel(
name="FLUX Dev",
name="FLUX.1 dev",
base=BaseModelType.Flux,
source="InvokeAI/flux_dev::transformer/base/flux1-dev.safetensors",
description="FLUX dev transformer in bfloat16. Total size with dependencies: ~33GB",
type=ModelType.Main,
dependencies=[t5_base_encoder, flux_vae, clip_l_encoder],
)
flux_kontext = StarterModel(
name="FLUX.1 Kontext dev",
base=BaseModelType.Flux,
source="https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev/resolve/main/flux1-kontext-dev.safetensors",
description="FLUX.1 Kontext dev transformer in bfloat16. Total size with dependencies: ~33GB",
type=ModelType.Main,
dependencies=[t5_base_encoder, flux_vae, clip_l_encoder],
)
flux_kontext_quantized = StarterModel(
name="FLUX.1 Kontext dev (Quantized)",
base=BaseModelType.Flux,
source="https://huggingface.co/unsloth/FLUX.1-Kontext-dev-GGUF/resolve/main/flux1-kontext-dev-Q4_K_M.gguf",
description="FLUX.1 Kontext dev quantized (q4_k_m). Total size with dependencies: ~14GB",
type=ModelType.Main,
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
)
sd35_medium = StarterModel(
name="SD3.5 Medium",
base=BaseModelType.StableDiffusion3,
@@ -656,6 +672,7 @@ flux_fill = StarterModel(
# List of starter models, displayed on the frontend.
# The order/sort of this list is not changed by the frontend - set it how you want it here.
STARTER_MODELS: list[StarterModel] = [
flux_kontext_quantized,
flux_schnell_quantized,
flux_dev_quantized,
flux_schnell,
@@ -776,12 +793,13 @@ flux_bundle: list[StarterModel] = [
flux_depth_control_lora,
flux_redux,
flux_fill,
flux_kontext_quantized,
]
STARTER_BUNDLES: dict[str, list[StarterModel]] = {
BaseModelType.StableDiffusion1: sd1_bundle,
BaseModelType.StableDiffusionXL: sdxl_bundle,
BaseModelType.Flux: flux_bundle,
STARTER_BUNDLES: dict[str, StarterModelBundle] = {
BaseModelType.StableDiffusion1: StarterModelBundle(name="Stable Diffusion 1.5", models=sd1_bundle),
BaseModelType.StableDiffusionXL: StarterModelBundle(name="SDXL", models=sdxl_bundle),
BaseModelType.Flux: StarterModelBundle(name="FLUX.1 dev", models=flux_bundle),
}
assert len(STARTER_MODELS) == len({m.source for m in STARTER_MODELS}), "Duplicate starter models"

View File

@@ -17,6 +17,15 @@ module.exports = {
'no-promise-executor-return': 'error',
// https://eslint.org/docs/latest/rules/require-await
'require-await': 'error',
// Restrict setActiveTab calls to only use-navigation-api.tsx
'no-restricted-syntax': [
'error',
{
selector: 'CallExpression[callee.name="setActiveTab"]',
message:
'setActiveTab() can only be called from use-navigation-api.tsx. Use navigationApi.switchToTab() instead.',
},
],
// TODO: ENABLE THIS RULE BEFORE v6.0.0
'react/display-name': 'off',
'no-restricted-properties': [
@@ -56,6 +65,15 @@ module.exports = {
],
},
overrides: [
/**
* Allow setActiveTab calls only in use-navigation-api.tsx
*/
{
files: ['**/use-navigation-api.tsx'],
rules: {
'no-restricted-syntax': 'off',
},
},
/**
* Overrides for stories
*/

View File

@@ -3,8 +3,6 @@ import type { KnipConfig } from 'knip';
const config: KnipConfig = {
project: ['src/**/*.{ts,tsx}!'],
ignore: [
// TODO(psyche): temporarily ignored all files for test build purposes
'src/**',
// This file is only used during debugging
'src/app/store/middleware/debugLoggerMiddleware.ts',
// Autogenerated types - shouldn't ever touch these
@@ -14,10 +12,8 @@ const config: KnipConfig = {
'src/features/parameters/types/parameterSchemas.ts',
// TODO(psyche): maybe we can clean up these utils after canvas v2 release
'src/features/controlLayers/konva/util.ts',
// TODO(psyche): restore HRF functionality?
'src/features/hrf/**',
// This feature is (temprarily?) disabled
'src/features/controlLayers/components/InpaintMask/InpaintMaskAddButtons.tsx',
// Will be using this
'src/common/hooks/useAsyncState.ts',
],
ignoreBinaries: ['only-allow'],
paths: {

View File

@@ -38,19 +38,6 @@
"test:ui": "vitest --coverage --ui",
"test:no-watch": "vitest --no-watch"
},
"madge": {
"excludeRegExp": [
"^index.ts$"
],
"detectiveOptions": {
"ts": {
"skipTypeImports": true
},
"tsx": {
"skipTypeImports": true
}
}
},
"dependencies": {
"@atlaskit/pragmatic-drag-and-drop": "^1.7.4",
"@atlaskit/pragmatic-drag-and-drop-auto-scroll": "^2.1.1",
@@ -64,6 +51,7 @@
"@reduxjs/toolkit": "2.8.2",
"@roarr/browser-log-writer": "^1.3.0",
"@xyflow/react": "^12.7.1",
"ag-psd": "^28.2.1",
"async-mutex": "^0.5.0",
"chakra-react-select": "^4.9.2",
"cmdk": "^1.1.1",
@@ -75,7 +63,7 @@
"framer-motion": "^11.10.0",
"i18next": "^25.2.1",
"i18next-http-backend": "^3.0.2",
"idb-keyval": "^6.2.2",
"idb-keyval": "6.2.1",
"jsondiffpatch": "^0.7.3",
"konva": "^9.3.20",
"linkify-react": "^4.3.1",
@@ -146,7 +134,7 @@
"eslint": "^8.57.1",
"eslint-plugin-i18next": "^6.1.1",
"eslint-plugin-path": "^1.3.0",
"knip": "^5.50.5",
"knip": "^5.61.3",
"openapi-types": "^12.1.3",
"openapi-typescript": "^7.6.1",
"prettier": "^3.5.3",
@@ -155,7 +143,7 @@
"tsafe": "^1.8.5",
"type-fest": "^4.40.0",
"typescript": "^5.8.3",
"vite": "^6.3.3",
"vite": "^7.0.2",
"vite-plugin-css-injected-by-js": "^3.5.2",
"vite-plugin-dts": "^4.5.3",
"vite-plugin-eslint": "^1.8.1",
@@ -163,7 +151,7 @@
"vitest": "^3.1.2"
},
"engines": {
"pnpm": "8"
"pnpm": "10"
},
"packageManager": "pnpm@8.15.9+sha512.499434c9d8fdd1a2794ebf4552b3b25c0a633abcee5bb15e7b5de90f32f47b513aca98cd5cfd001c31f0db454bc3804edccd578501e4ca293a6816166bbd9f81"
"packageManager": "pnpm@10.12.4"
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,3 @@
onlyBuiltDependencies:
- '@swc/core'
- esbuild

View File

@@ -225,7 +225,16 @@
"prompt": {
"addPromptTrigger": "Add Prompt Trigger",
"compatibleEmbeddings": "Compatible Embeddings",
"noMatchingTriggers": "No matching triggers"
"noMatchingTriggers": "No matching triggers",
"generateFromImage": "Generate prompt from image",
"expandCurrentPrompt": "Expand Current Prompt",
"uploadImageForPromptGeneration": "Upload Image for Prompt Generation",
"expandingPrompt": "Expanding prompt...",
"resultTitle": "Prompt Expansion Complete",
"resultSubtitle": "Choose how to handle the expanded prompt:",
"replace": "Replace",
"insert": "Insert",
"discard": "Discard"
},
"queue": {
"queue": "Queue",
@@ -335,14 +344,14 @@
"images": "Images",
"assets": "Assets",
"alwaysShowImageSizeBadge": "Always Show Image Size Badge",
"assetsTab": "Files youve uploaded for use in your projects.",
"assetsTab": "Files you've uploaded for use in your projects.",
"autoAssignBoardOnClick": "Auto-Assign Board on Click",
"autoSwitchNewImages": "Auto-Switch to New Images",
"boardsSettings": "Boards Settings",
"copy": "Copy",
"currentlyInUse": "This image is currently in use in the following features:",
"drop": "Drop",
"dropOrUpload": "$t(gallery.drop) or Upload",
"dropOrUpload": "Drop or Upload",
"dropToUpload": "$t(gallery.drop) to Upload",
"deleteImage_one": "Delete Image",
"deleteImage_other": "Delete {{count}} Images",
@@ -357,7 +366,7 @@
"gallerySettings": "Gallery Settings",
"go": "Go",
"image": "image",
"imagesTab": "Images youve created and saved within Invoke.",
"imagesTab": "Images you've created and saved within Invoke.",
"imagesSettings": "Gallery Images Settings",
"jump": "Jump",
"loading": "Loading",
@@ -396,7 +405,8 @@
"compareHelp4": "Press <Kbd>Z</Kbd> or <Kbd>Esc</Kbd> to exit.",
"openViewer": "Open Viewer",
"closeViewer": "Close Viewer",
"move": "Move"
"move": "Move",
"useForPromptGeneration": "Use for Prompt Generation"
},
"hotkeys": {
"hotkeys": "Hotkeys",
@@ -579,6 +589,16 @@
"cancelTransform": {
"title": "Cancel Transform",
"desc": "Cancel the pending transform."
},
"settings": {
"behavior": "Behavior",
"display": "Display",
"grid": "Grid",
"debug": "Debug"
},
"toggleNonRasterLayers": {
"title": "Toggle Non-Raster Layers",
"desc": "Show or hide all non-raster layer categories (Control Layers, Inpaint Masks, Regional Guidance)."
}
},
"workflows": {
@@ -742,7 +762,7 @@
"vae": "VAE",
"width": "Width",
"workflow": "Workflow",
"canvasV2Metadata": "Canvas"
"canvasV2Metadata": "Canvas Layers"
},
"modelManager": {
"active": "active",
@@ -763,7 +783,7 @@
"convertToDiffusers": "Convert To Diffusers",
"convertToDiffusersHelpText1": "This model will be converted to the 🧨 Diffusers format.",
"convertToDiffusersHelpText2": "This process will replace your Model Manager entry with the Diffusers version of the same model.",
"convertToDiffusersHelpText3": "Your checkpoint file on disk WILL be deleted if it is in InvokeAI root folder. If it is in a custom location, then it WILL NOT be deleted.",
"convertToDiffusersHelpText3": "Your checkpoint file on disk WILL be deleted if it is in the InvokeAI root folder. If it is in a custom location, then it WILL NOT be deleted.",
"convertToDiffusersHelpText4": "This is a one time process only. It might take around 30s-60s depending on the specifications of your computer.",
"convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 2GB-7GB in size.",
"convertToDiffusersHelpText6": "Do you wish to convert this model?",
@@ -806,7 +826,11 @@
"urlUnauthorizedErrorMessage": "You may need to configure an API token to access this model.",
"urlUnauthorizedErrorMessage2": "Learn how here.",
"imageEncoderModelId": "Image Encoder Model ID",
"includesNModels": "Includes {{n}} models and their dependencies",
"installedModelsCount": "{{installed}} of {{total}} models installed.",
"includesNModels": "Includes {{n}} models and their dependencies.",
"allNModelsInstalled": "All {{count}} models installed",
"nToInstall": "{{count}} to install",
"nAlreadyInstalled": "{{count}} already installed",
"installQueue": "Install Queue",
"inplaceInstall": "In-place install",
"inplaceInstallDesc": "Install models without copying the files. When using the model, it will be loaded from its this location. If disabled, the model file(s) will be copied into the Invoke-managed models directory during installation.",
@@ -869,6 +893,25 @@
"starterBundleHelpText": "Easily install all models needed to get started with a base model, including a main model, controlnets, IP adapters, and more. Selecting a bundle will skip any models that you already have installed.",
"starterModels": "Starter Models",
"starterModelsInModelManager": "Starter Models can be found in Model Manager",
"bundleAlreadyInstalled": "Bundle already installed",
"bundleAlreadyInstalledDesc": "All models in the {{bundleName}} bundle are already installed.",
"launchpadTab": "Launchpad",
"launchpad": {
"welcome": "Welcome to Model Management",
"description": "Invoke requires models to be installed to utilize most features of the platform. Choose from manual installation options or explore curated starter models.",
"manualInstall": "Manual Installation",
"urlDescription": "Install models from a URL or local file path. Perfect for specific models you want to add.",
"huggingFaceDescription": "Browse and install models directly from HuggingFace repositories.",
"scanFolderDescription": "Scan a local folder to automatically detect and install models.",
"recommendedModels": "Recommended Models",
"exploreStarter": "Or browse all available starter models",
"quickStart": "Quick Start Bundles",
"bundleDescription": "Each bundle includes essential models for each model family and curated base models to get started.",
"browseAll": "Or browse all available models:",
"stableDiffusion15": "Stable Diffusion 1.5",
"sdxl": "SDXL",
"fluxDev": "FLUX.1 dev"
},
"controlLora": "Control LoRA",
"llavaOnevision": "LLaVA OneVision",
"syncModels": "Sync Models",
@@ -905,7 +948,8 @@
"selectModel": "Select a Model",
"noLoRAsInstalled": "No LoRAs installed",
"noRefinerModelsInstalled": "No SDXL Refiner models installed",
"defaultVAE": "Default VAE"
"defaultVAE": "Default VAE",
"noCompatibleLoRAs": "No Compatible LoRAs"
},
"nodes": {
"arithmeticSequence": "Arithmetic Sequence",
@@ -1155,7 +1199,9 @@
"canvasIsSelectingObject": "Canvas is busy (selecting object)",
"noPrompts": "No prompts generated",
"noNodesInGraph": "No nodes in graph",
"systemDisconnected": "System disconnected"
"systemDisconnected": "System disconnected",
"promptExpansionPending": "Prompt expansion in progress",
"promptExpansionResultPending": "Please accept or discard your prompt expansion result"
},
"maskBlur": "Mask Blur",
"negativePromptPlaceholder": "Negative Prompt",
@@ -1313,6 +1359,21 @@
"problemCopyingLayer": "Unable to Copy Layer",
"problemSavingLayer": "Unable to Save Layer",
"problemDownloadingImage": "Unable to Download Image",
"noRasterLayers": "No Raster Layers Found",
"noRasterLayersDesc": "Create at least one raster layer to export to PSD",
"noActiveRasterLayers": "No Active Raster Layers",
"noActiveRasterLayersDesc": "Enable at least one raster layer to export to PSD",
"noVisibleRasterLayers": "No Visible Raster Layers",
"noVisibleRasterLayersDesc": "Enable at least one raster layer to export to PSD",
"invalidCanvasDimensions": "Invalid Canvas Dimensions",
"canvasTooLarge": "Canvas Too Large",
"canvasTooLargeDesc": "Canvas dimensions exceed the maximum allowed size for PSD export. Reduce the total width and height of the canvas of the canvas and try again.",
"failedToProcessLayers": "Failed to Process Layers",
"psdExportSuccess": "PSD Export Complete",
"psdExportSuccessDesc": "Successfully exported {{count}} layers to PSD file",
"problemExportingPSD": "Problem Exporting PSD",
"canvasManagerNotAvailable": "Canvas Manager Not Available",
"noValidLayerAdapters": "No Valid Layer Adapters Found",
"pasteSuccess": "Pasted to {{destination}}",
"pasteFailed": "Paste Failed",
"prunedQueue": "Pruned Queue",
@@ -1338,10 +1399,15 @@
"fluxFillIncompatibleWithT2IAndI2I": "FLUX Fill is not compatible with Text to Image or Image to Image. Use other FLUX models for these tasks.",
"imagenIncompatibleGenerationMode": "Google {{model}} supports Text to Image only. Use other models for Image to Image, Inpainting and Outpainting tasks.",
"chatGPT4oIncompatibleGenerationMode": "ChatGPT 4o supports Text to Image and Image to Image only. Use other models Inpainting and Outpainting tasks.",
"fluxKontextIncompatibleGenerationMode": "Flux Kontext supports Text to Image only. Use other models for Image to Image, Inpainting and Outpainting tasks.",
"fluxKontextIncompatibleGenerationMode": "FLUX Kontext does not support generation from images placed on the canvas. Re-try using the Reference Image section and disable any Raster Layers.",
"problemUnpublishingWorkflow": "Problem Unpublishing Workflow",
"problemUnpublishingWorkflowDescription": "There was a problem unpublishing the workflow. Please try again.",
"workflowUnpublished": "Workflow Unpublished"
"workflowUnpublished": "Workflow Unpublished",
"sentToCanvas": "Sent to Canvas",
"sentToUpscale": "Sent to Upscale",
"promptGenerationStarted": "Prompt generation started",
"uploadAndPromptGenerationFailed": "Failed to upload image and generate prompt",
"promptExpansionFailed": "We ran into an issue. Please try prompt expansion again."
},
"popovers": {
"clipSkip": {
@@ -1864,6 +1930,7 @@
"saveCanvasToGallery": "Save Canvas to Gallery",
"saveBboxToGallery": "Save Bbox to Gallery",
"saveLayerToAssets": "Save Layer to Assets",
"exportCanvasToPSD": "Export Canvas to PSD",
"cropLayerToBbox": "Crop Layer to Bbox",
"savedToGalleryOk": "Saved to Gallery",
"savedToGalleryError": "Error saving to gallery",
@@ -1889,11 +1956,13 @@
"mergingLayers": "Merging layers",
"clearHistory": "Clear History",
"bboxOverlay": "Show Bbox Overlay",
"ruleOfThirds": "Show Rule of Thirds",
"newSession": "New Session",
"clearCaches": "Clear Caches",
"recalculateRects": "Recalculate Rects",
"clipToBbox": "Clip Strokes to Bbox",
"outputOnlyMaskedRegions": "Output Only Generated Regions",
"saveAllImagesToGallery": "Save All Images to Gallery",
"addLayer": "Add Layer",
"duplicate": "Duplicate",
"moveToFront": "Move to Front",
@@ -1994,6 +2063,8 @@
"disableTransparencyEffect": "Disable Transparency Effect",
"hidingType": "Hiding {{type}}",
"showingType": "Showing {{type}}",
"showNonRasterLayers": "Show Non-Raster Layers (Shift+H)",
"hideNonRasterLayers": "Hide Non-Raster Layers (Shift+H)",
"dynamicGrid": "Dynamic Grid",
"logDebugInfo": "Log Debug Info",
"locked": "Locked",
@@ -2260,6 +2331,9 @@
"label": "Preserve Masked Region",
"alert": "Preserving Masked Region"
},
"saveAllImagesToGallery": {
"alert": "Saving All Images to Gallery"
},
"isolatedStagingPreview": "Isolated Staging Preview",
"isolatedPreview": "Isolated Preview",
"isolatedLayerPreview": "Isolated Layer Preview",
@@ -2288,6 +2362,7 @@
"newGlobalReferenceImage": "New Global Reference Image",
"newRegionalReferenceImage": "New Regional Reference Image",
"newControlLayer": "New Control Layer",
"newResizedControlLayer": "New Resized Control Layer",
"newRasterLayer": "New Raster Layer",
"newInpaintMask": "New Inpaint Mask",
"newRegionalGuidance": "New Regional Guidance",
@@ -2305,6 +2380,11 @@
"saveToGallery": "Save To Gallery",
"showResultsOn": "Showing Results",
"showResultsOff": "Hiding Results"
},
"autoSwitch": {
"off": "Off",
"switchOnStart": "On Start",
"switchOnFinish": "On Finish"
}
},
"upscaling": {
@@ -2371,7 +2451,8 @@
"uploadImage": "Upload Image",
"useForTemplate": "Use For Prompt Template",
"viewList": "View Template List",
"viewModeTooltip": "This is how your prompt will look with your currently selected template. To edit your prompt, click anywhere in the text box."
"viewModeTooltip": "This is how your prompt will look with your currently selected template. To edit your prompt, click anywhere in the text box.",
"togglePromptPreviews": "Toggle Prompt Previews"
},
"upsell": {
"inviteTeammates": "Invite Teammates",
@@ -2391,6 +2472,55 @@
"upscaling": "Upscaling",
"upscalingTab": "$t(ui.tabs.upscaling) $t(common.tab)",
"gallery": "Gallery"
},
"launchpad": {
"workflowsTitle": "Go deep with Workflows.",
"upscalingTitle": "Upscale and add detail.",
"canvasTitle": "Edit and refine on Canvas.",
"generateTitle": "Generate images from text prompts.",
"modelGuideText": "Want to learn what prompts work best for each model?",
"modelGuideLink": "Check out our Model Guide.",
"workflows": {
"description": "Workflows are reusable templates that automate image generation tasks, allowing you to quickly perform complex operations and get consistent results.",
"learnMoreLink": "Learn more about creating workflows",
"browseTemplates": {
"title": "Browse Workflow Templates",
"description": "Choose from pre-built workflows for common tasks"
},
"createNew": {
"title": "Create a new Workflow",
"description": "Start a new workflow from scratch"
},
"loadFromFile": {
"title": "Load workflow from file",
"description": "Upload a workflow to start with an existing setup"
}
},
"upscaling": {
"uploadImage": {
"title": "Upload Image to Upscale",
"description": "Click or drag an image to upscale (JPG, PNG, WebP up to 100MB)"
},
"replaceImage": {
"title": "Replace Current Image",
"description": "Click or drag a new image to replace the current one"
},
"imageReady": {
"title": "Image Ready",
"description": "Press Invoke to begin upscaling"
},
"readyToUpscale": {
"title": "Ready to upscale!",
"description": "Configure your settings below, then click the Invoke button to begin upscaling your image."
},
"upscaleModel": "Upscale Model",
"model": "Model",
"scale": "Scale",
"helpText": {
"promptAdvice": "When upscaling, use a prompt that describes the medium and style. Avoid describing specific content details in the image.",
"styleAdvice": "Upscaling works best with the general style of your image."
}
}
}
},
"system": {
@@ -2430,8 +2560,9 @@
"whatsNew": {
"whatsNewInInvoke": "What's New in Invoke",
"items": [
"Inpainting: Per-mask noise levels and denoise limits.",
"Canvas: Smarter aspect ratios for SDXL and improved scroll-to-zoom."
"Generate images faster with new Launchpads and a simplified Generate tab.",
"Edit with prompts using Flux Kontext Dev.",
"Export to PSD, bulk-hide overlays, organize models & images — all in a reimagined interface built for control."
],
"readReleaseNotes": "Read Release Notes",
"watchRecentReleaseVideos": "Watch Recent Release Videos",
@@ -2440,62 +2571,16 @@
"supportVideos": {
"supportVideos": "Support Videos",
"gettingStarted": "Getting Started",
"controlCanvas": "Control Canvas",
"watch": "Watch",
"studioSessionsDesc1": "Check out the <StudioSessionsPlaylistLink /> for Invoke deep dives.",
"studioSessionsDesc2": "Join our <DiscordLink /> to participate in the live sessions and ask questions. Sessions are uploaded to the playlist the following week.",
"studioSessionsDesc": "Join our <DiscordLink /> to participate in the live sessions and ask questions. Sessions are uploaded to the playlist the following week.",
"videos": {
"creatingYourFirstImage": {
"title": "Creating Your First Image",
"description": "Introduction to creating an image from scratch using Invoke's tools."
"gettingStarted": {
"title": "Getting Started with Invoke",
"description": "Complete video series covering everything you need to know to get started with Invoke, from creating your first image to advanced techniques."
},
"usingControlLayersAndReferenceGuides": {
"title": "Using Control Layers and Reference Guides",
"description": "Learn how to guide your image creation with control layers and reference images."
},
"understandingImageToImageAndDenoising": {
"title": "Understanding Image-to-Image and Denoising",
"description": "Overview of image-to-image transformations and denoising in Invoke."
},
"exploringAIModelsAndConceptAdapters": {
"title": "Exploring AI Models and Concept Adapters",
"description": "Dive into AI models and how to use concept adapters for creative control."
},
"creatingAndComposingOnInvokesControlCanvas": {
"title": "Creating and Composing on Invoke's Control Canvas",
"description": "Learn to compose images using Invoke's control canvas."
},
"upscaling": {
"title": "Upscaling",
"description": "How to upscale images with Invoke's tools to enhance resolution."
},
"howDoIGenerateAndSaveToTheGallery": {
"title": "How Do I Generate and Save to the Gallery?",
"description": "Steps to generate and save images to the gallery."
},
"howDoIEditOnTheCanvas": {
"title": "How Do I Edit on the Canvas?",
"description": "Guide to editing images directly on the canvas."
},
"howDoIDoImageToImageTransformation": {
"title": "How Do I Do Image-to-Image Transformation?",
"description": "Tutorial on performing image-to-image transformations in Invoke."
},
"howDoIUseControlNetsAndControlLayers": {
"title": "How Do I Use Control Nets and Control Layers?",
"description": "Learn to apply control layers and controlnets to your images."
},
"howDoIUseGlobalIPAdaptersAndReferenceImages": {
"title": "How Do I Use Global IP Adapters and Reference Images?",
"description": "Introduction to adding reference images and global IP adapters."
},
"howDoIUseInpaintMasks": {
"title": "How Do I Use Inpaint Masks?",
"description": "How to apply inpaint masks for image correction and variation."
},
"howDoIOutpaint": {
"title": "How Do I Outpaint?",
"description": "Guide to outpainting beyond the original image borders."
"studioSessions": {
"title": "Studio Sessions",
"description": "Deep dive sessions exploring advanced Invoke features, creative workflows, and community discussions."
}
}
}

View File

@@ -2,8 +2,7 @@ import { Box } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { GlobalHookIsolator } from 'app/components/GlobalHookIsolator';
import { GlobalModalIsolator } from 'app/components/GlobalModalIsolator';
import type { StudioInitAction } from 'app/hooks/useStudioInitAction';
import { $globalIsLoading } from 'app/store/nanostores/globalIsLoading';
import { $didStudioInit, type StudioInitAction } from 'app/hooks/useStudioInitAction';
import type { PartialAppConfig } from 'app/types/invokeai';
import Loading from 'common/components/Loading/Loading';
import { useClearStorage } from 'common/hooks/useClearStorage';
@@ -12,6 +11,7 @@ import { memo, useCallback } from 'react';
import { ErrorBoundary } from 'react-error-boundary';
import AppErrorBoundaryFallback from './AppErrorBoundaryFallback';
import ThemeLocaleProvider from './ThemeLocaleProvider';
const DEFAULT_CONFIG = {};
interface Props {
@@ -20,7 +20,7 @@ interface Props {
}
const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
const globalIsLoading = useStore($globalIsLoading);
const didStudioInit = useStore($didStudioInit);
const clearStorage = useClearStorage();
const handleReset = useCallback(() => {
@@ -31,12 +31,14 @@ const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
return (
<ErrorBoundary onReset={handleReset} FallbackComponent={AppErrorBoundaryFallback}>
<Box id="invoke-app-wrapper" w="100dvw" h="100dvh" position="relative" overflow="hidden">
<AppContent />
{globalIsLoading && <Loading />}
</Box>
<GlobalHookIsolator config={config} studioInitAction={studioInitAction} />
<GlobalModalIsolator />
<ThemeLocaleProvider>
<Box id="invoke-app-wrapper" w="100dvw" h="100dvh" position="relative" overflow="hidden">
<AppContent />
{!didStudioInit && <Loading />}
</Box>
<GlobalHookIsolator config={config} studioInitAction={studioInitAction} />
<GlobalModalIsolator />
</ThemeLocaleProvider>
</ErrorBoundary>
);
};

View File

@@ -1,4 +1,5 @@
import { useGlobalModifiersInit } from '@invoke-ai/ui-library';
import { setupListeners } from '@reduxjs/toolkit/query';
import type { StudioInitAction } from 'app/hooks/useStudioInitAction';
import { useStudioInitAction } from 'app/hooks/useStudioInitAction';
import { useSyncQueueStatus } from 'app/hooks/useSyncQueueStatus';
@@ -10,13 +11,14 @@ import type { PartialAppConfig } from 'app/types/invokeai';
import { useFocusRegionWatcher } from 'common/hooks/focus';
import { useCloseChakraTooltipsOnDragFix } from 'common/hooks/useCloseChakraTooltipsOnDragFix';
import { useGlobalHotkeys } from 'common/hooks/useGlobalHotkeys';
import { size } from 'es-toolkit/compat';
import { useDndMonitor } from 'features/dnd/useDndMonitor';
import { useDynamicPromptsWatcher } from 'features/dynamicPrompts/hooks/useDynamicPromptsWatcher';
import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast';
import { useWorkflowBuilderWatcher } from 'features/nodes/components/sidePanel/workflow/IsolatedWorkflowBuilderWatcher';
import { useReadinessWatcher } from 'features/queue/store/readiness';
import { configChanged } from 'features/system/store/configSlice';
import { selectLanguage } from 'features/system/store/systemSelectors';
import { useNavigationApi } from 'features/ui/layouts/use-navigation-api';
import i18n from 'i18n';
import { memo, useEffect } from 'react';
import { useGetOpenAPISchemaQuery } from 'services/api/endpoints/appInfo';
@@ -43,6 +45,8 @@ export const GlobalHookIsolator = memo(
useGetOpenAPISchemaQuery();
useSyncLoggingConfig();
useCloseChakraTooltipsOnDragFix();
useNavigationApi();
useDndMonitor();
// Persistent subscription to the queue counts query - canvas relies on this to know if there are pending
// and/or in progress canvas sessions.
@@ -53,16 +57,18 @@ export const GlobalHookIsolator = memo(
}, [language]);
useEffect(() => {
if (size(config)) {
logger.info({ config }, 'Received config');
dispatch(configChanged(config));
}
logger.info({ config }, 'Received config');
dispatch(configChanged(config));
}, [dispatch, config, logger]);
useEffect(() => {
dispatch(appStarted());
}, [dispatch]);
useEffect(() => {
return setupListeners(dispatch);
}, [dispatch]);
useStudioInitAction(studioInitAction);
useStarterModelsToast();
useSyncQueueStatus();

View File

@@ -1,11 +1,14 @@
import { useAppSelector } from 'app/store/storeHooks';
import { useIsRegionFocused } from 'common/hooks/focus';
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { useImageActions } from 'features/gallery/hooks/useImageActions';
import { useLoadWorkflow } from 'features/gallery/hooks/useLoadWorkflow';
import { useRecallAll } from 'features/gallery/hooks/useRecallAll';
import { useRecallDimensions } from 'features/gallery/hooks/useRecallDimensions';
import { useRecallPrompts } from 'features/gallery/hooks/useRecallPrompts';
import { useRecallRemix } from 'features/gallery/hooks/useRecallRemix';
import { useRecallSeed } from 'features/gallery/hooks/useRecallSeed';
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { memo } from 'react';
import { useImageDTO } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
@@ -27,59 +30,64 @@ GlobalImageHotkeys.displayName = 'GlobalImageHotkeys';
const GlobalImageHotkeysInternal = memo(({ imageDTO }: { imageDTO: ImageDTO }) => {
const isGalleryFocused = useIsRegionFocused('gallery');
const isViewerFocused = useIsRegionFocused('viewer');
const imageActions = useImageActions(imageDTO);
const isStaging = useAppSelector(selectIsStaging);
const isUpscalingEnabled = useFeatureStatus('upscaling');
const isFocusOK = isGalleryFocused || isViewerFocused;
const recallAll = useRecallAll(imageDTO);
const recallRemix = useRecallRemix(imageDTO);
const recallPrompts = useRecallPrompts(imageDTO);
const recallSeed = useRecallSeed(imageDTO);
const recallDimensions = useRecallDimensions(imageDTO);
const loadWorkflow = useLoadWorkflow(imageDTO);
useRegisteredHotkeys({
id: 'loadWorkflow',
category: 'viewer',
callback: imageActions.loadWorkflow,
options: { enabled: isGalleryFocused || isViewerFocused },
dependencies: [imageActions.loadWorkflow, isGalleryFocused, isViewerFocused],
callback: loadWorkflow.load,
options: { enabled: loadWorkflow.isEnabled && isFocusOK },
dependencies: [loadWorkflow, isFocusOK],
});
useRegisteredHotkeys({
id: 'recallAll',
category: 'viewer',
callback: imageActions.recallAll,
options: { enabled: !isStaging && (isGalleryFocused || isViewerFocused) },
dependencies: [imageActions.recallAll, isStaging, isGalleryFocused, isViewerFocused],
callback: recallAll.recall,
options: { enabled: recallAll.isEnabled && isFocusOK },
dependencies: [recallAll, isFocusOK],
});
useRegisteredHotkeys({
id: 'recallSeed',
category: 'viewer',
callback: imageActions.recallSeed,
options: { enabled: isGalleryFocused || isViewerFocused },
dependencies: [imageActions.recallSeed, isGalleryFocused, isViewerFocused],
callback: recallSeed.recall,
options: { enabled: recallSeed.isEnabled && isFocusOK },
dependencies: [recallSeed, isFocusOK],
});
useRegisteredHotkeys({
id: 'recallPrompts',
category: 'viewer',
callback: imageActions.recallPrompts,
options: { enabled: isGalleryFocused || isViewerFocused },
dependencies: [imageActions.recallPrompts, isGalleryFocused, isViewerFocused],
callback: recallPrompts.recall,
options: { enabled: recallPrompts.isEnabled && isFocusOK },
dependencies: [recallPrompts, isFocusOK],
});
useRegisteredHotkeys({
id: 'remix',
category: 'viewer',
callback: imageActions.remix,
options: { enabled: isGalleryFocused || isViewerFocused },
dependencies: [imageActions.remix, isGalleryFocused, isViewerFocused],
callback: recallRemix.recall,
options: { enabled: recallRemix.isEnabled && isFocusOK },
dependencies: [recallRemix, isFocusOK],
});
useRegisteredHotkeys({
id: 'useSize',
category: 'viewer',
callback: imageActions.recallSize,
options: { enabled: !isStaging && (isGalleryFocused || isViewerFocused) },
dependencies: [imageActions.recallSize, isStaging, isGalleryFocused, isViewerFocused],
});
useRegisteredHotkeys({
id: 'runPostprocessing',
category: 'viewer',
callback: imageActions.upscale,
options: { enabled: isUpscalingEnabled && isViewerFocused },
dependencies: [isUpscalingEnabled, imageDTO, isViewerFocused],
callback: recallDimensions.recall,
options: { enabled: recallDimensions.isEnabled && isFocusOK },
dependencies: [recallDimensions, isFocusOK],
});
return null;
});

View File

@@ -42,7 +42,6 @@ import { $socketOptions } from 'services/events/stores';
import type { ManagerOptions, SocketOptions } from 'socket.io-client';
const App = lazy(() => import('./App'));
const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider'));
interface Props extends PropsWithChildren {
apiUrl?: string;
@@ -330,9 +329,7 @@ const InvokeAIUI = ({
<React.StrictMode>
<Provider store={store}>
<React.Suspense fallback={<Loading />}>
<ThemeLocaleProvider>
<App config={config} studioInitAction={studioInitAction} />
</ThemeLocaleProvider>
<App config={config} studioInitAction={studioInitAction} />
</React.Suspense>
</Provider>
</React.StrictMode>

View File

@@ -8,7 +8,7 @@ import { paramsReset } from 'features/controlLayers/store/paramsSlice';
import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
import { imageDTOToImageObject } from 'features/controlLayers/store/util';
import { sentImageToCanvas } from 'features/gallery/store/actions';
import { parseAndRecallAllMetadata } from 'features/metadata/util/handlers';
import { MetadataUtils } from 'features/metadata/parsing';
import { $hasTemplates } from 'features/nodes/store/nodesSlice';
import { $isWorkflowLibraryModalOpen } from 'features/nodes/store/workflowLibraryModal';
import {
@@ -19,7 +19,9 @@ import {
} from 'features/nodes/store/workflowLibrarySlice';
import { $isStylePresetsMenuOpen, activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice';
import { toast } from 'features/toast/toast';
import { activeTabCanvasRightPanelChanged, setActiveTab } from 'features/ui/store/uiSlice';
import { navigationApi } from 'features/ui/layouts/navigation-api';
import { LAUNCHPAD_PANEL_ID, WORKSPACE_PANEL_ID } from 'features/ui/layouts/shared';
import { activeTabCanvasRightPanelChanged } from 'features/ui/store/uiSlice';
import { useLoadWorkflowWithDialog } from 'features/workflowLibrary/components/LoadWorkflowConfirmationAlertDialog';
import { atom } from 'nanostores';
import { useCallback, useEffect } from 'react';
@@ -90,6 +92,7 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
const overrides: Partial<CanvasRasterLayerState> = {
objects: [imageObject],
};
await navigationApi.focusPanel('canvas', WORKSPACE_PANEL_ID);
store.dispatch(canvasReset());
store.dispatch(rasterLayerAdded({ overrides, isSelected: true }));
store.dispatch(sentImageToCanvas());
@@ -116,23 +119,23 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
const metadata = getImageMetadataResult.value;
store.dispatch(canvasReset());
// This shows a toast
await parseAndRecallAllMetadata(metadata, true);
await MetadataUtils.recallAll(metadata, store);
},
[store, t]
);
const handleLoadWorkflow = useCallback(
async (workflowId: string) => {
(workflowId: string) => {
// This shows a toast
await loadWorkflowWithDialog({
loadWorkflowWithDialog({
type: 'library',
data: workflowId,
onSuccess: () => {
store.dispatch(setActiveTab('workflows'));
navigationApi.switchToTab('workflows');
},
});
},
[loadWorkflowWithDialog, store]
[loadWorkflowWithDialog]
);
const handleSelectStylePreset = useCallback(
@@ -146,7 +149,7 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
return;
}
store.dispatch(activeStylePresetIdChanged(stylePresetId));
store.dispatch(setActiveTab('canvas'));
navigationApi.switchToTab('canvas');
toast({
title: t('toast.stylePresetLoaded'),
status: 'info',
@@ -156,33 +159,34 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
);
const handleGoToDestination = useCallback(
(destination: StudioDestinationAction['data']['destination']) => {
async (destination: StudioDestinationAction['data']['destination']) => {
switch (destination) {
case 'generation':
// Go to the canvas tab, open the image viewer, and enable send-to-gallery mode
// Go to the generate tab, open the launchpad
await navigationApi.focusPanel('generate', LAUNCHPAD_PANEL_ID);
store.dispatch(paramsReset());
store.dispatch(activeTabCanvasRightPanelChanged('gallery'));
break;
case 'canvas':
// Go to the canvas tab, close the image viewer, and disable send-to-gallery mode
store.dispatch(canvasReset());
// Go to the canvas tab, open the launchpad
await navigationApi.focusPanel('canvas', WORKSPACE_PANEL_ID);
break;
case 'workflows':
// Go to the workflows tab
store.dispatch(setActiveTab('workflows'));
navigationApi.switchToTab('workflows');
break;
case 'upscaling':
// Go to the upscaling tab
store.dispatch(setActiveTab('upscaling'));
navigationApi.switchToTab('upscaling');
break;
case 'viewAllWorkflows':
// Go to the workflows tab and open the workflow library modal
store.dispatch(setActiveTab('workflows'));
navigationApi.switchToTab('workflows');
$isWorkflowLibraryModalOpen.set(true);
break;
case 'viewAllWorkflowsRecommended':
// Go to the workflows tab and open the workflow library modal with the recommended workflows view
store.dispatch(setActiveTab('workflows'));
navigationApi.switchToTab('workflows');
$isWorkflowLibraryModalOpen.set(true);
store.dispatch(workflowLibraryViewChanged('defaults'));
store.dispatch(workflowLibraryTagsReset());
@@ -194,7 +198,7 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
break;
case 'viewAllStylePresets':
// Go to the canvas tab and open the style presets menu
store.dispatch(setActiveTab('canvas'));
navigationApi.switchToTab('canvas');
$isStylePresetsMenuOpen.set(true);
break;
}

View File

@@ -8,7 +8,6 @@ import { addBatchEnqueuedListener } from 'app/store/middleware/listenerMiddlewar
import { addDeleteBoardAndImagesFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/boardAndImagesDeleted';
import { addBoardIdSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/boardIdSelected';
import { addBulkDownloadListeners } from 'app/store/middleware/listenerMiddleware/listeners/bulkDownload';
import { addEnqueueRequestedLinear } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear';
import { addGetOpenAPISchemaListener } from 'app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema';
import { addImageAddedToBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageAddedToBoard';
import { addImageRemovedFromBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageRemovedFromBoard';
@@ -20,7 +19,6 @@ import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMi
import type { AppDispatch, RootState } from 'app/store/store';
import { addArchivedOrDeletedBoardListener } from './listeners/addArchivedOrDeletedBoardListener';
import { addEnqueueRequestedUpscale } from './listeners/enqueueRequestedUpscale';
export const listenerMiddleware = createListenerMiddleware();
@@ -43,8 +41,6 @@ addImageUploadedFulfilledListener(startAppListening);
addDeleteBoardAndImagesFulfilledListener(startAppListening);
// User Invoked
addEnqueueRequestedLinear(startAppListening);
addEnqueueRequestedUpscale(startAppListening);
addAnyEnqueuedListener(startAppListening);
addBatchEnqueuedListener(startAppListening);

View File

@@ -1,6 +1,6 @@
import { isAnyOf } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectListImageNamesQueryArgs, selectSelectedBoardId } from 'features/gallery/store/gallerySelectors';
import { selectGetImageNamesQueryArgs, selectSelectedBoardId } from 'features/gallery/store/gallerySelectors';
import { boardIdSelected, galleryViewChanged, imageSelected } from 'features/gallery/store/gallerySlice';
import { imagesApi } from 'services/api/endpoints/images';
@@ -20,7 +20,7 @@ export const addBoardIdSelectedListener = (startAppListening: AppStartListening)
const board_id = selectSelectedBoardId(state);
const queryArgs = { ...selectListImageNamesQueryArgs(state), board_id };
const queryArgs = { ...selectGetImageNamesQueryArgs(state), board_id };
// wait until the board has some images - maybe it already has some from a previous fetch
// must use getState() to ensure we do not have stale state

View File

@@ -1,154 +0,0 @@
import type { AlertStatus } from '@invoke-ai/ui-library';
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError';
import { withResult, withResultAsync } from 'common/util/result';
import { parseify } from 'common/util/serialize';
import {
canvasSessionIdCreated,
generateSessionIdCreated,
selectCanvasSessionId,
selectGenerateSessionId,
} from 'features/controlLayers/store/canvasStagingAreaSlice';
import { $canvasManager } from 'features/controlLayers/store/ephemeral';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildChatGPT4oGraph';
import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph';
import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph';
import { buildFluxKontextGraph } from 'features/nodes/util/graph/generation/buildFluxKontextGraph';
import { buildImagen3Graph } from 'features/nodes/util/graph/generation/buildImagen3Graph';
import { buildImagen4Graph } from 'features/nodes/util/graph/generation/buildImagen4Graph';
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Graph';
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
import { toast } from 'features/toast/toast';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { serializeError } from 'serialize-error';
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
import { assert, AssertionError } from 'tsafe';
const log = logger('generation');
export const enqueueRequestedCanvas = createAction<{ prepend: boolean }>('app/enqueueRequestedCanvas');
export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: enqueueRequestedCanvas,
effect: async (action, { getState, dispatch }) => {
log.debug('Enqueue requested');
const tab = selectActiveTab(getState());
let sessionId = null;
if (tab === 'generate') {
sessionId = selectGenerateSessionId(getState());
if (!sessionId) {
dispatch(generateSessionIdCreated());
sessionId = selectGenerateSessionId(getState());
}
} else if (tab === 'canvas') {
sessionId = selectCanvasSessionId(getState());
if (!sessionId) {
dispatch(canvasSessionIdCreated());
sessionId = selectCanvasSessionId(getState());
}
} else {
log.warn(`Enqueue requested in unsupported tab ${tab}`);
return;
}
const state = getState();
const destination = sessionId;
assert(destination !== null);
const { prepend } = action.payload;
const manager = $canvasManager.get();
// assert(manager, 'No canvas manager');
const model = state.params.model;
assert(model, 'No model found in state');
const base = model.base;
const buildGraphResult = await withResultAsync(async () => {
switch (base) {
case 'sdxl':
return await buildSDXLGraph(state, manager);
case 'sd-1':
case `sd-2`:
return await buildSD1Graph(state, manager);
case `sd-3`:
return await buildSD3Graph(state, manager);
case `flux`:
return await buildFLUXGraph(state, manager);
case 'cogview4':
return await buildCogView4Graph(state, manager);
case 'imagen3':
return await buildImagen3Graph(state, manager);
case 'imagen4':
return await buildImagen4Graph(state, manager);
case 'chatgpt-4o':
return await buildChatGPT4oGraph(state, manager);
case 'flux-kontext':
return await buildFluxKontextGraph(state, manager);
default:
assert(false, `No graph builders for base ${base}`);
}
});
if (buildGraphResult.isErr()) {
let title = 'Failed to build graph';
let status: AlertStatus = 'error';
let description: string | null = null;
if (buildGraphResult.error instanceof AssertionError) {
description = extractMessageFromAssertionError(buildGraphResult.error);
} else if (buildGraphResult.error instanceof UnsupportedGenerationModeError) {
title = 'Unsupported generation mode';
description = buildGraphResult.error.message;
status = 'warning';
}
const error = serializeError(buildGraphResult.error);
log.error({ error }, 'Failed to build graph');
toast({
status,
title,
description,
});
return;
}
const { g, seedFieldIdentifier, positivePromptFieldIdentifier } = buildGraphResult.value;
const prepareBatchResult = withResult(() =>
prepareLinearUIBatch({
state,
g,
prepend,
seedFieldIdentifier,
positivePromptFieldIdentifier,
origin: tab,
destination,
})
);
if (prepareBatchResult.isErr()) {
log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
return;
}
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(prepareBatchResult.value, enqueueMutationFixedCacheKeyOptions)
);
try {
await req.unwrap();
log.debug(parseify({ batchConfig: prepareBatchResult.value }), 'Enqueued batch');
} catch (error) {
log.error({ error: serializeError(error as Error) }, 'Failed to enqueue batch');
} finally {
req.reset();
}
},
});
};

View File

@@ -1,44 +0,0 @@
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { parseify } from 'common/util/serialize';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildMultidiffusionUpscaleGraph } from 'features/nodes/util/graph/buildMultidiffusionUpscaleGraph';
import { serializeError } from 'serialize-error';
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
const log = logger('generation');
export const enqueueRequestedUpscaling = createAction<{ prepend: boolean }>('app/enqueueRequestedUpscaling');
export const addEnqueueRequestedUpscale = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: enqueueRequestedUpscaling,
effect: async (action, { getState, dispatch }) => {
const state = getState();
const { prepend } = action.payload;
const { g, seedFieldIdentifier, positivePromptFieldIdentifier } = await buildMultidiffusionUpscaleGraph(state);
const batchConfig = prepareLinearUIBatch({
state,
g,
prepend,
seedFieldIdentifier,
positivePromptFieldIdentifier,
origin: 'upscaling',
destination: 'gallery',
});
const req = dispatch(queueApi.endpoints.enqueueBatch.initiate(batchConfig, enqueueMutationFixedCacheKeyOptions));
try {
await req.unwrap();
log.debug(parseify({ batchConfig }), 'Enqueued batch');
} catch (error) {
log.error({ error: serializeError(error as Error) }, 'Failed to enqueue batch');
} finally {
req.reset();
}
},
});
};

View File

@@ -1,14 +1,28 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { bboxSyncedToOptimalDimension } from 'features/controlLayers/store/canvasSlice';
import { bboxSyncedToOptimalDimension, rgRefImageModelChanged } from 'features/controlLayers/store/canvasSlice';
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
import { modelChanged, vaeSelected } from 'features/controlLayers/store/paramsSlice';
import { selectBboxModelBase } from 'features/controlLayers/store/selectors';
import { modelChanged, syncedToOptimalDimension, vaeSelected } from 'features/controlLayers/store/paramsSlice';
import { refImageModelChanged, selectReferenceImageEntities } from 'features/controlLayers/store/refImagesSlice';
import {
selectAllEntitiesOfType,
selectBboxModelBase,
selectCanvasSlice,
} from 'features/controlLayers/store/selectors';
import { getEntityIdentifier } from 'features/controlLayers/store/types';
import { modelSelected } from 'features/parameters/store/actions';
import { zParameterModel } from 'features/parameters/types/parameterSchemas';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { selectGlobalRefImageModels, selectRegionalRefImageModels } from 'services/api/hooks/modelsByType';
import type { AnyModelConfig } from 'services/api/types';
import {
isChatGPT4oModelConfig,
isFluxKontextApiModelConfig,
isFluxKontextModelConfig,
isFluxReduxModelConfig,
} from 'services/api/types';
const log = logger('models');
@@ -25,9 +39,8 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
}
const newModel = result.data;
const newBaseModel = newModel.base;
const didBaseModelChange = state.params.model?.base !== newBaseModel;
const newBase = newModel.base;
const didBaseModelChange = state.params.model?.base !== newBase;
if (didBaseModelChange) {
// we may need to reset some incompatible submodels
@@ -35,7 +48,7 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
// handle incompatible loras
state.loras.loras.forEach((lora) => {
if (lora.model.base !== newBaseModel) {
if (lora.model.base !== newBase) {
dispatch(loraDeleted({ id: lora.id }));
modelsCleared += 1;
}
@@ -43,20 +56,82 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
// handle incompatible vae
const { vae } = state.params;
if (vae && vae.base !== newBaseModel) {
if (vae && vae.base !== newBase) {
dispatch(vaeSelected(null));
modelsCleared += 1;
}
// handle incompatible controlnets
// state.canvas.present.controlAdapters.entities.forEach((ca) => {
// if (ca.model?.base !== newBaseModel) {
// modelsCleared += 1;
// if (ca.isEnabled) {
// dispatch(entityIsEnabledToggled({ entityIdentifier: { id: ca.id, type: 'control_adapter' } }));
// }
// }
// });
// Handle incompatible reference image models - switch to first compatible model, with some smart logic
// to choose the best available model based on the new main model.
const allRefImageModels = selectGlobalRefImageModels(state).filter(({ base }) => base === newBase);
let newGlobalRefImageModel = null;
// Certain models require the ref image model to be the same as the main model - others just need a matching
// base. Helper to grab the first exact match or the first available model if no exact match is found.
const exactMatchOrFirst = <T extends AnyModelConfig>(candidates: T[]): T | null =>
candidates.find(({ key }) => key === newModel.key) ?? candidates[0] ?? null;
// The only way we can differentiate between FLUX and FLUX Kontext is to check for "kontext" in the name
if (newModel.base === 'flux' && newModel.name.toLowerCase().includes('kontext')) {
const fluxKontextDevModels = allRefImageModels.filter(isFluxKontextModelConfig);
newGlobalRefImageModel = exactMatchOrFirst(fluxKontextDevModels);
} else if (newModel.base === 'chatgpt-4o') {
const chatGPT4oModels = allRefImageModels.filter(isChatGPT4oModelConfig);
newGlobalRefImageModel = exactMatchOrFirst(chatGPT4oModels);
} else if (newModel.base === 'flux-kontext') {
const fluxKontextApiModels = allRefImageModels.filter(isFluxKontextApiModelConfig);
newGlobalRefImageModel = exactMatchOrFirst(fluxKontextApiModels);
} else if (newModel.base === 'flux') {
const fluxReduxModels = allRefImageModels.filter(isFluxReduxModelConfig);
newGlobalRefImageModel = fluxReduxModels[0] ?? null;
} else {
newGlobalRefImageModel = allRefImageModels[0] ?? null;
}
// All ref image entities are updated to use the same new model
const refImageEntities = selectReferenceImageEntities(state);
for (const entity of refImageEntities) {
const shouldUpdateModel =
(entity.config.model && entity.config.model.base !== newBase) ||
(!entity.config.model && newGlobalRefImageModel);
if (shouldUpdateModel) {
dispatch(
refImageModelChanged({
id: entity.id,
modelConfig: newGlobalRefImageModel,
})
);
modelsCleared += 1;
}
}
// For regional guidance, there is no smart logic - we just pick the first available model.
const newRegionalRefImageModel = selectRegionalRefImageModels(state)[0] ?? null;
// All regional guidance entities are updated to use the same new model.
const canvasState = selectCanvasSlice(state);
const canvasRegionalGuidanceEntities = selectAllEntitiesOfType(canvasState, 'regional_guidance');
for (const entity of canvasRegionalGuidanceEntities) {
for (const refImage of entity.referenceImages) {
// Only change the model if the current one is not compatible with the new base model.
const shouldUpdateModel =
(refImage.config.model && refImage.config.model.base !== newBase) ||
(!refImage.config.model && newRegionalRefImageModel);
if (shouldUpdateModel) {
dispatch(
rgRefImageModelChanged({
entityIdentifier: getEntityIdentifier(entity),
referenceImageId: refImage.id,
modelConfig: newRegionalRefImageModel,
})
);
modelsCleared += 1;
}
}
}
if (modelsCleared > 0) {
toast({
@@ -71,9 +146,16 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
}
dispatch(modelChanged({ model: newModel, previousModel: state.params.model }));
const modelBase = selectBboxModelBase(state);
if (!selectIsStaging(state) && modelBase !== state.params.model?.base) {
dispatch(bboxSyncedToOptimalDimension());
if (modelBase !== state.params.model?.base) {
// Sync generate tab settings whenever the model base changes
dispatch(syncedToOptimalDimension());
if (!selectIsStaging(state)) {
// Canvas tab only syncs if not staging
dispatch(bboxSyncedToOptimalDimension());
}
}
},
});

View File

@@ -1,7 +1,9 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { isNil } from 'es-toolkit';
import { bboxHeightChanged, bboxWidthChanged } from 'features/controlLayers/store/canvasSlice';
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import {
heightChanged,
setCfgRescaleMultiplier,
setCfgScale,
setGuidance,
@@ -9,6 +11,7 @@ import {
setSteps,
vaePrecisionChanged,
vaeSelected,
widthChanged,
} from 'features/controlLayers/store/paramsSlice';
import { setDefaultSettings } from 'features/parameters/store/actions';
import {
@@ -23,6 +26,7 @@ import {
zParameterVAEModel,
} from 'features/parameters/types/parameterSchemas';
import { toast } from 'features/toast/toast';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { t } from 'i18next';
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
import { isNonRefinerMainModelConfig } from 'services/api/types';
@@ -86,10 +90,16 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
}
}
if (cfg_rescale_multiplier) {
if (!isNil(cfg_rescale_multiplier)) {
if (isParameterCFGRescaleMultiplier(cfg_rescale_multiplier)) {
dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier));
}
} else {
// Set this to 0 if it doesn't have a default. This value is
// easy to miss in the UI when users are resetting defaults
// and leaving it non-zero could lead to detrimental
// effects.
dispatch(setCfgRescaleMultiplier(0));
}
if (steps) {
@@ -106,15 +116,24 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
const setSizeOptions = { updateAspectRatio: true, clamp: true };
const isStaging = selectIsStaging(getState());
if (!isStaging && width) {
const activeTab = selectActiveTab(getState());
if (activeTab === 'generate') {
if (isParameterWidth(width)) {
dispatch(bboxWidthChanged({ width, ...setSizeOptions }));
dispatch(widthChanged({ width, ...setSizeOptions }));
}
if (isParameterHeight(height)) {
dispatch(heightChanged({ height, ...setSizeOptions }));
}
}
if (!isStaging && height) {
if (isParameterHeight(height)) {
dispatch(bboxHeightChanged({ height, ...setSizeOptions }));
if (activeTab === 'canvas') {
if (!isStaging) {
if (isParameterWidth(width)) {
dispatch(bboxWidthChanged({ width, ...setSizeOptions }));
}
if (isParameterHeight(height)) {
dispatch(bboxHeightChanged({ height, ...setSizeOptions }));
}
}
}

View File

@@ -1,13 +0,0 @@
import { $didStudioInit } from 'app/hooks/useStudioInitAction';
import { atom, computed } from 'nanostores';
import { flushSync } from 'react-dom';
export const $isLayoutLoading = atom(false);
export const setIsLayoutLoading = (isLoading: boolean) => {
flushSync(() => {
$isLayoutLoading.set(isLoading);
});
};
export const $globalIsLoading = computed([$didStudioInit, $isLayoutLoading], (didStudioInit, isLayoutLoading) => {
return !didStudioInit || isLayoutLoading;
});

View File

@@ -1,4 +1,3 @@
import { useStore } from '@nanostores/react';
import type { AppStore } from 'app/store/store';
import { atom } from 'nanostores';
@@ -32,11 +31,3 @@ export const getStore = () => {
}
return store;
};
export const useAppStore = () => {
const store = useStore($store);
if (!store) {
throw new ReduxStoreNotInitialized();
}
return store;
};

View File

@@ -11,5 +11,7 @@ export const $false: ReadableAtom<boolean> = atom(false);
/**
* A fallback non-writable atom that always returns `true`, used when a nanostores atom is only conditionally available
* in a hook or component.
*
* @knipignore
*/
export const $true: ReadableAtom<boolean> = atom(true);

View File

@@ -17,7 +17,6 @@ import { paramsPersistConfig, paramsSlice } from 'features/controlLayers/store/p
import { refImagesPersistConfig, refImagesSlice } from 'features/controlLayers/store/refImagesSlice';
import { dynamicPromptsPersistConfig, dynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import { galleryPersistConfig, gallerySlice } from 'features/gallery/store/gallerySlice';
import { hrfPersistConfig, hrfSlice } from 'features/hrf/store/hrfSlice';
import { modelManagerV2PersistConfig, modelManagerV2Slice } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { nodesPersistConfig, nodesSlice, nodesUndoableConfig } from 'features/nodes/store/nodesSlice';
import { workflowLibraryPersistConfig, workflowLibrarySlice } from 'features/nodes/store/workflowLibrarySlice';
@@ -57,7 +56,6 @@ const allReducers = {
[changeBoardModalSlice.name]: changeBoardModalSlice.reducer,
[modelManagerV2Slice.name]: modelManagerV2Slice.reducer,
[queueSlice.name]: queueSlice.reducer,
[hrfSlice.name]: hrfSlice.reducer,
[canvasSlice.name]: undoable(canvasSlice.reducer, canvasUndoableConfig),
[workflowSettingsSlice.name]: workflowSettingsSlice.reducer,
[upscaleSlice.name]: upscaleSlice.reducer,
@@ -103,7 +101,6 @@ const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
[uiPersistConfig.name]: uiPersistConfig,
[dynamicPromptsPersistConfig.name]: dynamicPromptsPersistConfig,
[modelManagerV2PersistConfig.name]: modelManagerV2PersistConfig,
[hrfPersistConfig.name]: hrfPersistConfig,
[canvasPersistConfig.name]: canvasPersistConfig,
[workflowSettingsPersistConfig.name]: workflowSettingsPersistConfig,
[upscalePersistConfig.name]: upscalePersistConfig,

View File

@@ -1,8 +1,8 @@
import type { AppThunkDispatch, RootState } from 'app/store/store';
import type { AppStore, AppThunkDispatch, RootState } from 'app/store/store';
import type { TypedUseSelectorHook } from 'react-redux';
import { useDispatch, useSelector, useStore } from 'react-redux';
// Use throughout your app instead of plain `useDispatch` and `useSelector`
export const useAppDispatch = () => useDispatch<AppThunkDispatch>();
export const useAppSelector: TypedUseSelectorHook<RootState> = useSelector;
export const useAppStore = () => useStore<RootState>();
export const useAppStore = () => useStore.withTypes<AppStore>()();

View File

@@ -1,6 +1,6 @@
import type { Selector } from '@reduxjs/toolkit';
import { useAppStore } from 'app/store/nanostores/store';
import type { RootState } from 'app/store/store';
import { useAppStore } from 'app/store/storeHooks';
import { useEffect, useState } from 'react';
/**

View File

@@ -14,6 +14,7 @@ export type AppFeature =
| 'githubLink'
| 'discordLink'
| 'bugLink'
| 'aboutModal'
| 'localization'
| 'consoleLogging'
| 'dynamicPrompting'
@@ -29,7 +30,8 @@ export type AppFeature =
| 'hfToken'
| 'retryQueueItem'
| 'cancelAndClearAll'
| 'chatGPT4oHigh';
| 'chatGPT4oHigh'
| 'modelRelationships';
/**
* A disable-able Stable Diffusion feature
*/
@@ -76,6 +78,7 @@ export type AppConfig = {
allowPrivateStylePresets: boolean;
allowClientSideUpload: boolean;
allowPublishWorkflows: boolean;
allowPromptExpansion: boolean;
disabledTabs: TabName[];
disabledFeatures: AppFeature[];
disabledSDFeatures: SDFeature[];

View File

@@ -1,56 +0,0 @@
import { Box, type BoxProps, type SystemStyleObject } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { type FocusRegionName, useFocusRegion, useIsRegionFocused } from 'common/hooks/focus';
import { selectSystemShouldEnableHighlightFocusedRegions } from 'features/system/store/systemSlice';
import { memo, useMemo, useRef } from 'react';
interface FocusRegionWrapperProps extends BoxProps {
region: FocusRegionName;
focusOnMount?: boolean;
}
const FOCUS_REGION_STYLES: SystemStyleObject = {
position: 'relative',
'&[data-highlighted="true"]::after': {
borderColor: 'blue.700',
},
'&::after': {
content: '""',
position: 'absolute',
inset: 0,
zIndex: 1,
borderRadius: 'base',
border: '2px solid',
borderColor: 'transparent',
pointerEvents: 'none',
transition: 'border-color 0.1s ease-in-out',
},
};
export const FocusRegionWrapper = memo(
({ region, focusOnMount = false, sx, children, ...boxProps }: FocusRegionWrapperProps) => {
const shouldHighlightFocusedRegions = useAppSelector(selectSystemShouldEnableHighlightFocusedRegions);
const ref = useRef<HTMLDivElement>(null);
const options = useMemo(() => ({ focusOnMount }), [focusOnMount]);
useFocusRegion(region, ref, options);
const isFocused = useIsRegionFocused(region);
const isHighlighted = isFocused && shouldHighlightFocusedRegions;
return (
<Box
ref={ref}
tabIndex={-1}
sx={useMemo(() => ({ ...FOCUS_REGION_STYLES, ...sx }), [sx])}
data-highlighted={isHighlighted}
{...boxProps}
>
{children}
</Box>
);
}
);
FocusRegionWrapper.displayName = 'FocusRegionWrapper';

View File

@@ -8,21 +8,16 @@ const Loading = () => {
return (
<Flex
position="absolute"
width="100dvw"
height="100dvh"
alignItems="center"
justifyContent="center"
bg="#151519"
top={0}
right={0}
bottom={0}
left={0}
bg="hsl(220 12% 10% / 1)" // base.900
inset={0}
zIndex={99999}
>
<Image src={InvokeLogoWhite} w="8rem" h="8rem" />
<Spinner
label="Loading"
color="grey"
color="hsl(220 12% 68% / 1)" // base.300
position="absolute"
size="sm"
width="24px !important"

View File

@@ -87,7 +87,7 @@ export const buildGroup = <T extends object>(group: Omit<Group<T>, typeof unique
[uniqueGroupKey]: true,
});
const isGroup = <T extends object>(optionOrGroup: OptionOrGroup<T>): optionOrGroup is Group<T> => {
export const isGroup = <T extends object>(optionOrGroup: OptionOrGroup<T>): optionOrGroup is Group<T> => {
return uniqueGroupKey in optionOrGroup && optionOrGroup[uniqueGroupKey] === true;
};
@@ -198,6 +198,10 @@ type PickerProps<T extends object> = {
* Whether the picker should be searchable. If true, renders a search input.
*/
searchable?: boolean;
/**
* Initial state for group toggles. If provided, groups will start with these states instead of all being disabled.
*/
initialGroupStates?: GroupStatusMap;
};
export type PickerContextState<T extends object> = {
@@ -310,9 +314,9 @@ const flattenOptions = <T extends object>(options: OptionOrGroup<T>[]): T[] => {
return flattened;
};
type GroupStatusMap = Record<string, boolean>;
export type GroupStatusMap = Record<string, boolean>;
const useTogglableGroups = <T extends object>(options: OptionOrGroup<T>[]) => {
const useTogglableGroups = <T extends object>(options: OptionOrGroup<T>[], initialGroupStates?: GroupStatusMap) => {
const groupsWithOptions = useMemo(() => {
const ids: string[] = [];
for (const optionOrGroup of options) {
@@ -332,14 +336,16 @@ const useTogglableGroups = <T extends object>(options: OptionOrGroup<T>[]) => {
const groupStatusMap = $groupStatusMap.get();
const newMap: GroupStatusMap = {};
for (const id of groupsWithOptions) {
if (newMap[id] === undefined) {
newMap[id] = false;
if (initialGroupStates && initialGroupStates[id] !== undefined) {
newMap[id] = initialGroupStates[id];
} else if (groupStatusMap[id] !== undefined) {
newMap[id] = groupStatusMap[id];
} else {
newMap[id] = false;
}
}
$groupStatusMap.set(newMap);
}, [groupsWithOptions, $groupStatusMap]);
}, [groupsWithOptions, $groupStatusMap, initialGroupStates]);
const toggleGroup = useCallback(
(idToToggle: string) => {
@@ -511,10 +517,14 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
OptionComponent = DefaultOptionComponent,
NextToSearchBar,
searchable,
initialGroupStates,
} = props;
const rootRef = useRef<HTMLDivElement>(null);
const inputRef = useRef<HTMLInputElement>(null);
const { $groupStatusMap, $areAllGroupsDisabled, toggleGroup } = useTogglableGroups(optionsOrGroups);
const { $groupStatusMap, $areAllGroupsDisabled, toggleGroup } = useTogglableGroups(
optionsOrGroups,
initialGroupStates
);
const $activeOptionId = useAtom(getFirstOptionId(optionsOrGroups, getOptionId));
const $compactView = useAtom(true);
const $optionsOrGroups = useAtom(optionsOrGroups);

View File

@@ -1,20 +1,15 @@
import { MenuItem } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import {
useNewCanvasSession,
useNewGallerySession,
} from 'features/controlLayers/components/NewSessionConfirmationAlertDialog';
import { allEntitiesDeleted } from 'features/controlLayers/store/canvasSlice';
import { paramsReset } from 'features/controlLayers/store/paramsSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowsCounterClockwiseBold, PiFilePlusBold } from 'react-icons/pi';
import { PiArrowsCounterClockwiseBold } from 'react-icons/pi';
export const SessionMenuItems = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const { newGallerySessionWithDialog } = useNewGallerySession();
const { newCanvasSessionWithDialog } = useNewCanvasSession();
const resetCanvasLayers = useCallback(() => {
dispatch(allEntitiesDeleted());
}, [dispatch]);
@@ -23,12 +18,6 @@ export const SessionMenuItems = memo(() => {
}, [dispatch]);
return (
<>
<MenuItem icon={<PiFilePlusBold />} onClick={newGallerySessionWithDialog}>
{t('controlLayers.newGallerySession')}
</MenuItem>
<MenuItem icon={<PiFilePlusBold />} onClick={newCanvasSessionWithDialog}>
{t('controlLayers.newCanvasSession')}
</MenuItem>
<MenuItem icon={<PiArrowsCounterClockwiseBold />} onClick={resetCanvasLayers}>
{t('controlLayers.resetCanvasLayers')}
</MenuItem>

View File

@@ -27,26 +27,37 @@ import { objectKeys } from 'tsafe';
const log = logger('system');
const REGION_NAMES = [
'launchpad',
'viewer',
'gallery',
'boards',
'layers',
'canvas',
'workflows',
'progress',
'settings',
] as const;
/**
* The names of the focus regions.
*/
export type FocusRegionName = 'gallery' | 'layers' | 'canvas' | 'workflows' | 'viewer';
export type FocusRegionName = (typeof REGION_NAMES)[number];
/**
* A map of focus regions to the elements that are part of that region.
*/
const REGION_TARGETS: Record<FocusRegionName, Set<HTMLElement>> = {
gallery: new Set<HTMLElement>(),
layers: new Set<HTMLElement>(),
canvas: new Set<HTMLElement>(),
workflows: new Set<HTMLElement>(),
viewer: new Set<HTMLElement>(),
} as const;
const REGION_TARGETS: Record<FocusRegionName, Set<HTMLElement>> = REGION_NAMES.reduce(
(acc, region) => {
acc[region] = new Set<HTMLElement>();
return acc;
},
{} as Record<FocusRegionName, Set<HTMLElement>>
);
/**
* The currently-focused region or `null` if no region is focused.
*/
export const $focusedRegion = atom<FocusRegionName | null>(null);
const $focusedRegion = atom<FocusRegionName | null>(null);
/**
* A map of focus regions to atoms that indicate if that region is focused.
@@ -62,11 +73,13 @@ const FOCUS_REGIONS = objectKeys(REGION_TARGETS).reduce(
/**
* Sets the focused region, logging a trace level message.
*/
const setFocus = (region: FocusRegionName | null) => {
export const setFocusedRegion = (region: FocusRegionName | null) => {
$focusedRegion.set(region);
log.trace(`Focus changed: ${region}`);
};
export const getFocusedRegion = () => $focusedRegion.get();
type UseFocusRegionOptions = {
focusOnMount?: boolean;
};
@@ -99,14 +112,14 @@ export const useFocusRegion = (
REGION_TARGETS[region].add(element);
if (focusOnMount) {
setFocus(region);
setFocusedRegion(region);
}
return () => {
REGION_TARGETS[region].delete(element);
if (REGION_TARGETS[region].size === 0 && $focusedRegion.get() === region) {
setFocus(null);
setFocusedRegion(null);
}
};
}, [options, ref, region]);
@@ -163,7 +176,7 @@ const onFocus = (_: FocusEvent) => {
return;
}
setFocus(focusedRegion);
setFocusedRegion(focusedRegion);
};
/**

View File

@@ -0,0 +1,115 @@
import { useStore } from '@nanostores/react';
import { WrappedError } from 'common/util/result';
import type { Atom } from 'nanostores';
import { atom } from 'nanostores';
import { useCallback, useEffect, useMemo, useState } from 'react';
type SuccessState<T> = {
status: 'success';
value: T;
error: null;
};
type ErrorState = {
status: 'error';
value: null;
error: Error;
};
type PendingState = {
status: 'pending';
value: null;
error: null;
};
type IdleState = {
status: 'idle';
value: null;
error: null;
};
export type State<T> = IdleState | PendingState | SuccessState<T> | ErrorState;
type UseAsyncStateOptions = {
immediate?: boolean;
};
type UseAsyncReturn<T> = {
$state: Atom<State<T>>;
trigger: () => Promise<void>;
reset: () => void;
};
export const useAsyncState = <T>(execute: () => Promise<T>, options?: UseAsyncStateOptions): UseAsyncReturn<T> => {
const $state = useState(() =>
atom<State<T>>({
status: 'idle',
value: null,
error: null,
})
)[0];
const trigger = useCallback(async () => {
$state.set({
status: 'pending',
value: null,
error: null,
});
try {
const value = await execute();
$state.set({
status: 'success',
value,
error: null,
});
} catch (error) {
$state.set({
status: 'error',
value: null,
error: WrappedError.wrap(error),
});
}
}, [$state, execute]);
const reset = useCallback(() => {
$state.set({
status: 'idle',
value: null,
error: null,
});
}, [$state]);
useEffect(() => {
if (options?.immediate) {
trigger();
}
}, [options?.immediate, trigger]);
const api = useMemo(
() =>
({
$state,
trigger,
reset,
}) satisfies UseAsyncReturn<T>,
[$state, trigger, reset]
);
return api;
};
type UseAsyncReturnReactive<T> = {
state: State<T>;
trigger: () => Promise<void>;
reset: () => void;
};
export const useAsyncStateReactive = <T>(
execute: () => Promise<T>,
options?: UseAsyncStateOptions
): UseAsyncReturnReactive<T> => {
const { $state, trigger, reset } = useAsyncState(execute, options);
const state = useStore($state);
return { state, trigger, reset };
};

View File

@@ -73,7 +73,7 @@ export const useBoolean = (initialValue: boolean): UseBoolean => {
};
};
export type UseDisclosure = {
type UseDisclosure = {
isOpen: boolean;
open: () => void;
close: () => void;

View File

@@ -1,165 +0,0 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
/**
* Adapted from https://github.com/chakra-ui/chakra-ui/blob/v2/packages/hooks/src/use-outside-click.ts
*
* The main change here is to support filtering of outside clicks via a `filter` function.
*
* This lets us work around issues with portals and components like popovers, which typically close on an outside click.
*
* For example, consider a popover that has a custom drop-down component inside it, which uses a portal to render
* the drop-down options. The original outside click handler would close the popover when clicking on the drop-down options,
* because the click is outside the popover - but we expect the popover to stay open in this case.
*
* A filter function like this can fix that:
*
* ```ts
* const filter = (el: HTMLElement) => el.className.includes('chakra-portal') || el.id.includes('react-select')
* ```
*
* This ignores clicks on react-select-based drop-downs and Chakra UI portals and is used as the default filter.
*/
import { useCallback, useEffect, useRef } from 'react';
type FilterFunction = (el: HTMLElement | SVGElement) => boolean;
export function useCallbackRef<T extends (...args: any[]) => any>(
callback: T | undefined,
deps: React.DependencyList = []
) {
const callbackRef = useRef(callback);
useEffect(() => {
callbackRef.current = callback;
});
// eslint-disable-next-line react-hooks/exhaustive-deps
return useCallback(((...args) => callbackRef.current?.(...args)) as T, deps);
}
export interface UseOutsideClickProps {
/**
* Whether the hook is enabled
*/
enabled?: boolean;
/**
* The reference to a DOM element.
*/
ref: React.RefObject<HTMLElement | null>;
/**
* Function invoked when a click is triggered outside the referenced element.
*/
handler?: (e: Event) => void;
/**
* A function that filters the elements that should be considered as outside clicks.
*
* If omitted, a default filter function that ignores clicks in Chakra UI portals and react-select components is used.
*/
filter?: FilterFunction;
}
export const DEFAULT_FILTER: FilterFunction = (el) => {
if (el instanceof SVGElement) {
// SVGElement's type appears to be incorrect. Its className is not a string, which causes `includes` to fail.
// Let's assume that SVG elements with a class name are not part of the portal and should not be filtered.
return false;
}
return el.className.includes('chakra-portal') || el.id.includes('react-select');
};
/**
* Example, used in components like Dialogs and Popovers, so they can close
* when a user clicks outside them.
*/
export function useFilterableOutsideClick(props: UseOutsideClickProps) {
const { ref, handler, enabled = true, filter = DEFAULT_FILTER } = props;
const savedHandler = useCallbackRef(handler);
const stateRef = useRef({
isPointerDown: false,
ignoreEmulatedMouseEvents: false,
});
const state = stateRef.current;
useEffect(() => {
if (!enabled) {
return;
}
const onPointerDown: any = (e: PointerEvent) => {
if (isValidEvent(e, ref, filter)) {
state.isPointerDown = true;
}
};
const onMouseUp: any = (event: MouseEvent) => {
if (state.ignoreEmulatedMouseEvents) {
state.ignoreEmulatedMouseEvents = false;
return;
}
if (state.isPointerDown && handler && isValidEvent(event, ref)) {
state.isPointerDown = false;
savedHandler(event);
}
};
const onTouchEnd = (event: TouchEvent) => {
state.ignoreEmulatedMouseEvents = true;
if (handler && state.isPointerDown && isValidEvent(event, ref)) {
state.isPointerDown = false;
savedHandler(event);
}
};
const doc = getOwnerDocument(ref.current);
doc.addEventListener('mousedown', onPointerDown, true);
doc.addEventListener('mouseup', onMouseUp, true);
doc.addEventListener('touchstart', onPointerDown, true);
doc.addEventListener('touchend', onTouchEnd, true);
return () => {
doc.removeEventListener('mousedown', onPointerDown, true);
doc.removeEventListener('mouseup', onMouseUp, true);
doc.removeEventListener('touchstart', onPointerDown, true);
doc.removeEventListener('touchend', onTouchEnd, true);
};
}, [handler, ref, savedHandler, state, enabled, filter]);
}
function isValidEvent(event: Event, ref: React.RefObject<HTMLElement | null>, filter?: FilterFunction): boolean {
const target = (event.composedPath?.()[0] ?? event.target) as HTMLElement;
if (target) {
const doc = getOwnerDocument(target);
if (!doc.contains(target)) {
return false;
}
}
if (ref.current?.contains(target)) {
return false;
}
// This is the main logic change from the original hook.
if (filter) {
// Check if the click is inside an element matching the filter.
// This is used for portal-awareness or other general exclusion cases.
let currentElement: HTMLElement | null = target;
// Traverse up the DOM tree from the target element.
while (currentElement && currentElement !== document.body) {
if (filter(currentElement)) {
return false;
}
currentElement = currentElement.parentElement;
}
}
// If the click is not inside the ref and not inside a portal, it's a valid outside click.
return true;
}
function getOwnerDocument(node?: Element | null): Document {
return node?.ownerDocument ?? document;
}

View File

@@ -1,13 +1,17 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { useAppStore } from 'app/store/storeHooks';
import { useDeleteImageModalApi } from 'features/deleteImageModal/store/state';
import { selectSelection } from 'features/gallery/store/gallerySelectors';
import { useClearQueue } from 'features/queue/hooks/useClearQueue';
import { useDeleteCurrentQueueItem } from 'features/queue/hooks/useDeleteCurrentQueueItem';
import { useInvoke } from 'features/queue/hooks/useInvoke';
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { navigationApi } from 'features/ui/layouts/navigation-api';
import { getFocusedRegion } from './focus';
export const useGlobalHotkeys = () => {
const dispatch = useAppDispatch();
const { dispatch, getState } = useAppStore();
const isModelManagerEnabled = useFeatureStatus('modelManager');
const queue = useInvoke();
@@ -65,7 +69,7 @@ export const useGlobalHotkeys = () => {
id: 'selectGenerateTab',
category: 'app',
callback: () => {
dispatch(setActiveTab('generate'));
navigationApi.switchToTab('generate');
},
dependencies: [dispatch],
});
@@ -74,7 +78,7 @@ export const useGlobalHotkeys = () => {
id: 'selectCanvasTab',
category: 'app',
callback: () => {
dispatch(setActiveTab('canvas'));
navigationApi.switchToTab('canvas');
},
dependencies: [dispatch],
});
@@ -83,7 +87,7 @@ export const useGlobalHotkeys = () => {
id: 'selectUpscalingTab',
category: 'app',
callback: () => {
dispatch(setActiveTab('upscaling'));
navigationApi.switchToTab('upscaling');
},
dependencies: [dispatch],
});
@@ -92,7 +96,7 @@ export const useGlobalHotkeys = () => {
id: 'selectWorkflowsTab',
category: 'app',
callback: () => {
dispatch(setActiveTab('workflows'));
navigationApi.switchToTab('workflows');
},
dependencies: [dispatch],
});
@@ -101,7 +105,7 @@ export const useGlobalHotkeys = () => {
id: 'selectModelsTab',
category: 'app',
callback: () => {
dispatch(setActiveTab('models'));
navigationApi.switchToTab('models');
},
options: {
enabled: isModelManagerEnabled,
@@ -113,24 +117,26 @@ export const useGlobalHotkeys = () => {
id: 'selectQueueTab',
category: 'app',
callback: () => {
dispatch(setActiveTab('queue'));
navigationApi.switchToTab('queue');
},
dependencies: [dispatch, isModelManagerEnabled],
});
// TODO: implement delete - needs to handle gallery focus, which has changed w/ dockview
// useRegisteredHotkeys({
// id: 'deleteSelection',
// category: 'gallery',
// callback: () => {
// if (!selection.length) {
// return;
// }
// deleteImageModal.delete(selection);
// },
// options: {
// enabled: (isGalleryFocused || isImageViewerFocused) && isDeleteEnabledByTab && !isWorkflowsFocused,
// },
// dependencies: [isWorkflowsFocused, isDeleteEnabledByTab, selection, isWorkflowsFocused],
// });
const deleteImageModalApi = useDeleteImageModalApi();
useRegisteredHotkeys({
id: 'deleteSelection',
category: 'gallery',
callback: () => {
const focusedRegion = getFocusedRegion();
if (focusedRegion !== 'gallery' && focusedRegion !== 'viewer') {
return;
}
const selection = selectSelection(getState());
if (!selection.length) {
return;
}
deleteImageModalApi.delete(selection);
},
dependencies: [getState, deleteImageModalApi],
});
};

View File

@@ -21,11 +21,15 @@ type UseImageUploadButtonArgs =
isDisabled?: boolean;
allowMultiple: false;
onUpload?: (imageDTO: ImageDTO) => void;
onUploadStarted?: (files: File) => void;
onError?: (error: unknown) => void;
}
| {
isDisabled?: boolean;
allowMultiple: true;
onUpload?: (imageDTOs: ImageDTO[]) => void;
onUploadStarted?: (files: File[]) => void;
onError?: (error: unknown) => void;
};
const log = logger('gallery');
@@ -49,7 +53,13 @@ const log = logger('gallery');
* <Button {...getUploadButtonProps()} /> // will open the file dialog on click
* <input {...getUploadInputProps()} /> // hidden, handles native upload functionality
*/
export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: UseImageUploadButtonArgs) => {
export const useImageUploadButton = ({
onUpload,
isDisabled,
allowMultiple,
onUploadStarted,
onError,
}: UseImageUploadButtonArgs) => {
const autoAddBoardId = useAppSelector(selectAutoAddBoardId);
const isClientSideUploadEnabled = useAppSelector(selectIsClientSideUploadEnabled);
const [uploadImage, request] = useUploadImageMutation();
@@ -71,6 +81,7 @@ export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: Us
}
const file = files[0];
assert(file !== undefined); // should never happen
onUploadStarted?.(file);
const imageDTO = await uploadImage({
file,
image_category: 'user',
@@ -82,6 +93,8 @@ export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: Us
onUpload(imageDTO);
}
} else {
onUploadStarted?.(files);
let imageDTOs: ImageDTO[] = [];
if (isClientSideUploadEnabled && files.length > 1) {
imageDTOs = await Promise.all(files.map((file, i) => clientSideUpload(file, i)));
@@ -102,6 +115,7 @@ export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: Us
}
}
} catch (error) {
onError?.(error);
toast({
id: 'UPLOAD_FAILED',
title: t('toast.imageUploadFailed'),
@@ -109,7 +123,17 @@ export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: Us
});
}
},
[allowMultiple, autoAddBoardId, onUpload, uploadImage, isClientSideUploadEnabled, clientSideUpload, t]
[
allowMultiple,
onUploadStarted,
uploadImage,
autoAddBoardId,
onUpload,
isClientSideUploadEnabled,
clientSideUpload,
onError,
t,
]
);
const onDropRejected = useCallback(

View File

@@ -1,4 +1,4 @@
import { useAppStore } from 'app/store/nanostores/store';
import { useAppStore } from 'app/store/storeHooks';
import { debounce } from 'es-toolkit/compat';
import type { Dimensions } from 'features/controlLayers/store/types';
import { selectUiSlice, textAreaSizesStateChanged } from 'features/ui/store/uiSlice';

View File

@@ -1,132 +0,0 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { EMPTY_ARRAY } from 'app/store/constants';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import type { GroupBase } from 'chakra-react-select';
import { uniq } from 'es-toolkit/compat';
import { selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice';
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetRelatedModelIdsBatchQuery } from 'services/api/endpoints/modelRelationships';
import type { AnyModelConfig } from 'services/api/types';
import { useGroupedModelCombobox } from './useGroupedModelCombobox';
type UseRelatedGroupedModelComboboxArg<T extends AnyModelConfig> = {
modelConfigs: T[];
selectedModel?: ModelIdentifierField | null;
onChange: (value: T | null) => void;
getIsDisabled?: (model: T) => boolean;
isLoading?: boolean;
groupByType?: boolean;
};
// Custom hook to overlay the grouped model combobox with related models on top!
// Cleaner than hooking into useGroupedModelCombobox with a flag to enable/disable the related models
// Also allows for related models to be shown conditionally with some pretty simple logic if it ends up as a config flag.
type UseRelatedGroupedModelComboboxReturn = {
value: ComboboxOption | undefined | null;
options: GroupBase<ComboboxOption>[];
onChange: ComboboxOnChange;
placeholder: string;
noOptionsMessage: () => string;
};
const selectSelectedModelKeys = createMemoizedSelector(selectParamsSlice, selectLoRAsSlice, (params, loras) => {
const keys: string[] = [];
const main = params.model;
const vae = params.vae;
const refiner = params.refinerModel;
const controlnet = params.controlLora;
if (main) {
keys.push(main.key);
}
if (vae) {
keys.push(vae.key);
}
if (refiner) {
keys.push(refiner.key);
}
if (controlnet) {
keys.push(controlnet.key);
}
for (const { model } of loras.loras) {
keys.push(model.key);
}
return uniq(keys);
});
export function useRelatedGroupedModelCombobox<T extends AnyModelConfig>({
modelConfigs,
selectedModel,
onChange,
isLoading = false,
getIsDisabled,
groupByType,
}: UseRelatedGroupedModelComboboxArg<T>): UseRelatedGroupedModelComboboxReturn {
const { t } = useTranslation();
const selectedKeys = useAppSelector(selectSelectedModelKeys);
const { relatedKeys } = useGetRelatedModelIdsBatchQuery(selectedKeys, {
selectFromResult: ({ data }) => {
if (!data) {
return { relatedKeys: EMPTY_ARRAY };
}
return { relatedKeys: data };
},
});
// Base grouped options
const base = useGroupedModelCombobox({
modelConfigs,
selectedModel,
onChange,
getIsDisabled,
isLoading,
groupByType,
});
const options = useMemo(() => {
if (relatedKeys.length === 0) {
return base.options;
}
const relatedOptions: ComboboxOption[] = [];
const updatedGroups: GroupBase<ComboboxOption>[] = [];
for (const group of base.options) {
const remainingOptions: ComboboxOption[] = [];
for (const option of group.options) {
if (relatedKeys.includes(option.value)) {
relatedOptions.push({ ...option, label: `* ${option.label}` });
} else {
remainingOptions.push(option);
}
}
if (remainingOptions.length > 0) {
updatedGroups.push({
label: group.label,
options: remainingOptions,
});
}
}
if (relatedOptions.length > 0) {
return [{ label: t('modelManager.relatedModels'), options: relatedOptions }, ...updatedGroups];
} else {
return updatedGroups;
}
}, [base.options, relatedKeys, t]);
return {
...base,
options,
};
}

View File

@@ -1,28 +0,0 @@
import type { Selector } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
import { useAppStore } from 'app/store/storeHooks';
import type { Atom, WritableAtom } from 'nanostores';
import { atom } from 'nanostores';
import { useEffect, useState } from 'react';
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
export const useSelectorAsAtom = <T extends Selector<RootState, any>>(selector: T): Atom<ReturnType<T>> => {
const store = useAppStore();
const $atom = useState<WritableAtom<ReturnType<T>>>(() => atom<ReturnType<T>>(selector(store.getState())))[0];
useEffect(() => {
const unsubscribe = store.subscribe(() => {
const prev = $atom.get();
const next = selector(store.getState());
if (prev !== next) {
$atom.set(next);
}
});
return () => {
unsubscribe();
};
}, [$atom, selector, store]);
return $atom;
};

View File

@@ -0,0 +1,20 @@
export type Deferred<T> = {
promise: Promise<T>;
resolve: (value: T) => void;
reject: (error: Error) => void;
};
/**
* Create a promise and expose its resolve and reject callbacks.
*/
export const createDeferredPromise = <T>(): Deferred<T> => {
let resolve!: (value: T) => void;
let reject!: (error: Error) => void;
const promise = new Promise<T>((res, rej) => {
resolve = res;
reject = rej;
});
return { promise, resolve, reject };
};

View File

@@ -1,6 +0,0 @@
/**
* Get the keys of an object. This is a wrapper around `Object.keys` that types the result as an array of the keys of the object.
* @param obj The object to get the keys of.
* @returns The keys of the object.
*/
export const objectKeys = <T extends Record<string, unknown>>(obj: T) => Object.keys(obj) as Array<keyof T>;

View File

@@ -89,7 +89,7 @@ export function withResult<T>(fn: () => T): Result<T> {
try {
return new Ok(fn());
} catch (error) {
return new Err(error instanceof Error ? error : new Error(String(error)));
return new Err(error instanceof Error ? error : new WrappedError(error));
}
}
@@ -104,6 +104,23 @@ export async function withResultAsync<T>(fn: () => Promise<T>): Promise<Result<T
const result = await fn();
return new Ok(result);
} catch (error) {
return new Err(error instanceof Error ? error : new Error(String(error)));
return new Err(error instanceof Error ? error : new WrappedError(error));
}
}
export class WrappedError extends Error {
error: unknown;
constructor(error: unknown) {
super('Wrapped Error');
this.name = this.constructor.name;
this.error = error;
}
static wrap(error: unknown): Error | WrappedError {
if (error instanceof Error) {
return error;
}
return new WrappedError(error);
}
}

View File

@@ -1,182 +0,0 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import {
ContextMenu,
Divider,
Flex,
IconButton,
Menu,
MenuButton,
MenuList,
Tab,
TabList,
TabPanel,
TabPanels,
Tabs,
} from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { FocusRegionWrapper } from 'common/components/FocusRegionWrapper';
import { CanvasAlertsInvocationProgress } from 'features/controlLayers/components/CanvasAlerts/CanvasAlertsInvocationProgress';
import { CanvasAlertsPreserveMask } from 'features/controlLayers/components/CanvasAlerts/CanvasAlertsPreserveMask';
import { CanvasAlertsSelectedEntityStatus } from 'features/controlLayers/components/CanvasAlerts/CanvasAlertsSelectedEntityStatus';
import { CanvasContextMenuGlobalMenuItems } from 'features/controlLayers/components/CanvasContextMenu/CanvasContextMenuGlobalMenuItems';
import { CanvasContextMenuSelectedEntityMenuItems } from 'features/controlLayers/components/CanvasContextMenu/CanvasContextMenuSelectedEntityMenuItems';
import { CanvasDropArea } from 'features/controlLayers/components/CanvasDropArea';
import { Filter } from 'features/controlLayers/components/Filters/Filter';
import { CanvasHUD } from 'features/controlLayers/components/HUD/CanvasHUD';
import { InvokeCanvasComponent } from 'features/controlLayers/components/InvokeCanvasComponent';
import { SelectObject } from 'features/controlLayers/components/SelectObject/SelectObject';
import { CanvasSessionContextProvider } from 'features/controlLayers/components/SimpleSession/context';
import { GenerateLaunchpadPanel } from 'features/controlLayers/components/SimpleSession/GenerateLaunchpadPanel';
import { StagingAreaItemsList } from 'features/controlLayers/components/SimpleSession/StagingAreaItemsList';
import { StagingAreaToolbar } from 'features/controlLayers/components/StagingArea/StagingAreaToolbar';
import { CanvasToolbar } from 'features/controlLayers/components/Toolbar/CanvasToolbar';
import { Transform } from 'features/controlLayers/components/Transform/Transform';
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { selectDynamicGrid, selectShowHUD } from 'features/controlLayers/store/canvasSettingsSlice';
import { ImageViewer } from 'features/gallery/components/ImageViewer/ImageViewer';
import { ViewerToolbar } from 'features/gallery/components/ImageViewer/ViewerToolbar';
import { memo, useCallback } from 'react';
import { PiDotsThreeOutlineVerticalFill } from 'react-icons/pi';
const FOCUS_REGION_STYLES: SystemStyleObject = {
width: 'full',
height: 'full',
};
const MenuContent = memo(() => {
return (
<CanvasManagerProviderGate>
<MenuList>
<CanvasContextMenuSelectedEntityMenuItems />
<CanvasContextMenuGlobalMenuItems />
</MenuList>
</CanvasManagerProviderGate>
);
});
MenuContent.displayName = 'MenuContent';
const canvasBgSx = {
position: 'relative',
w: 'full',
h: 'full',
borderRadius: 'base',
overflow: 'hidden',
bg: 'base.900',
'&[data-dynamic-grid="true"]': {
bg: 'base.850',
},
};
export const AdvancedSession = memo(({ id }: { id: string | null }) => {
const dynamicGrid = useAppSelector(selectDynamicGrid);
const showHUD = useAppSelector(selectShowHUD);
const renderMenu = useCallback(() => {
return <MenuContent />;
}, []);
return (
<Tabs w="full" h="full">
<TabList>
<Tab>Welcome</Tab>
<Tab>Workspace</Tab>
<Tab>Viewer</Tab>
</TabList>
<TabPanels w="full" h="full">
<TabPanel w="full" h="full" justifyContent="center">
<GenerateLaunchpadPanel />
</TabPanel>
<TabPanel w="full" h="full">
<FocusRegionWrapper region="canvas" sx={FOCUS_REGION_STYLES}>
<Flex
tabIndex={-1}
borderRadius="base"
position="relative"
flexDirection="column"
height="full"
width="full"
gap={2}
alignItems="center"
justifyContent="center"
overflow="hidden"
>
<CanvasManagerProviderGate>
<CanvasToolbar />
</CanvasManagerProviderGate>
<Divider />
<ContextMenu<HTMLDivElement> renderMenu={renderMenu} withLongPress={false}>
{(ref) => (
<Flex ref={ref} sx={canvasBgSx} data-dynamic-grid={dynamicGrid}>
<InvokeCanvasComponent />
<CanvasManagerProviderGate>
<Flex
position="absolute"
flexDir="column"
top={1}
insetInlineStart={1}
pointerEvents="none"
gap={2}
alignItems="flex-start"
>
{showHUD && <CanvasHUD />}
<CanvasAlertsSelectedEntityStatus />
<CanvasAlertsPreserveMask />
<CanvasAlertsInvocationProgress />
</Flex>
<Flex position="absolute" top={1} insetInlineEnd={1}>
<Menu>
<MenuButton as={IconButton} icon={<PiDotsThreeOutlineVerticalFill />} colorScheme="base" />
<MenuContent />
</Menu>
</Flex>
</CanvasManagerProviderGate>
</Flex>
)}
</ContextMenu>
{id !== null && (
<CanvasManagerProviderGate>
<CanvasSessionContextProvider type="advanced" id={id}>
<Flex
position="absolute"
flexDir="column"
bottom={4}
gap={2}
align="center"
justify="center"
left={4}
right={4}
>
<Flex position="relative" maxW="full" w="full" h={108}>
<StagingAreaItemsList />
</Flex>
<Flex gap={2}>
<StagingAreaToolbar />
</Flex>
</Flex>
</CanvasSessionContextProvider>
</CanvasManagerProviderGate>
)}
<Flex position="absolute" bottom={4}>
<CanvasManagerProviderGate>
<Filter />
<Transform />
<SelectObject />
</CanvasManagerProviderGate>
</Flex>
<CanvasManagerProviderGate>
<CanvasDropArea />
</CanvasManagerProviderGate>
</Flex>
</FocusRegionWrapper>
</TabPanel>
<TabPanel w="full" h="full">
<Flex flexDir="column" w="full" h="full">
<ViewerToolbar />
<ImageViewer />
</Flex>
</TabPanel>
</TabPanels>
</Tabs>
);
});
AdvancedSession.displayName = 'AdvancedSession';

View File

@@ -0,0 +1,23 @@
import { Alert, AlertIcon, AlertTitle } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { selectSaveAllImagesToGallery } from 'features/controlLayers/store/canvasSettingsSlice';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
export const CanvasAlertsSaveAllImagesToGallery = memo(() => {
const { t } = useTranslation();
const saveAllImagesToGallery = useAppSelector(selectSaveAllImagesToGallery);
if (!saveAllImagesToGallery) {
return null;
}
return (
<Alert status="info" borderRadius="base" fontSize="sm" shadow="md" w="fit-content">
<AlertIcon />
<AlertTitle>{t('controlLayers.settings.saveAllImagesToGallery.alert')}</AlertTitle>
</Alert>
);
});
CanvasAlertsSaveAllImagesToGallery.displayName = 'CanvasAlertsSaveAllImagesToGallery';

View File

@@ -1,3 +1,4 @@
import type { SpinnerProps } from '@invoke-ai/ui-library';
import { Spinner } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
@@ -5,7 +6,7 @@ import { useAllEntityAdapters } from 'features/controlLayers/contexts/EntityAdap
import { computed } from 'nanostores';
import { memo, useMemo } from 'react';
export const CanvasBusySpinner = memo(() => {
export const CanvasBusySpinner = memo((props: SpinnerProps) => {
const canvasManager = useCanvasManager();
const allEntityAdapters = useAllEntityAdapters();
const $isPendingRectCalculation = useMemo(
@@ -21,7 +22,7 @@ export const CanvasBusySpinner = memo(() => {
const isCompositing = useStore(canvasManager.compositor.$isBusy);
if (isRasterizing || isCompositing || isPendingRectCalculation) {
return <Spinner opacity={0.3} />;
return <Spinner opacity={0.3} {...props} />;
}
return null;
});

View File

@@ -12,6 +12,10 @@ const addControlLayerFromImageDndTargetData = newCanvasEntityFromImageDndTarget.
const addRegionalGuidanceReferenceImageFromImageDndTargetData = newCanvasEntityFromImageDndTarget.getData({
type: 'regional_guidance_with_reference_image',
});
const addResizedControlLayerFromImageDndTargetData = newCanvasEntityFromImageDndTarget.getData({
type: 'control_layer',
withResize: true,
});
export const CanvasDropArea = memo(() => {
const { t } = useTranslation();
@@ -45,7 +49,6 @@ export const CanvasDropArea = memo(() => {
isDisabled={isBusy}
/>
</GridItem>
<GridItem position="relative">
<DndDropTarget
dndTarget={newCanvasEntityFromImageDndTarget}
@@ -54,6 +57,14 @@ export const CanvasDropArea = memo(() => {
isDisabled={isBusy}
/>
</GridItem>
<GridItem position="relative">
<DndDropTarget
dndTarget={newCanvasEntityFromImageDndTarget}
dndTargetData={addResizedControlLayerFromImageDndTargetData}
label={t('controlLayers.canvasContextMenu.newResizedControlLayer')}
isDisabled={isBusy}
/>
</GridItem>
</Grid>
</>
);

View File

@@ -10,6 +10,7 @@ import { fixTooltipCloseOnScrollStyles } from 'common/util/fixTooltipCloseOnScro
import { CanvasEntityAddOfTypeButton } from 'features/controlLayers/components/common/CanvasEntityAddOfTypeButton';
import { CanvasEntityMergeVisibleButton } from 'features/controlLayers/components/common/CanvasEntityMergeVisibleButton';
import { CanvasEntityTypeIsHiddenToggle } from 'features/controlLayers/components/common/CanvasEntityTypeIsHiddenToggle';
import { RasterLayerExportPSDButton } from 'features/controlLayers/components/RasterLayer/RasterLayerExportPSDButton';
import { useEntityTypeInformationalPopover } from 'features/controlLayers/hooks/useEntityTypeInformationalPopover';
import { useEntityTypeTitle } from 'features/controlLayers/hooks/useEntityTypeTitle';
import { entitiesReordered } from 'features/controlLayers/store/canvasSlice';
@@ -166,6 +167,7 @@ export const CanvasEntityGroupList = memo(({ isSelected, type, children, entityI
</Flex>
<CanvasEntityMergeVisibleButton type={type} />
<CanvasEntityTypeIsHiddenToggle type={type} />
{type === 'raster_layer' && <RasterLayerExportPSDButton />}
<CanvasEntityAddOfTypeButton type={type} />
</Flex>
<Collapse in={collapse.isTrue} style={fixTooltipCloseOnScrollStyles}>

View File

@@ -6,6 +6,7 @@ import { EntityListSelectedEntityActionBarFilterButton } from 'features/controlL
import { EntityListSelectedEntityActionBarOpacity } from 'features/controlLayers/components/CanvasEntityList/EntityListSelectedEntityActionBarOpacity';
import { EntityListSelectedEntityActionBarSelectObjectButton } from 'features/controlLayers/components/CanvasEntityList/EntityListSelectedEntityActionBarSelectObjectButton';
import { EntityListSelectedEntityActionBarTransformButton } from 'features/controlLayers/components/CanvasEntityList/EntityListSelectedEntityActionBarTransformButton';
import { EntityListNonRasterLayerToggle } from 'features/controlLayers/components/common/CanvasNonRasterLayersIsHiddenToggle';
import { memo } from 'react';
import { EntityListSelectedEntityActionBarSaveToAssetsButton } from './EntityListSelectedEntityActionBarSaveToAssetsButton';
@@ -22,6 +23,7 @@ export const EntityListSelectedEntityActionBar = memo(() => {
<EntityListSelectedEntityActionBarTransformButton />
<EntityListSelectedEntityActionBarSaveToAssetsButton />
<EntityListSelectedEntityActionBarDuplicateButton />
<EntityListNonRasterLayerToggle />
<EntityListGlobalActionBarAddLayerMenu />
</Flex>
</Flex>

View File

@@ -14,7 +14,7 @@ export const CanvasLayersPanel = memo(() => {
return (
<CanvasManagerProviderGate>
<Flex flexDir="column" gap={2} w="full" h="full" p={2}>
<Flex flexDir="column" gap={2} w="full" h="full">
<EntityListSelectedEntityActionBar />
<Divider py={0} />
<ParamDenoisingStrength />

View File

@@ -10,8 +10,7 @@ import {
ModalOverlay,
} from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppStore } from 'app/store/nanostores/store';
import { useAppSelector } from 'app/store/storeHooks';
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';

View File

@@ -1,7 +1,6 @@
import { Flex, IconButton } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppStore } from 'app/store/nanostores/store';
import { useAppSelector } from 'app/store/storeHooks';
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import { BeginEndStepPct } from 'features/controlLayers/components/common/BeginEndStepPct';
import { Weight } from 'features/controlLayers/components/common/Weight';

View File

@@ -1,5 +1,5 @@
import { Button, Flex, Text } from '@invoke-ai/ui-library';
import { useAppStore } from 'app/store/nanostores/store';
import { useAppStore } from 'app/store/storeHooks';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { usePullBboxIntoLayer } from 'features/controlLayers/hooks/saveCanvasHooks';

View File

@@ -1,27 +0,0 @@
// import { Button, Flex } from '@invoke-ai/ui-library';
// import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
// import { useAddInpaintMaskDenoiseLimit, useAddInpaintMaskNoise } from 'features/controlLayers/hooks/addLayerHooks';
// import { useTranslation } from 'react-i18next';
// import { PiPlusBold } from 'react-icons/pi';
// Removed buttons because denosie limit is not helpful for many architectures
// Users can access with right click menu instead.
// If buttons for noise or new features are deemed important in the future, add them back here.
export const InpaintMaskAddButtons = () => {
// Buttons are temporarily hidden. To restore, uncomment the code below.
return null;
// const entityIdentifier = useEntityIdentifierContext('inpaint_mask');
// const { t } = useTranslation();
// const addInpaintMaskDenoiseLimit = useAddInpaintMaskDenoiseLimit(entityIdentifier);
// const addInpaintMaskNoise = useAddInpaintMaskNoise(entityIdentifier);
// return (
// <Flex w="full" p={2} justifyContent="center">
// <Button size="sm" variant="ghost" leftIcon={<PiPlusBold />} onClick={addInpaintMaskDenoiseLimit}>
// {t('controlLayers.denoiseLimit')}
// </Button>
// <Button size="sm" variant="ghost" leftIcon={<PiPlusBold />} onClick={addInpaintMaskNoise}>
// {t('controlLayers.imageNoise')}
// </Button>
// </Flex>
// );
};

View File

@@ -14,7 +14,7 @@ import { useTranslation } from 'react-i18next';
const [useNewGallerySessionDialog] = buildUseBoolean(false);
const [useNewCanvasSessionDialog] = buildUseBoolean(false);
export const useNewGallerySession = () => {
const useNewGallerySession = () => {
const dispatch = useAppDispatch();
const shouldConfirmOnNewSession = useAppSelector(selectSystemShouldConfirmOnNewSession);
const newSessionDialog = useNewGallerySessionDialog();
@@ -35,7 +35,7 @@ export const useNewGallerySession = () => {
return { newGallerySessionImmediate, newGallerySessionWithDialog };
};
export const useNewCanvasSession = () => {
const useNewCanvasSession = () => {
const dispatch = useAppDispatch();
const shouldConfirmOnNewSession = useAppSelector(selectSystemShouldConfirmOnNewSession);
const newSessionDialog = useNewCanvasSessionDialog();

View File

@@ -0,0 +1,31 @@
import { IconButton } from '@invoke-ai/ui-library';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { useExportCanvasToPSD } from 'features/controlLayers/hooks/useExportCanvasToPSD';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiFileArrowDownBold } from 'react-icons/pi';
export const RasterLayerExportPSDButton = memo(() => {
const { t } = useTranslation();
const isBusy = useCanvasIsBusy();
const { exportCanvasToPSD } = useExportCanvasToPSD();
const onClick = useCallback(() => {
exportCanvasToPSD();
}, [exportCanvasToPSD]);
return (
<IconButton
onClick={onClick}
isDisabled={isBusy}
size="sm"
variant="link"
alignSelf="stretch"
aria-label={t('controlLayers.exportCanvasToPSD')}
tooltip={t('controlLayers.exportCanvasToPSD')}
icon={<PiFileArrowDownBold />}
/>
);
});
RasterLayerExportPSDButton.displayName = 'RasterLayerExportPSDButton';

View File

@@ -4,9 +4,17 @@ import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useRefImageEntity } from 'features/controlLayers/components/RefImage/useRefImageEntity';
import { useRefImageIdContext } from 'features/controlLayers/contexts/RefImageIdContext';
import { refImageDeleted, selectRefImageEntityIds } from 'features/controlLayers/store/refImagesSlice';
import { selectMainModelConfig } from 'features/controlLayers/store/paramsSlice';
import {
refImageDeleted,
refImageIsEnabledToggled,
selectRefImageEntityIds,
} from 'features/controlLayers/store/refImagesSlice';
import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators';
import { memo, useCallback, useMemo } from 'react';
import { PiTrashBold } from 'react-icons/pi';
import { PiCircleBold, PiCircleFill, PiTrashBold, PiWarningBold } from 'react-icons/pi';
import { RefImageWarningTooltipContent } from './RefImageWarningTooltipContent';
const textSx: SystemStyleObject = {
color: 'base.300',
@@ -24,25 +32,63 @@ export const RefImageHeader = memo(() => {
);
const refImageNumber = useAppSelector(selectRefImageNumber);
const entity = useRefImageEntity(id);
const mainModelConfig = useAppSelector(selectMainModelConfig);
const warnings = useMemo(() => {
return getGlobalReferenceImageWarnings(entity, mainModelConfig);
}, [entity, mainModelConfig]);
const deleteRefImage = useCallback(() => {
dispatch(refImageDeleted({ id }));
}, [dispatch, id]);
const toggleIsEnabled = useCallback(() => {
dispatch(refImageIsEnabledToggled({ id }));
}, [dispatch, id]);
return (
<Flex justifyContent="space-between" alignItems="center" w="full" ps={2}>
<Text fontWeight="semibold" sx={textSx} data-is-error={!entity.config.image}>
Reference Image #{refImageNumber}
</Text>
<IconButton
tooltip="Delete Reference Image"
size="xs"
variant="link"
alignSelf="stretch"
aria-label="Delete ref image"
onClick={deleteRefImage}
icon={<PiTrashBold />}
colorScheme="error"
/>
<Flex alignItems="center" gap={1}>
{warnings.length > 0 && (
<IconButton
as="span"
size="sm"
variant="link"
alignSelf="stretch"
aria-label="warnings"
tooltip={<RefImageWarningTooltipContent warnings={warnings} />}
icon={<PiWarningBold />}
colorScheme="warning"
/>
)}
{!entity.isEnabled && (
<Text fontSize="xs" fontStyle="italic" color="base.400">
Disabled
</Text>
)}
<IconButton
tooltip={entity.isEnabled ? 'Disable Reference Image' : 'Enable Reference Image'}
size="xs"
variant="link"
alignSelf="stretch"
aria-label={entity.isEnabled ? 'Disable ref image' : 'Enable ref image'}
onClick={toggleIsEnabled}
icon={entity.isEnabled ? <PiCircleFill /> : <PiCircleBold />}
/>
<IconButton
tooltip="Delete Reference Image"
size="xs"
variant="link"
alignSelf="stretch"
aria-label="Delete ref image"
onClick={deleteRefImage}
icon={<PiTrashBold />}
colorScheme="error"
/>
</Flex>
</Flex>
);
});

View File

@@ -61,7 +61,7 @@ export const RefImageImage = memo(
)}
{imageDTO && (
<>
<DndImage imageDTO={imageDTO} borderWidth={1} borderStyle="solid" w="full" />
<DndImage imageDTO={imageDTO} borderRadius="base" borderWidth={1} borderStyle="solid" w="full" />
<Flex position="absolute" flexDir="column" top={2} insetInlineEnd={2} gap={1}>
<DndImageIcon
onClick={handleResetControlImage}

View File

@@ -1,10 +1,12 @@
import { Button, Collapse, Divider, Flex } from '@invoke-ai/ui-library';
import { useAppStore } from 'app/store/nanostores/store';
import { useAppSelector } from 'app/store/storeHooks';
import { Button, Collapse, Divider, Flex, IconButton } from '@invoke-ai/ui-library';
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import { RefImagePreview } from 'features/controlLayers/components/RefImage/RefImagePreview';
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { RefImageIdContext } from 'features/controlLayers/contexts/RefImageIdContext';
import { getDefaultRefImageConfig } from 'features/controlLayers/hooks/addLayerHooks';
import { useNewGlobalReferenceImageFromBbox } from 'features/controlLayers/hooks/saveCanvasHooks';
import { useCanvasIsBusySafe } from 'features/controlLayers/hooks/useCanvasIsBusy';
import {
refImageAdded,
selectIsRefImagePanelOpen,
@@ -14,8 +16,10 @@ import {
import { imageDTOToImageWithDims } from 'features/controlLayers/store/util';
import { addGlobalReferenceImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { memo, useMemo } from 'react';
import { PiUploadBold } from 'react-icons/pi';
import { useTranslation } from 'react-i18next';
import { PiBoundingBoxBold, PiUploadBold } from 'react-icons/pi';
import type { ImageDTO } from 'services/api/types';
import { RefImageHeader } from './RefImageHeader';
@@ -79,6 +83,7 @@ MaxRefImages.displayName = 'MaxRefImages';
const AddRefImageDropTargetAndButton = memo(() => {
const { dispatch, getState } = useAppStore();
const tab = useAppSelector(selectActiveTab);
const uploadOptions = useMemo(
() =>
@@ -96,7 +101,7 @@ const AddRefImageDropTargetAndButton = memo(() => {
const uploadApi = useImageUploadButton(uploadOptions);
return (
<>
<Flex gap={1} h="full" w="full">
<Button
position="relative"
size="sm"
@@ -113,7 +118,31 @@ const AddRefImageDropTargetAndButton = memo(() => {
<input {...uploadApi.getUploadInputProps()} />
<DndDropTarget label="Drop" dndTarget={addGlobalReferenceImageDndTarget} dndTargetData={dndTargetData} />
</Button>
</>
{tab === 'canvas' && (
<CanvasManagerProviderGate>
<BboxButton />
</CanvasManagerProviderGate>
)}
</Flex>
);
});
const BboxButton = memo(() => {
const { t } = useTranslation();
const isBusy = useCanvasIsBusySafe();
const newGlobalReferenceImageFromBbox = useNewGlobalReferenceImageFromBbox();
return (
<IconButton
size="lg"
variant="outline"
h="full"
icon={<PiBoundingBoxBold />}
onClick={newGlobalReferenceImageFromBbox}
isDisabled={isBusy}
aria-label={t('controlLayers.pullBboxIntoReferenceImage')}
tooltip={t('controlLayers.pullBboxIntoReferenceImage')}
/>
);
});
AddRefImageDropTargetAndButton.displayName = 'AddRefImageDropTargetAndButton';

View File

@@ -1,25 +1,35 @@
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { selectBase } from 'features/controlLayers/store/paramsSlice';
import { selectMainModelConfig } from 'features/controlLayers/store/paramsSlice';
import { areBasesCompatibleForRefImage } from 'features/controlLayers/store/validators';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGlobalReferenceImageModels } from 'services/api/hooks/modelsByType';
import type { AnyModelConfig, ApiModelConfig, FLUXReduxModelConfig, IPAdapterModelConfig } from 'services/api/types';
import type {
ChatGPT4oModelConfig,
FLUXKontextModelConfig,
FLUXReduxModelConfig,
IPAdapterModelConfig,
} from 'services/api/types';
type Props = {
modelKey: string | null;
onChangeModel: (modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ApiModelConfig) => void;
onChangeModel: (
modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ChatGPT4oModelConfig | FLUXKontextModelConfig
) => void;
};
export const RefImageModel = memo(({ modelKey, onChangeModel }: Props) => {
const { t } = useTranslation();
const currentBaseModel = useAppSelector(selectBase);
const mainModelConfig = useAppSelector(selectMainModelConfig);
const [modelConfigs, { isLoading }] = useGlobalReferenceImageModels();
const selectedModel = useMemo(() => modelConfigs.find((m) => m.key === modelKey), [modelConfigs, modelKey]);
const _onChangeModel = useCallback(
(modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ApiModelConfig | null) => {
(
modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ChatGPT4oModelConfig | FLUXKontextModelConfig | null
) => {
if (!modelConfig) {
return;
}
@@ -29,12 +39,10 @@ export const RefImageModel = memo(({ modelKey, onChangeModel }: Props) => {
);
const getIsDisabled = useCallback(
(model: AnyModelConfig): boolean => {
const hasMainModel = Boolean(currentBaseModel);
const hasSameBase = currentBaseModel === model.base;
return !hasMainModel || !hasSameBase;
(model: IPAdapterModelConfig | FLUXReduxModelConfig | ChatGPT4oModelConfig | FLUXKontextModelConfig): boolean => {
return !areBasesCompatibleForRefImage(mainModelConfig, model);
},
[currentBaseModel]
[mainModelConfig]
);
const { options, value, onChange, noOptionsMessage } = useGroupedModelCombobox({
@@ -47,7 +55,11 @@ export const RefImageModel = memo(({ modelKey, onChangeModel }: Props) => {
return (
<Tooltip label={selectedModel?.description}>
<FormControl isInvalid={!value || currentBaseModel !== selectedModel?.base} w="full" minW={0}>
<FormControl
isInvalid={!value || !areBasesCompatibleForRefImage(mainModelConfig, selectedModel)}
w="full"
minW={0}
>
<Combobox
options={options}
placeholder={t('common.placeholderSelectAModel')}

View File

@@ -1,24 +1,41 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Flex, Icon, IconButton, Image, Skeleton, Text } from '@invoke-ai/ui-library';
import { Flex, Icon, IconButton, Image, Skeleton, Text, Tooltip } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { round } from 'es-toolkit/compat';
import { useRefImageEntity } from 'features/controlLayers/components/RefImage/useRefImageEntity';
import { useRefImageIdContext } from 'features/controlLayers/contexts/RefImageIdContext';
import { selectMainModelConfig } from 'features/controlLayers/store/paramsSlice';
import {
refImageSelected,
selectIsRefImagePanelOpen,
selectSelectedRefEntityId,
} from 'features/controlLayers/store/refImagesSlice';
import { isIPAdapterConfig } from 'features/controlLayers/store/types';
import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators';
import { memo, useCallback, useEffect, useMemo, useState } from 'react';
import { PiExclamationMarkBold, PiImageBold } from 'react-icons/pi';
import { PiExclamationMarkBold, PiEyeSlashBold, PiImageBold } from 'react-icons/pi';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { RefImageWarningTooltipContent } from './RefImageWarningTooltipContent';
const baseSx: SystemStyleObject = {
'&[data-is-open="true"]': {
borderColor: 'invokeBlue.300',
},
'&[data-is-disabled="true"]': {
img: {
opacity: 0.4,
filter: 'grayscale(100%)',
},
},
'&[data-is-error="true"]': {
borderColor: 'error.500',
img: {
opacity: 0.4,
filter: 'grayscale(100%)',
},
},
};
const weightDisplaySx: SystemStyleObject = {
@@ -51,6 +68,7 @@ export const RefImagePreview = memo(() => {
const dispatch = useAppDispatch();
const id = useRefImageIdContext();
const entity = useRefImageEntity(id);
const mainModelConfig = useAppSelector(selectMainModelConfig);
const selectedEntityId = useAppSelector(selectSelectedRefEntityId);
const isPanelOpen = useAppSelector(selectIsRefImagePanelOpen);
const [showWeightDisplay, setShowWeightDisplay] = useState(false);
@@ -76,6 +94,10 @@ export const RefImagePreview = memo(() => {
};
}, [entity.config]);
const warnings = useMemo(() => {
return getGlobalReferenceImageWarnings(entity, mainModelConfig);
}, [entity, mainModelConfig]);
const onClick = useCallback(() => {
dispatch(refImageSelected({ id }));
}, [dispatch, id]);
@@ -97,66 +119,82 @@ export const RefImagePreview = memo(() => {
flexShrink={0}
data-is-open={selectedEntityId === id && isPanelOpen}
data-is-error={true}
data-is-disabled={!entity.isEnabled}
sx={sx}
/>
);
}
return (
<Flex
position="relative"
borderWidth={1}
borderStyle="solid"
borderRadius="base"
aspectRatio="1/1"
maxW="full"
maxH="full"
flexShrink={0}
sx={sx}
data-is-open={selectedEntityId === id && isPanelOpen}
data-is-error={!entity.config.model}
role="button"
onClick={onClick}
cursor="pointer"
>
<Image
src={imageDTO?.thumbnail_url}
objectFit="contain"
<Tooltip label={warnings.length > 0 ? <RefImageWarningTooltipContent warnings={warnings} /> : undefined}>
<Flex
position="relative"
borderWidth={1}
borderStyle="solid"
borderRadius="base"
aspectRatio="1/1"
height={imageDTO?.height}
fallback={<Skeleton h="full" aspectRatio="1/1" />}
maxW="full"
maxH="full"
borderRadius="base"
/>
{isIPAdapterConfig(entity.config) && (
<Flex
position="absolute"
inset={0}
fontWeight="semibold"
alignItems="center"
justifyContent="center"
zIndex={1}
data-visible={showWeightDisplay}
sx={weightDisplaySx}
>
<Text filter="drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 2px rgba(0, 0, 0, 1))">
{`${round(entity.config.weight * 100, 2)}%`}
</Text>
</Flex>
)}
{!entity.config.model && (
<Icon
position="absolute"
top="50%"
left="50%"
transform="translateX(-50%) translateY(-50%)"
filter="drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 2px rgba(0, 0, 0, 1))"
color="error.500"
boxSize={16}
as={PiExclamationMarkBold}
flexShrink={0}
sx={sx}
data-is-open={selectedEntityId === id && isPanelOpen}
data-is-error={warnings.length > 0}
data-is-disabled={!entity.isEnabled}
role="button"
onClick={onClick}
cursor="pointer"
overflow="hidden"
>
<Image
src={imageDTO?.thumbnail_url}
objectFit="contain"
aspectRatio="1/1"
height={imageDTO?.height}
fallback={<Skeleton h="full" aspectRatio="1/1" />}
maxW="full"
maxH="full"
/>
)}
</Flex>
{isIPAdapterConfig(entity.config) && (
<Flex
position="absolute"
inset={0}
fontWeight="semibold"
alignItems="center"
justifyContent="center"
zIndex={1}
data-visible={showWeightDisplay}
sx={weightDisplaySx}
>
<Text filter="drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 2px rgba(0, 0, 0, 1))">
{`${round(entity.config.weight * 100, 2)}%`}
</Text>
</Flex>
)}
{!entity.isEnabled && (
<Icon
position="absolute"
top="50%"
left="50%"
transform="translateX(-50%) translateY(-50%)"
filter="drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 2px rgba(0, 0, 0, 1))"
color="base.300"
boxSize={8}
as={PiEyeSlashBold}
/>
)}
{entity.isEnabled && warnings.length > 0 && (
<Icon
position="absolute"
top="50%"
left="50%"
transform="translateX(-50%) translateY(-50%)"
filter="drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 2px rgba(0, 0, 0, 1))"
color="error.500"
boxSize={12}
as={PiExclamationMarkBold}
/>
)}
</Flex>
</Tooltip>
);
});
RefImagePreview.displayName = 'RefImagePreview';

View File

@@ -38,7 +38,13 @@ import type { SetGlobalReferenceImageDndTargetData } from 'features/dnd/dnd';
import { setGlobalReferenceImageDndTarget } from 'features/dnd/dnd';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { memo, useCallback, useMemo } from 'react';
import type { ApiModelConfig, FLUXReduxModelConfig, ImageDTO, IPAdapterModelConfig } from 'services/api/types';
import type {
ChatGPT4oModelConfig,
FLUXKontextModelConfig,
FLUXReduxModelConfig,
ImageDTO,
IPAdapterModelConfig,
} from 'services/api/types';
import { RefImageImage } from './RefImageImage';
@@ -84,7 +90,7 @@ const RefImageSettingsContent = memo(() => {
);
const onChangeModel = useCallback(
(modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ApiModelConfig) => {
(modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ChatGPT4oModelConfig | FLUXKontextModelConfig) => {
dispatch(refImageModelChanged({ id, modelConfig }));
},
[dispatch, id]

View File

@@ -0,0 +1,18 @@
import { Flex, ListItem, Text, UnorderedList } from '@invoke-ai/ui-library';
import { upperFirst } from 'es-toolkit/compat';
import { useTranslation } from 'react-i18next';
export const RefImageWarningTooltipContent = ({ warnings }: { warnings: string[] }) => {
const { t } = useTranslation();
return (
<Flex flexDir="column">
<Text fontWeight="semibold">Invalid Reference Image:</Text>
<UnorderedList>
{warnings.map((tKey) => (
<ListItem key={tKey}>{upperFirst(t(tKey))}</ListItem>
))}
</UnorderedList>
</Flex>
);
};

View File

@@ -83,7 +83,7 @@ export const RegionalGuidanceIPAdapterSettingsEmptyState = memo(({ referenceImag
</Flex>
<Flex alignItems="center" gap={2} p={4}>
<Text textAlign="center" color="base.300">
<Trans i18nKey="controlLayers.referenceImageEmptyStateWithCanvasTab" components={components} />
<Trans i18nKey="controlLayers.referenceImageEmptyStateWithCanvasOptions" components={components} />
</Text>
</Flex>
<input {...uploadApi.getUploadInputProps()} />

View File

@@ -1,12 +1,14 @@
import {
Divider,
Flex,
Icon,
IconButton,
Popover,
PopoverArrow,
PopoverBody,
PopoverContent,
PopoverTrigger,
Text,
useShiftModifier,
} from '@invoke-ai/ui-library';
import { CanvasSettingsBboxOverlaySwitch } from 'features/controlLayers/components/Settings/CanvasSettingsBboxOverlaySwitch';
@@ -23,11 +25,13 @@ import { CanvasSettingsOutputOnlyMaskedRegionsCheckbox } from 'features/controlL
import { CanvasSettingsPreserveMaskCheckbox } from 'features/controlLayers/components/Settings/CanvasSettingsPreserveMaskCheckbox';
import { CanvasSettingsPressureSensitivityCheckbox } from 'features/controlLayers/components/Settings/CanvasSettingsPressureSensitivity';
import { CanvasSettingsRecalculateRectsButton } from 'features/controlLayers/components/Settings/CanvasSettingsRecalculateRectsButton';
import { CanvasSettingsRuleOfThirdsSwitch } from 'features/controlLayers/components/Settings/CanvasSettingsRuleOfThirdsGuideSwitch';
import { CanvasSettingsSaveAllImagesToGalleryCheckbox } from 'features/controlLayers/components/Settings/CanvasSettingsSaveAllImagesToGalleryCheckbox';
import { CanvasSettingsShowHUDSwitch } from 'features/controlLayers/components/Settings/CanvasSettingsShowHUDSwitch';
import { CanvasSettingsShowProgressOnCanvas } from 'features/controlLayers/components/Settings/CanvasSettingsShowProgressOnCanvasSwitch';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiGearSixFill } from 'react-icons/pi';
import { PiCodeFill, PiEyeFill, PiGearSixFill, PiPencilFill, PiSquaresFourFill } from 'react-icons/pi';
export const CanvasSettingsPopover = memo(() => {
const { t } = useTranslation();
@@ -41,22 +45,58 @@ export const CanvasSettingsPopover = memo(() => {
alignSelf="stretch"
/>
</PopoverTrigger>
<PopoverContent>
<PopoverContent maxW="280px">
<PopoverArrow />
<PopoverBody>
<Flex direction="column" gap={2}>
<CanvasSettingsInvertScrollCheckbox />
<CanvasSettingsPreserveMaskCheckbox />
<CanvasSettingsClipToBboxCheckbox />
<CanvasSettingsOutputOnlyMaskedRegionsCheckbox />
<CanvasSettingsSnapToGridCheckbox />
<CanvasSettingsPressureSensitivityCheckbox />
<CanvasSettingsShowProgressOnCanvas />
<CanvasSettingsIsolatedStagingPreviewSwitch />
<CanvasSettingsIsolatedLayerPreviewSwitch />
<CanvasSettingsDynamicGridSwitch />
<CanvasSettingsBboxOverlaySwitch />
<CanvasSettingsShowHUDSwitch />
{/* Behavior Settings */}
<Flex direction="column" gap={1}>
<Flex align="center" gap={2}>
<Icon as={PiPencilFill} boxSize={4} />
<Text fontWeight="bold" fontSize="sm" color="base.100">
{t('hotkeys.canvas.settings.behavior')}
</Text>
</Flex>
<CanvasSettingsInvertScrollCheckbox />
<CanvasSettingsPressureSensitivityCheckbox />
<CanvasSettingsPreserveMaskCheckbox />
<CanvasSettingsClipToBboxCheckbox />
<CanvasSettingsOutputOnlyMaskedRegionsCheckbox />
<CanvasSettingsSaveAllImagesToGalleryCheckbox />
</Flex>
<Divider />
{/* Display Settings */}
<Flex direction="column" gap={1}>
<Flex align="center" gap={2} color="base.200">
<Icon as={PiEyeFill} boxSize={4} />
<Text fontWeight="bold" fontSize="sm">
{t('hotkeys.canvas.settings.display')}
</Text>
</Flex>
<CanvasSettingsShowProgressOnCanvas />
<CanvasSettingsIsolatedStagingPreviewSwitch />
<CanvasSettingsIsolatedLayerPreviewSwitch />
<CanvasSettingsBboxOverlaySwitch />
<CanvasSettingsShowHUDSwitch />
</Flex>
<Divider />
{/* Grid Settings */}
<Flex direction="column" gap={1}>
<Flex align="center" gap={2} color="base.200">
<Icon as={PiSquaresFourFill} boxSize={4} />
<Text fontWeight="bold" fontSize="sm">
{t('hotkeys.canvas.settings.grid')}
</Text>
</Flex>
<CanvasSettingsSnapToGridCheckbox />
<CanvasSettingsDynamicGridSwitch />
<CanvasSettingsRuleOfThirdsSwitch />
</Flex>
<DebugSettings />
</Flex>
</PopoverBody>
@@ -68,6 +108,7 @@ export const CanvasSettingsPopover = memo(() => {
CanvasSettingsPopover.displayName = 'CanvasSettingsPopover';
const DebugSettings = () => {
const { t } = useTranslation();
const shift = useShiftModifier();
if (!shift) {
@@ -77,10 +118,18 @@ const DebugSettings = () => {
return (
<>
<Divider />
<CanvasSettingsClearCachesButton />
<CanvasSettingsRecalculateRectsButton />
<CanvasSettingsLogDebugInfoButton />
<CanvasSettingsClearHistoryButton />
<Flex direction="column" gap={1}>
<Flex align="center" gap={2} color="base.200">
<Icon as={PiCodeFill} boxSize={4} />
<Text fontWeight="bold" fontSize="sm">
{t('hotkeys.canvas.settings.debug')}
</Text>
</Flex>
<CanvasSettingsClearCachesButton />
<CanvasSettingsRecalculateRectsButton />
<CanvasSettingsLogDebugInfoButton />
<CanvasSettingsClearHistoryButton />
</Flex>
</>
);
};

View File

@@ -0,0 +1,25 @@
import { FormControl, FormLabel, Switch } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { selectRuleOfThirds, settingsRuleOfThirdsToggled } from 'features/controlLayers/store/canvasSettingsSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
export const CanvasSettingsRuleOfThirdsSwitch = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const ruleOfThirds = useAppSelector(selectRuleOfThirds);
const onChange = useCallback(() => {
dispatch(settingsRuleOfThirdsToggled());
}, [dispatch]);
return (
<FormControl>
<FormLabel m={0} flexGrow={1}>
{t('controlLayers.ruleOfThirds')}
</FormLabel>
<Switch size="sm" isChecked={ruleOfThirds} onChange={onChange} />
</FormControl>
);
});
CanvasSettingsRuleOfThirdsSwitch.displayName = 'CanvasSettingsRuleOfThirdsSwitch';

View File

@@ -0,0 +1,25 @@
import { Checkbox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
selectSaveAllImagesToGallery,
settingsSaveAllImagesToGalleryToggled,
} from 'features/controlLayers/store/canvasSettingsSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
export const CanvasSettingsSaveAllImagesToGalleryCheckbox = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const saveAllImagesToGallery = useAppSelector(selectSaveAllImagesToGallery);
const onChange = useCallback(() => {
dispatch(settingsSaveAllImagesToGalleryToggled());
}, [dispatch]);
return (
<FormControl w="full">
<FormLabel flexGrow={1}>{t('controlLayers.saveAllImagesToGallery')}</FormLabel>
<Checkbox isChecked={saveAllImagesToGallery} onChange={onChange} />
</FormControl>
);
});
CanvasSettingsSaveAllImagesToGalleryCheckbox.displayName = 'CanvasSettingsSaveAllImagesToGalleryCheckbox';

View File

@@ -1,42 +1,46 @@
import { Button, Flex, Grid, Heading, Text } from '@invoke-ai/ui-library';
import { useAutoLayoutContext } from 'features/ui/layouts/auto-layout-context';
import { Button, Flex, Grid, Text } from '@invoke-ai/ui-library';
import { navigationApi } from 'features/ui/layouts/navigation-api';
import { WORKSPACE_PANEL_ID } from 'features/ui/layouts/shared';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { InitialStateMainModelPicker } from './InitialStateMainModelPicker';
import { LaunchpadAddStyleReference } from './LaunchpadAddStyleReference';
import { LaunchpadContainer } from './LaunchpadContainer';
import { LaunchpadEditImageButton } from './LaunchpadEditImageButton';
import { LaunchpadGenerateFromTextButton } from './LaunchpadGenerateFromTextButton';
import { LaunchpadUseALayoutImageButton } from './LaunchpadUseALayoutImageButton';
export const CanvasLaunchpadPanel = memo(() => {
const ctx = useAutoLayoutContext();
const { t } = useTranslation();
const focusCanvas = useCallback(() => {
ctx.focusPanel(WORKSPACE_PANEL_ID);
}, [ctx]);
navigationApi.focusPanel('canvas', WORKSPACE_PANEL_ID);
}, []);
return (
<Flex flexDir="column" h="full" w="full" alignItems="center" gap={2}>
<Flex flexDir="column" w="full" gap={4} px={14} maxW={768} pt="20vh">
<Heading mb={4}>Edit and refine on Canvas.</Heading>
<Flex flexDir="column" gap={8}>
<Grid gridTemplateColumns="1fr 1fr" gap={8}>
<InitialStateMainModelPicker />
<Flex flexDir="column" gap={2} justifyContent="center">
<Text>
Want to learn what prompts work best for each model?{' '}
<Button as="a" variant="link" href="#" size="sm">
Check our our Model Guide.
</Button>
</Text>
</Flex>
</Grid>
<LaunchpadGenerateFromTextButton extraAction={focusCanvas} />
<LaunchpadAddStyleReference extraAction={focusCanvas} />
<LaunchpadEditImageButton extraAction={focusCanvas} />
<LaunchpadUseALayoutImageButton extraAction={focusCanvas} />
<LaunchpadContainer heading={t('ui.launchpad.canvasTitle')}>
<Grid gridTemplateColumns="1fr 1fr" gap={8}>
<InitialStateMainModelPicker />
<Flex flexDir="column" gap={2} justifyContent="center">
<Text>
{t('ui.launchpad.modelGuideText')}{' '}
<Button
as="a"
variant="link"
href="https://support.invoke.ai/support/solutions/articles/151000216086-model-guide"
target="_blank"
rel="noopener noreferrer"
size="sm"
>
{t('ui.launchpad.modelGuideLink')}
</Button>
</Text>
</Flex>
</Flex>
</Flex>
</Grid>
<LaunchpadGenerateFromTextButton extraAction={focusCanvas} />
<LaunchpadAddStyleReference extraAction={focusCanvas} />
<LaunchpadEditImageButton extraAction={focusCanvas} />
<LaunchpadUseALayoutImageButton extraAction={focusCanvas} />
</LaunchpadContainer>
);
});
CanvasLaunchpadPanel.displayName = 'CanvasLaunchpadPanel';

View File

@@ -1,47 +1,48 @@
import { Alert, Button, Flex, Grid, Heading, Text } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { Alert, Button, Flex, Grid, Text } from '@invoke-ai/ui-library';
import { InitialStateMainModelPicker } from 'features/controlLayers/components/SimpleSession/InitialStateMainModelPicker';
import { LaunchpadAddStyleReference } from 'features/controlLayers/components/SimpleSession/LaunchpadAddStyleReference';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { navigationApi } from 'features/ui/layouts/navigation-api';
import { memo, useCallback } from 'react';
import { LaunchpadContainer } from './LaunchpadContainer';
import { LaunchpadGenerateFromTextButton } from './LaunchpadGenerateFromTextButton';
export const GenerateLaunchpadPanel = memo(() => {
const dispatch = useAppDispatch();
const newCanvasSession = useCallback(() => {
dispatch(setActiveTab('canvas'));
}, [dispatch]);
navigationApi.switchToTab('canvas');
}, []);
return (
<Flex flexDir="column" h="full" w="full" alignItems="center" gap={2}>
<Flex flexDir="column" w="full" gap={4} px={14} maxW={768} pt="20vh">
<Heading mb={4}>Generate images from text prompts.</Heading>
<Flex flexDir="column" gap={8}>
<Grid gridTemplateColumns="1fr 1fr" gap={8}>
<InitialStateMainModelPicker />
<Flex flexDir="column" gap={2} justifyContent="center">
<Text>
Want to learn what prompts work best for each model?{' '}
<Button as="a" variant="link" href="#" size="sm">
Check our our Model Guide.
</Button>
</Text>
</Flex>
</Grid>
<LaunchpadGenerateFromTextButton />
<LaunchpadAddStyleReference />
<Alert status="info" borderRadius="base" flexDir="column" gap={2} overflow="unset">
<Text fontSize="md" fontWeight="semibold">
Looking to get more control, edit, and iterate on your images?
</Text>
<Button variant="link" onClick={newCanvasSession}>
Navigate to Canvas for more capabilities.
<LaunchpadContainer heading="Generate images from text prompts.">
<Grid gridTemplateColumns="1fr 1fr" gap={8}>
<InitialStateMainModelPicker />
<Flex flexDir="column" gap={2} justifyContent="center">
<Text>
Want to learn what prompts work best for each model?{' '}
<Button
as="a"
variant="link"
href="https://support.invoke.ai/support/solutions/articles/151000216086-model-guide"
target="_blank"
rel="noopener noreferrer"
size="sm"
>
Check out our Model Guide.
</Button>
</Alert>
</Text>
</Flex>
</Flex>
</Flex>
</Grid>
<LaunchpadGenerateFromTextButton />
<LaunchpadAddStyleReference />
<Alert status="info" borderRadius="base" flexDir="column" gap={2} overflow="unset">
<Text fontSize="md" fontWeight="semibold">
Looking to get more control, edit, and iterate on your images?
</Text>
<Button variant="link" onClick={newCanvasSession}>
Navigate to Canvas for more capabilities.
</Button>
</Alert>
</LaunchpadContainer>
);
});
GenerateLaunchpadPanel.displayName = 'GenerateLaunchpad';

View File

@@ -1,28 +0,0 @@
import type { ButtonGroupProps } from '@invoke-ai/ui-library';
import { Button, ButtonGroup } from '@invoke-ai/ui-library';
import { useAppStore } from 'app/store/nanostores/store';
import { newCanvasFromImage } from 'features/imageActions/actions';
import { memo, useCallback } from 'react';
import type { ImageDTO } from 'services/api/types';
export const ImageActions = memo(({ imageDTO, ...rest }: { imageDTO: ImageDTO } & ButtonGroupProps) => {
const { getState, dispatch } = useAppStore();
const edit = useCallback(() => {
newCanvasFromImage({
imageDTO,
type: 'raster_layer',
withInpaintMask: true,
getState,
dispatch,
});
}, [dispatch, getState, imageDTO]);
return (
<ButtonGroup isAttached={false} size="sm" {...rest}>
<Button onClick={edit} tooltip="Edit parts of this image with Inpainting">
Edit
</Button>
</ButtonGroup>
);
});
ImageActions.displayName = 'ImageActions';

Some files were not shown because too many files have changed in this diff Show More