Compare commits

..

175 Commits

Author SHA1 Message Date
psychedelicious
e5776e048d chore: bump version to v6.0.0a7 2025-06-25 14:14:48 +10:00
psychedelicious
10a87717cb chore(ui): dpdm 2025-06-25 14:12:21 +10:00
psychedelicious
62b0f7166a chore: ruff 2025-06-25 14:06:00 +10:00
psychedelicious
cb48244b9f fix(ui): restore gallery selection count tag 2025-06-25 14:06:00 +10:00
psychedelicious
e0224d73eb fix(ui): debounce gallery min width value 2025-06-25 14:06:00 +10:00
psychedelicious
a0ff1f6697 chore(ui): disable debug logger 2025-06-25 14:06:00 +10:00
psychedelicious
76e767d50d fix(ui): issues with progress viewer 2025-06-25 14:06:00 +10:00
psychedelicious
6a1bba819a refactor: remove unused methods/routes, fix some gallery invalidation issues 2025-06-25 14:06:00 +10:00
psychedelicious
92cd59324b feat(ui): restore gallery hotkeys (except delete) 2025-06-25 14:06:00 +10:00
psychedelicious
5ff1a9e9b5 fix(ui): gallery updates on image completion 2025-06-25 14:05:59 +10:00
psychedelicious
a364c28061 fix(ui): remove context from DOM props 2025-06-25 14:05:59 +10:00
psychedelicious
2c231075a3 feat(ui): calculate gridTemplateColumns in selector 2025-06-25 14:05:59 +10:00
psychedelicious
8dd7b29966 chore: ruff 2025-06-25 14:05:59 +10:00
psychedelicious
af83c2b5d7 chore: bump version to v6.0.0a6 2025-06-25 14:05:59 +10:00
psychedelicious
4be31bd5a1 fix(ui): minor jank when siwtching images rapidly 2025-06-25 14:05:59 +10:00
psychedelicious
ef519d418c feat(ui): scrollbar styles 2025-06-25 14:05:59 +10:00
psychedelicious
af5ddd9a7b refactor: gallery scroll (improved impl) 2025-06-25 14:05:59 +10:00
psychedelicious
41bc2cfc2f refactor: gallery scroll (improved impl) 2025-06-25 14:05:59 +10:00
psychedelicious
a32d1147d2 refactor: gallery scroll (improved impl) 2025-06-25 14:05:59 +10:00
psychedelicious
b2c5b6f039 refactor: gallery scroll (improved impl) 2025-06-25 14:05:59 +10:00
psychedelicious
b2cdf14a32 refactor: gallery scroll (improved impl) 2025-06-25 14:05:59 +10:00
psychedelicious
8f736bff3f refactor: gallery scroll (improved impl) 2025-06-25 14:05:59 +10:00
psychedelicious
42e326304c refactor: gallery scroll (improved impl) 2025-06-25 14:05:58 +10:00
psychedelicious
3d1642ce52 refactor: gallery scroll (improved impl) 2025-06-25 14:05:58 +10:00
psychedelicious
a339cec36f refactor: gallery scroll (improved impl) 2025-06-25 14:05:58 +10:00
psychedelicious
054428730d refactor: gallery scroll (improved impl) 2025-06-25 14:05:58 +10:00
psychedelicious
b51e6c8783 refactor: gallery scroll (improved impl) 2025-06-25 14:05:58 +10:00
psychedelicious
cc255435c7 refactor: gallery scroll 2025-06-25 14:05:58 +10:00
psychedelicious
a5baa53f07 fix(ui): fix metadata toggle stuck disabled 2025-06-25 14:05:58 +10:00
psychedelicious
32704b41a1 chore: bump version to v6.0.0a5 2025-06-25 14:05:58 +10:00
psychedelicious
c43c9358b9 chore(ui): lint 2025-06-25 14:05:58 +10:00
psychedelicious
76284a22a0 refactor(ui): use image names for selection instead of dtos
Update the frontend to incorporate the previous changes to how image
selection and general image identification is handled in the frontend.
2025-06-25 14:05:58 +10:00
psychedelicious
c83bdc58c9 chore(ui): typegen 2025-06-25 14:05:58 +10:00
psychedelicious
a63482b582 feat(api): return more data when doing image/board mutations
When we delete images, boards, or do any other board mutation, we need
to invalidate numerous query caches and related internal frontend state.
This gets complicated very quickly.

We can drastically reduce the complexity by having the backend return
some more information when we make these mutations.

For example, when deleting a list of images by name, we can return a
list of deleted image name and affected boards. The frontend can use
this information to determine which queries to invalidate with far less
tedium.

This will also enable the more efficient storage of images (e.g. in the
gallery selection). Previously, we had to store the entire image DTO
object, else we wouldn't be able to figure out which queries to
invalidate. But now that the backend tells us exactly what images/boards
have changed, we can just store image names in frontend state. This
amounts to a substantial improvement in DX and reduction in frontend
complexity.
2025-06-25 14:05:57 +10:00
psychedelicious
c0b32cb549 feat(ui): viewer integrates progress (wip) 2025-06-25 14:05:57 +10:00
psychedelicious
c3ffafeaeb feat(ui): switch to viewer/canvas on invoke 2025-06-25 14:05:57 +10:00
psychedelicious
4dc28758fc feat(ui): generation progress tab improvements 2025-06-25 14:05:57 +10:00
psychedelicious
bed7a03f14 feat(ui): show last progress message & placeholder in generation progress panel 2025-06-25 14:05:57 +10:00
psychedelicious
aa2fc1f3f9 fix(ui): staging area does not show placeholder on first render 2025-06-25 14:05:57 +10:00
psychedelicious
80f0ee77e2 feat(ui): double-click staging area image to disable auto-switch 2025-06-25 14:05:57 +10:00
psychedelicious
10cb511649 fix(ui): reset last started item id when doing autoswitch 2025-06-25 14:05:57 +10:00
psychedelicious
b751d92173 feat(ui): re-implement multiple auto-switch modes 2025-06-25 14:05:57 +10:00
psychedelicious
d213586202 chore: bump version to v6.0.0a4 2025-06-25 14:05:57 +10:00
psychedelicious
f1b7a6c9c5 feat(ui): no model error state for ref images 2025-06-25 14:05:57 +10:00
psychedelicious
c01f8f595f feat(ui): mini metadata viewer 2025-06-25 14:05:57 +10:00
psychedelicious
cf30d1d476 feat(ui): clean up image view components & code 2025-06-25 14:05:57 +10:00
psychedelicious
1bd9515b6c fix(ui): launchpad layouts 2025-06-25 14:05:56 +10:00
psychedelicious
aaafc1c55a fix(ui): don't use layers when generating on generate tab 2025-06-25 14:05:56 +10:00
psychedelicious
5e73597bb5 feat(ui): tweak vertical tab bar layout 2025-06-25 14:05:56 +10:00
psychedelicious
89657937cf fix(ui): unable to resize prompt box bc negative prompt button is over
the handle
2025-06-25 14:05:56 +10:00
psychedelicious
00cbba4dbd feat(ui): standardize auto layout structure 2025-06-25 14:05:56 +10:00
psychedelicious
6755cf2b43 feat(ui): tweak dockview tabs 2025-06-25 14:05:56 +10:00
psychedelicious
17ee4dd387 refactor(ui): rip out image viewer as modal 2025-06-25 14:05:56 +10:00
psychedelicious
bed8424dbd chore: bump version to v6.0.0a3 2025-06-25 14:05:56 +10:00
psychedelicious
712ae21994 chore(ui): lint 2025-06-25 14:05:56 +10:00
psychedelicious
85a8eae777 feat(ui): restore all panel hotkeys 2025-06-25 14:05:56 +10:00
psychedelicious
ec305ca10b fix(ui): generate tab hotkey 2025-06-25 14:05:56 +10:00
psychedelicious
6b8ff54839 feat(ui): restore floating panel buttons 2025-06-25 14:05:56 +10:00
psychedelicious
b88a258ba5 feat(ui): get all tabs working w/ new layout 2025-06-25 14:05:56 +10:00
psychedelicious
82a4e66070 fix(ui): unnecessary dependency on tab selection in
useCanvasDeleteLayerHotkey
2025-06-25 14:05:55 +10:00
psychedelicious
f6b85f8249 fix(ui): inverted logic for resume queue button 2025-06-25 14:05:55 +10:00
psychedelicious
ebd7087256 feat(ui): get layouts working 2025-06-25 14:05:55 +10:00
psychedelicious
9f39919c8f feat(ui): canvas launchpad 2025-06-25 14:05:55 +10:00
psychedelicious
4447e0d3ea wip 2025-06-25 14:05:55 +10:00
psychedelicious
ae50bedd88 fix(ui): wonky stage sizing on first visibility 2025-06-25 14:05:55 +10:00
psychedelicious
c23caa5116 wip 2025-06-25 14:05:55 +10:00
psychedelicious
6b5645adb6 feat(ui): port UI slice to zod 2025-06-25 14:05:55 +10:00
psychedelicious
9302bd5f11 fix(ui): only show weight for IP adapters 2025-06-25 14:05:55 +10:00
psychedelicious
740dac4602 feat(ui): represent IP adapter weight in ref image thumbnail 2025-06-25 14:05:55 +10:00
psychedelicious
7f49747a0f fix(ui): overflow on ref image model 2025-06-25 14:05:55 +10:00
psychedelicious
1426aade13 feat(ui): ref images feel more like buttons 2025-06-25 14:05:55 +10:00
psychedelicious
a6029ea60e feat(ui): switch tab on drag over tab button 2025-06-25 14:05:55 +10:00
psychedelicious
680c759af3 feat(ui): tweak splash screen layout 2025-06-25 14:05:54 +10:00
psychedelicious
1d2a85dd6e chore(ui): lint 2025-06-25 14:05:54 +10:00
psychedelicious
5695d2a3cc feat(ui): rework simple session initial state 2025-06-25 14:05:54 +10:00
psychedelicious
d89021f6a1 fix(ui): invoke button tooltip on generate tab 2025-06-25 14:05:54 +10:00
psychedelicious
e13276b052 fix(ui): progress image fixes 2025-06-25 14:05:54 +10:00
psychedelicious
bc1b1e187b feat(ui): make autoswitch on/off
When the invocation cache is used, we might skip all progress images. This can prevent auto-switch-on-first-progress from working, as we don't get any of those events.

It's much easier to only support auto-switch on complete.
2025-06-25 14:05:54 +10:00
psychedelicious
9da473cc51 feat(ui): refine ref images UI 2025-06-25 14:05:54 +10:00
psychedelicious
52ca4e0f19 feat(ui): toggleable negative prompt 2025-06-25 14:05:54 +10:00
psychedelicious
884cc6d47d fix(ui): remove old isSelected from refImageAdded call 2025-06-25 14:05:54 +10:00
psychedelicious
f5a5e5e7a8 chore: bump version to v6.0.0a2 2025-06-25 14:05:54 +10:00
psychedelicious
cd440eb836 fix(ui): update queue item preview images on init of queue items context 2025-06-25 14:05:54 +10:00
psychedelicious
2e4bd260c0 fix(ui): hack to close chakra tooltips on drag 2025-06-25 14:05:54 +10:00
psychedelicious
22d68652c9 tweak(ui): ref image header 2025-06-25 14:05:54 +10:00
psychedelicious
61bec9f30c experiment(ui): add generate tab 2025-06-25 14:05:53 +10:00
psychedelicious
d90abbe149 refactor(ui): ref images (WIP) 2025-06-25 14:05:53 +10:00
psychedelicious
4a5d873567 refactor(ui): ref images (WIP) 2025-06-25 14:05:53 +10:00
psychedelicious
626ca236d6 refactor(ui): refImage.ipAdapter -> refImage.config 2025-06-25 14:05:53 +10:00
psychedelicious
952483b3f3 feat(ui): split out ref images into own slice (WIP) 2025-06-25 14:05:53 +10:00
psychedelicious
16f50b0e8b feat(ui): simple session initial state cards are buttons 2025-06-25 14:05:53 +10:00
psychedelicious
1bbcd97b24 chore(ui): dpdm 2025-06-25 14:05:53 +10:00
psychedelicious
b4812f8c1d refactor(ui): async modal pattern; use for deleting images
This was needed for a canvas flow change which is currently paused, but the new API is much much nicer to use, so I am keeping it.
2025-06-25 14:05:53 +10:00
psychedelicious
187de81ae9 fix(ui): use imageDTO in staging area 2025-06-25 14:05:53 +10:00
psychedelicious
33b3b2994f fix(ui): wait until last queue item deleted before flagging canvas session finished 2025-06-25 14:05:53 +10:00
psychedelicious
8557684175 feat(ui): store output image DTO in session context instead of just the name 2025-06-25 14:05:53 +10:00
psychedelicious
84904f5ecb feat(ui): add AppGetState type 2025-06-25 14:05:53 +10:00
psychedelicious
80f645778c feat(ui): close viewer on escape 2025-06-25 14:05:53 +10:00
psychedelicious
bb423f79ef fix(ui): switch only on first progress image 2025-06-25 14:05:52 +10:00
psychedelicious
3956f52680 feat(ui): add on first progress autoswitch mode 2025-06-25 14:05:52 +10:00
psychedelicious
59b57f4cd0 feat(ui): move canvas-specific staging subscriptions to CanvasStagingAreaModule 2025-06-25 14:05:52 +10:00
psychedelicious
2fd5b4011b chore(ui): lint 2025-06-25 14:05:52 +10:00
psychedelicious
4a34e75794 feat(ui): make main panel styling and title consistent 2025-06-25 14:05:52 +10:00
psychedelicious
b46fd00a05 feat(ui): add startover button to canvas toolbar 2025-06-25 14:05:52 +10:00
psychedelicious
1474a0e8c1 feat(ui): fiddle w/ staging area header 2025-06-25 14:05:52 +10:00
psychedelicious
51c9beafa5 feat(ui): remove technical progress message from full preview 2025-06-25 14:05:52 +10:00
psychedelicious
554cc6b2fc feat(ui): simple session initial state 2025-06-25 14:05:52 +10:00
psychedelicious
6b3371942d feat(ui): remove vary and edit as control buttons 2025-06-25 14:05:52 +10:00
psychedelicious
dcba860771 refactor(ui): migrate from canceling queue items to deleteing, make queue hook APIs consistent 2025-06-25 14:05:52 +10:00
psychedelicious
862dd3f12f fix(ui): mini preview bg color 2025-06-25 14:05:52 +10:00
psychedelicious
3bbc2c458c fix(ui): hide layers when not on canvas tab 2025-06-25 14:05:51 +10:00
psychedelicious
f498610559 build(ui): temporarily ignore all knip issues 2025-06-25 14:05:51 +10:00
psychedelicious
2ce4b77821 feat(ui): finish generation when discarding last item 2025-06-25 14:05:51 +10:00
psychedelicious
d55aec6ec0 feat(ui): when discarding last item, select new last instead of first 2025-06-25 14:05:51 +10:00
psychedelicious
3017c14f1a feat(ui): tweak staging image display 2025-06-25 14:05:51 +10:00
psychedelicious
5a7dc1cb20 feat(ui): add staging area toolbar to simple session 2025-06-25 14:05:51 +10:00
psychedelicious
1c0d2ab6fb fix(ui): ensure canvas tool modules are destroyed 2025-06-25 14:05:51 +10:00
psychedelicious
f0fe070618 fix(ui): reset layers when changing session type 2025-06-25 14:05:51 +10:00
psychedelicious
d0decd5dca feat(ui): improved staging placeholders 2025-06-25 14:05:51 +10:00
psychedelicious
5f1f388114 feat(ui): improved staging placeholders 2025-06-25 14:05:51 +10:00
psychedelicious
977aa4f438 feat(ui): more staging fixes 2025-06-25 14:05:51 +10:00
psychedelicious
1c89152fe8 feat(ui): update canvas session state handling for new staging strat 2025-06-25 14:05:51 +10:00
psychedelicious
f6c39c7b34 chore(ui): lint (partial cleanup) 2025-06-25 14:05:51 +10:00
psychedelicious
ef8d993557 feat(ui): rough out canvas staging area 2025-06-25 14:05:50 +10:00
psychedelicious
0eaf313267 feat(app): support deleting queue items by id or destination 2025-06-25 14:05:50 +10:00
psychedelicious
cfa0ca27af feat(ui): tweak canvas scroll to zoom feel 2025-06-25 14:05:50 +10:00
psychedelicious
45a263bbab docs(ui): add comment about auto-switch not being quite right yet 2025-06-25 14:05:50 +10:00
psychedelicious
a9eebe47f3 feat: canvas flow rework (wip) 2025-06-25 14:05:50 +10:00
psychedelicious
8ffdd47644 feat(ui): prevent flicker of image action buttons 2025-06-25 14:05:50 +10:00
psychedelicious
6a568e5cf4 feat(ui): move socket events handling into ctx component 2025-06-25 14:05:50 +10:00
psychedelicious
2232dfa931 feat(ui): modularize all staging area logic so it can be shared w/ canvas more easily 2025-06-25 14:05:50 +10:00
psychedelicious
eea918803e perf(ui): queue actions menu is lazy 2025-06-25 14:05:50 +10:00
psychedelicious
bb7ae3a4fa fix(ui): cursor on staging area preview image 2025-06-25 14:05:50 +10:00
psychedelicious
c172c95b2b feat(ui): remove clear queue ui components 2025-06-25 14:05:50 +10:00
psychedelicious
a014f5b28c feat(app): do not prune queue on startup
With the new canvas design, this will result in loss of staging area images.
2025-06-25 14:05:50 +10:00
psychedelicious
3e7fb22916 tidy(ui): component organization 2025-06-25 14:05:49 +10:00
psychedelicious
55bc6b8c4d fix(ui): prevent drag of progress images 2025-06-25 14:05:49 +10:00
psychedelicious
93b1d0b678 feat: canvas flow rework (wip) 2025-06-25 14:05:49 +10:00
psychedelicious
a01a9a1bed feat: canvas flow rework (wip) 2025-06-25 14:05:49 +10:00
psychedelicious
e0a1f19bbb chore(ui): typegen 2025-06-25 14:05:49 +10:00
psychedelicious
4e886ca0e0 feat(api): remove status from list all queue items query 2025-06-25 14:05:49 +10:00
psychedelicious
a038c09bb3 tidy(ui): app layout components 2025-06-25 14:05:49 +10:00
psychedelicious
efc5531f4a feat: canvas flow rework (wip) 2025-06-25 14:05:49 +10:00
psychedelicious
19d34750af feat: canvas flow rework (wip) 2025-06-25 14:05:49 +10:00
psychedelicious
6931f4b538 feat: canvas flow rework (wip) 2025-06-25 14:05:49 +10:00
psychedelicious
2f63b617e4 fix(ui): unstable selector results in lora drop down 2025-06-25 14:05:49 +10:00
psychedelicious
bfac64430f feat: canvas flow rework (wip) 2025-06-25 14:05:49 +10:00
psychedelicious
5085fdabb2 feat: canvas flow rework (wip) 2025-06-25 14:05:48 +10:00
psychedelicious
bcc3e3b338 wip progress events 2025-06-25 14:05:48 +10:00
psychedelicious
db8fbc33c2 refactor(ui): canvas flow (wip) 2025-06-25 14:05:48 +10:00
psychedelicious
862343990d fix(ui): ref goes undefined in GalleryImage
This appears to be a bug in Chakra UI v2 - use of a fallback component makes the ref passed to an image end up undefined. Had to remove the skeleton loader fallback component.
2025-06-25 14:05:48 +10:00
psychedelicious
93338bd875 fix(ui): merge refs when forwardingin DndImage 2025-06-25 14:05:48 +10:00
psychedelicious
6f1192c5ae fix(ui): remove unused sessionId field from type 2025-06-25 14:05:48 +10:00
psychedelicious
7d9af08bc8 fix(ui): ensure all args are passed to handler when creating new canvas from image 2025-06-25 14:05:48 +10:00
psychedelicious
be44e7449e feat(ui): bookmark new inpaint masks 2025-06-25 14:05:48 +10:00
psychedelicious
a015789521 feat(ui): support bookmarking an entity when adding it 2025-06-25 14:05:48 +10:00
psychedelicious
176624accd fix(ui): ensure images are added to gallery in simple sessions 2025-06-25 14:05:48 +10:00
psychedelicious
7e408abc85 feat(ui): images always added to gallery in simple session 2025-06-25 14:05:48 +10:00
psychedelicious
84de2062dc wip 2025-06-25 14:05:48 +10:00
psychedelicious
79c106fcd3 refactor(ui): canvas flow (wip) 2025-06-25 14:05:48 +10:00
psychedelicious
0c289fb892 refactor(ui): canvas flow (wip) 2025-06-25 14:05:47 +10:00
psychedelicious
36d78a335a refactor(ui): canvas flow events (wip) 2025-06-25 14:05:47 +10:00
psychedelicious
9cd5bc29e7 refactor(ui): canvas flow (wip) 2025-06-25 14:05:47 +10:00
psychedelicious
ffc3a222e0 refactor(ui): canvas flow (wip) 2025-06-25 14:05:47 +10:00
psychedelicious
733a010996 refactor(ui): canvas flow (wip) 2025-06-25 14:05:47 +10:00
psychedelicious
c3950af063 refactor(ui): canvas flow (wip) 2025-06-25 14:05:47 +10:00
psychedelicious
371d6c73f6 fix(ui): circular import issue 2025-06-25 14:05:47 +10:00
psychedelicious
6c330cae83 refactor(ui): params state zodification 2025-06-25 14:05:47 +10:00
psychedelicious
50fc096cf1 refactor(ui): move params state to big file of canvas zod stuff 2025-06-25 14:05:47 +10:00
psychedelicious
9960eee7f6 refactor(ui): zod-ify params slice state 2025-06-25 14:05:47 +10:00
psychedelicious
7604ac54d0 refactor(ui): org state in prep for new flow 2025-06-25 14:05:47 +10:00
psychedelicious
c66cf04e06 refactor(ui): image viewer & comparison convolutedness 2025-06-25 14:05:47 +10:00
psychedelicious
77f6824ddd feat(ui): default canvas tool is move 2025-06-25 14:05:47 +10:00
psychedelicious
8f4cd7544b chore(ui): bump @reduxjs/toolkit to latest 2025-06-25 14:05:46 +10:00
psychedelicious
6e135d5f05 feat(ui): viewer is a modal (wip) 2025-06-25 14:05:46 +10:00
555 changed files with 15781 additions and 23846 deletions

View File

@@ -3,15 +3,15 @@ description: Installs frontend dependencies with pnpm, with caching
runs:
using: 'composite'
steps:
- name: setup node 20
- name: setup node 18
uses: actions/setup-node@v4
with:
node-version: '20'
node-version: '18'
- name: setup pnpm
uses: pnpm/action-setup@v4
with:
version: 10
version: 8.15.6
run_install: false
- name: get pnpm store directory

View File

@@ -297,7 +297,7 @@ Migration logic is in [migrations.ts].
<!-- links -->
[pydantic]: https://github.com/pydantic/pydantic 'pydantic'
[zod]: https://github.com/colinhacks/zod 'zod/v4'
[zod]: https://github.com/colinhacks/zod 'zod'
[openapi-types]: https://github.com/kogosoftwarellc/open-api/tree/main/packages/openapi-types 'openapi-types'
[reactflow]: https://github.com/xyflow/xyflow 'reactflow'
[reactflow-concepts]: https://reactflow.dev/learn/concepts/terms-and-definitions

View File

@@ -35,7 +35,7 @@ More detail on system requirements can be found [here](./requirements.md).
## Step 2: Download
Download the most recent launcher for your operating system:
Download the most launcher for your operating system:
- [Download for Windows](https://download.invoke.ai/Invoke%20Community%20Edition.exe)
- [Download for macOS](https://download.invoke.ai/Invoke%20Community%20Edition.dmg)

View File

@@ -14,7 +14,6 @@ from invokeai.app.api.extract_metadata_from_image import extract_metadata_from_i
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
ImageNamesResult,
ImageRecordChanges,
ResourceOrigin,
)
@@ -72,7 +71,7 @@ async def upload_image(
resize_to: Optional[str] = Body(
default=None,
description=f"Dimensions to resize the image to, must be stringified tuple of 2 integers. Max total pixel count: {ResizeToDimensions.MAX_SIZE}",
examples=['"[1024,1024]"'],
example='"[1024,1024]"',
),
metadata: Optional[str] = Body(
default=None,
@@ -577,11 +576,11 @@ async def get_image_names(
order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"),
starred_first: bool = Query(default=True, description="Whether to sort by starred images first"),
search_term: Optional[str] = Query(default=None, description="The term to search for"),
) -> ImageNamesResult:
"""Gets ordered list of image names with metadata for optimistic updates"""
) -> list[str]:
"""Gets ordered list of all image names (starred first, then unstarred)"""
try:
result = ApiDependencies.invoker.services.images.get_image_names(
image_names = ApiDependencies.invoker.services.images.get_image_names(
starred_first=starred_first,
order_dir=order_dir,
image_origin=image_origin,
@@ -590,34 +589,6 @@ async def get_image_names(
board_id=board_id,
search_term=search_term,
)
return result
return image_names
except Exception:
raise HTTPException(status_code=500, detail="Failed to get image names")
@images_router.post(
"/images_by_names",
operation_id="get_images_by_names",
responses={200: {"model": list[ImageDTO]}},
)
async def get_images_by_names(
image_names: list[str] = Body(embed=True, description="Object containing list of image names to fetch DTOs for"),
) -> list[ImageDTO]:
"""Gets image DTOs for the specified image names. Maintains order of input names."""
try:
image_service = ApiDependencies.invoker.services.images
# Fetch DTOs preserving the order of requested names
image_dtos: list[ImageDTO] = []
for name in image_names:
try:
dto = image_service.get_dto(name)
image_dtos.append(dto)
except Exception:
# Skip missing images - they may have been deleted between name fetch and DTO fetch
continue
return image_dtos
except Exception:
raise HTTPException(status_code=500, detail="Failed to get image DTOs")

View File

@@ -41,7 +41,6 @@ from invokeai.backend.model_manager.starter_models import (
STARTER_BUNDLES,
STARTER_MODELS,
StarterModel,
StarterModelBundle,
StarterModelWithoutDependencies,
)
@@ -292,7 +291,7 @@ async def get_hugging_face_models(
)
async def update_model_record(
key: Annotated[str, Path(description="Unique key of model")],
changes: Annotated[ModelRecordChanges, Body(description="Model config", examples=[example_model_input])],
changes: Annotated[ModelRecordChanges, Body(description="Model config", example=example_model_input)],
) -> AnyModelConfig:
"""Update a model's config."""
logger = ApiDependencies.invoker.services.logger
@@ -450,7 +449,7 @@ async def install_model(
access_token: Optional[str] = Query(description="access token for the remote resource", default=None),
config: ModelRecordChanges = Body(
description="Object containing fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
examples=[{"name": "string", "description": "string"}],
example={"name": "string", "description": "string"},
),
) -> ModelInstallJob:
"""Install a model using a string identifier.
@@ -800,7 +799,7 @@ async def convert_model(
class StarterModelResponse(BaseModel):
starter_models: list[StarterModel]
starter_bundles: dict[str, StarterModelBundle]
starter_bundles: dict[str, list[StarterModel]]
def get_is_installed(
@@ -834,7 +833,7 @@ async def get_starter_models() -> StarterModelResponse:
model.dependencies = missing_deps
for bundle in starter_bundles.values():
for model in bundle.models:
for model in bundle:
model.is_installed = get_is_installed(model, installed_models)
# Remove already-installed dependencies
missing_deps: list[StarterModelWithoutDependencies] = []

View File

@@ -1,6 +1,6 @@
from typing import Optional
from fastapi import Body, HTTPException, Path, Query
from fastapi import Body, Path, Query
from fastapi.routing import APIRouter
from pydantic import BaseModel, Field
@@ -22,7 +22,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
RetryItemsResult,
SessionQueueCountsByDestination,
SessionQueueItem,
SessionQueueItemNotFoundError,
SessionQueueStatus,
)
from invokeai.app.services.shared.pagination import CursorPaginatedResults
@@ -60,12 +59,10 @@ async def enqueue_batch(
),
) -> EnqueueBatchResult:
"""Processes a batch and enqueues the output graphs for execution."""
try:
return await ApiDependencies.invoker.services.session_queue.enqueue_batch(
queue_id=queue_id, batch=batch, prepend=prepend
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while enqueuing batch: {e}")
return await ApiDependencies.invoker.services.session_queue.enqueue_batch(
queue_id=queue_id, batch=batch, prepend=prepend
)
@session_queue_router.get(
@@ -85,17 +82,14 @@ async def list_queue_items(
) -> CursorPaginatedResults[SessionQueueItem]:
"""Gets cursor-paginated queue items"""
try:
return ApiDependencies.invoker.services.session_queue.list_queue_items(
queue_id=queue_id,
limit=limit,
status=status,
cursor=cursor,
priority=priority,
destination=destination,
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while listing all items: {e}")
return ApiDependencies.invoker.services.session_queue.list_queue_items(
queue_id=queue_id,
limit=limit,
status=status,
cursor=cursor,
priority=priority,
destination=destination,
)
@session_queue_router.get(
@@ -110,13 +104,11 @@ async def list_all_queue_items(
destination: Optional[str] = Query(default=None, description="The destination of queue items to fetch"),
) -> list[SessionQueueItem]:
"""Gets all queue items"""
try:
return ApiDependencies.invoker.services.session_queue.list_all_queue_items(
queue_id=queue_id,
destination=destination,
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while listing all queue items: {e}")
return ApiDependencies.invoker.services.session_queue.list_all_queue_items(
queue_id=queue_id,
destination=destination,
)
@session_queue_router.put(
@@ -128,10 +120,7 @@ async def resume(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> SessionProcessorStatus:
"""Resumes session processor"""
try:
return ApiDependencies.invoker.services.session_processor.resume()
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while resuming queue: {e}")
return ApiDependencies.invoker.services.session_processor.resume()
@session_queue_router.put(
@@ -143,10 +132,7 @@ async def Pause(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> SessionProcessorStatus:
"""Pauses session processor"""
try:
return ApiDependencies.invoker.services.session_processor.pause()
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while pausing queue: {e}")
return ApiDependencies.invoker.services.session_processor.pause()
@session_queue_router.put(
@@ -158,10 +144,7 @@ async def cancel_all_except_current(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> CancelAllExceptCurrentResult:
"""Immediately cancels all queue items except in-processing items"""
try:
return ApiDependencies.invoker.services.session_queue.cancel_all_except_current(queue_id=queue_id)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling all except current: {e}")
return ApiDependencies.invoker.services.session_queue.cancel_all_except_current(queue_id=queue_id)
@session_queue_router.put(
@@ -173,10 +156,7 @@ async def delete_all_except_current(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> DeleteAllExceptCurrentResult:
"""Immediately deletes all queue items except in-processing items"""
try:
return ApiDependencies.invoker.services.session_queue.delete_all_except_current(queue_id=queue_id)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while deleting all except current: {e}")
return ApiDependencies.invoker.services.session_queue.delete_all_except_current(queue_id=queue_id)
@session_queue_router.put(
@@ -189,12 +169,7 @@ async def cancel_by_batch_ids(
batch_ids: list[str] = Body(description="The list of batch_ids to cancel all queue items for", embed=True),
) -> CancelByBatchIDsResult:
"""Immediately cancels all queue items from the given batch ids"""
try:
return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(
queue_id=queue_id, batch_ids=batch_ids
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling by batch id: {e}")
return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(queue_id=queue_id, batch_ids=batch_ids)
@session_queue_router.put(
@@ -207,12 +182,9 @@ async def cancel_by_destination(
destination: str = Query(description="The destination to cancel all queue items for"),
) -> CancelByDestinationResult:
"""Immediately cancels all queue items with the given origin"""
try:
return ApiDependencies.invoker.services.session_queue.cancel_by_destination(
queue_id=queue_id, destination=destination
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling by destination: {e}")
return ApiDependencies.invoker.services.session_queue.cancel_by_destination(
queue_id=queue_id, destination=destination
)
@session_queue_router.put(
@@ -225,10 +197,7 @@ async def retry_items_by_id(
item_ids: list[int] = Body(description="The queue item ids to retry"),
) -> RetryItemsResult:
"""Immediately cancels all queue items with the given origin"""
try:
return ApiDependencies.invoker.services.session_queue.retry_items_by_id(queue_id=queue_id, item_ids=item_ids)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while retrying queue items: {e}")
return ApiDependencies.invoker.services.session_queue.retry_items_by_id(queue_id=queue_id, item_ids=item_ids)
@session_queue_router.put(
@@ -242,14 +211,11 @@ async def clear(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> ClearResult:
"""Clears the queue entirely, immediately canceling the currently-executing session"""
try:
queue_item = ApiDependencies.invoker.services.session_queue.get_current(queue_id)
if queue_item is not None:
ApiDependencies.invoker.services.session_queue.cancel_queue_item(queue_item.item_id)
clear_result = ApiDependencies.invoker.services.session_queue.clear(queue_id)
return clear_result
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while clearing queue: {e}")
queue_item = ApiDependencies.invoker.services.session_queue.get_current(queue_id)
if queue_item is not None:
ApiDependencies.invoker.services.session_queue.cancel_queue_item(queue_item.item_id)
clear_result = ApiDependencies.invoker.services.session_queue.clear(queue_id)
return clear_result
@session_queue_router.put(
@@ -263,10 +229,7 @@ async def prune(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> PruneResult:
"""Prunes all completed or errored queue items"""
try:
return ApiDependencies.invoker.services.session_queue.prune(queue_id)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while pruning queue: {e}")
return ApiDependencies.invoker.services.session_queue.prune(queue_id)
@session_queue_router.get(
@@ -280,10 +243,7 @@ async def get_current_queue_item(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> Optional[SessionQueueItem]:
"""Gets the currently execution queue item"""
try:
return ApiDependencies.invoker.services.session_queue.get_current(queue_id)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while getting current queue item: {e}")
return ApiDependencies.invoker.services.session_queue.get_current(queue_id)
@session_queue_router.get(
@@ -297,10 +257,7 @@ async def get_next_queue_item(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> Optional[SessionQueueItem]:
"""Gets the next queue item, without executing it"""
try:
return ApiDependencies.invoker.services.session_queue.get_next(queue_id)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while getting next queue item: {e}")
return ApiDependencies.invoker.services.session_queue.get_next(queue_id)
@session_queue_router.get(
@@ -314,12 +271,9 @@ async def get_queue_status(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> SessionQueueAndProcessorStatus:
"""Gets the status of the session queue"""
try:
queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id)
processor = ApiDependencies.invoker.services.session_processor.get_status()
return SessionQueueAndProcessorStatus(queue=queue, processor=processor)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while getting queue status: {e}")
queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id)
processor = ApiDependencies.invoker.services.session_processor.get_status()
return SessionQueueAndProcessorStatus(queue=queue, processor=processor)
@session_queue_router.get(
@@ -334,10 +288,7 @@ async def get_batch_status(
batch_id: str = Path(description="The batch to get the status of"),
) -> BatchStatus:
"""Gets the status of the session queue"""
try:
return ApiDependencies.invoker.services.session_queue.get_batch_status(queue_id=queue_id, batch_id=batch_id)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while getting batch status: {e}")
return ApiDependencies.invoker.services.session_queue.get_batch_status(queue_id=queue_id, batch_id=batch_id)
@session_queue_router.get(
@@ -353,12 +304,7 @@ async def get_queue_item(
item_id: int = Path(description="The queue item to get"),
) -> SessionQueueItem:
"""Gets a queue item"""
try:
return ApiDependencies.invoker.services.session_queue.get_queue_item(item_id)
except SessionQueueItemNotFoundError:
raise HTTPException(status_code=404, detail=f"Queue item with id {item_id} not found in queue {queue_id}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while fetching queue item: {e}")
return ApiDependencies.invoker.services.session_queue.get_queue_item(item_id)
@session_queue_router.delete(
@@ -370,10 +316,7 @@ async def delete_queue_item(
item_id: int = Path(description="The queue item to delete"),
) -> None:
"""Deletes a queue item"""
try:
ApiDependencies.invoker.services.session_queue.delete_queue_item(item_id)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while deleting queue item: {e}")
ApiDependencies.invoker.services.session_queue.delete_queue_item(item_id)
@session_queue_router.put(
@@ -388,12 +331,8 @@ async def cancel_queue_item(
item_id: int = Path(description="The queue item to cancel"),
) -> SessionQueueItem:
"""Deletes a queue item"""
try:
return ApiDependencies.invoker.services.session_queue.cancel_queue_item(item_id)
except SessionQueueItemNotFoundError:
raise HTTPException(status_code=404, detail=f"Queue item with id {item_id} not found in queue {queue_id}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling queue item: {e}")
return ApiDependencies.invoker.services.session_queue.cancel_queue_item(item_id)
@session_queue_router.get(
@@ -406,12 +345,9 @@ async def counts_by_destination(
destination: str = Query(description="The destination to query"),
) -> SessionQueueCountsByDestination:
"""Gets the counts of queue items by destination"""
try:
return ApiDependencies.invoker.services.session_queue.get_counts_by_destination(
queue_id=queue_id, destination=destination
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while fetching counts by destination: {e}")
return ApiDependencies.invoker.services.session_queue.get_counts_by_destination(
queue_id=queue_id, destination=destination
)
@session_queue_router.delete(
@@ -424,9 +360,6 @@ async def delete_by_destination(
destination: str = Path(description="The destination to query"),
) -> DeleteByDestinationResult:
"""Deletes all items with the given destination"""
try:
return ApiDependencies.invoker.services.session_queue.delete_by_destination(
queue_id=queue_id, destination=destination
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while deleting by destination: {e}")
return ApiDependencies.invoker.services.session_queue.delete_by_destination(
queue_id=queue_id, destination=destination
)

View File

@@ -64,7 +64,6 @@ class UIType(str, Enum, metaclass=MetaEnum):
Imagen3Model = "Imagen3ModelField"
Imagen4Model = "Imagen4ModelField"
ChatGPT4oModel = "ChatGPT4oModelField"
FluxKontextModel = "FluxKontextModelField"
# endregion
# region Misc Field Types
@@ -215,7 +214,6 @@ class FieldDescriptions:
flux_redux_conditioning = "FLUX Redux conditioning tensor"
vllm_model = "The VLLM model to use"
flux_fill_conditioning = "FLUX Fill conditioning tensor"
flux_kontext_conditioning = "FLUX Kontext conditioning (reference image)"
class ImageField(BaseModel):
@@ -292,12 +290,6 @@ class FluxFillConditioningField(BaseModel):
mask: TensorField = Field(description="The FLUX Fill inpaint mask.")
class FluxKontextConditioningField(BaseModel):
"""A conditioning field for FLUX Kontext (reference image)."""
image: ImageField = Field(description="The Kontext reference image.")
class SD3ConditioningField(BaseModel):
"""A conditioning tensor primitive value"""

View File

@@ -16,12 +16,13 @@ from invokeai.app.invocations.fields import (
FieldDescriptions,
FluxConditioningField,
FluxFillConditioningField,
FluxKontextConditioningField,
FluxReduxConditioningField,
ImageField,
Input,
InputField,
LatentsField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.flux_controlnet import FluxControlNetField
from invokeai.app.invocations.flux_vae_encode import FluxVaeEncodeInvocation
@@ -33,7 +34,6 @@ from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXCo
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux
from invokeai.backend.flux.denoise import denoise
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
from invokeai.backend.flux.extensions.kontext_extension import KontextExtension
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
@@ -63,9 +63,9 @@ from invokeai.backend.util.devices import TorchDevice
title="FLUX Denoise",
tags=["image", "flux"],
category="image",
version="4.0.0",
version="3.3.0",
)
class FluxDenoiseInvocation(BaseInvocation):
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Run denoising process with a FLUX transformer model."""
# If latents is provided, this means we are doing image-to-image.
@@ -145,20 +145,11 @@ class FluxDenoiseInvocation(BaseInvocation):
description=FieldDescriptions.vae,
input=Input.Connection,
)
# This node accepts a images for features like FLUX Fill, ControlNet, and Kontext, but needs to operate on them in
# latent space. We'll run the VAE to encode them in this node instead of requiring the user to run the VAE in
# upstream nodes.
ip_adapter: IPAdapterField | list[IPAdapterField] | None = InputField(
description=FieldDescriptions.ip_adapter, title="IP-Adapter", default=None, input=Input.Connection
)
kontext_conditioning: Optional[FluxKontextConditioningField] = InputField(
default=None,
description="FLUX Kontext conditioning (reference image).",
input=Input.Connection,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = self._run_diffusion(context)
@@ -385,27 +376,6 @@ class FluxDenoiseInvocation(BaseInvocation):
dtype=inference_dtype,
)
kontext_extension = None
if self.kontext_conditioning is not None:
if not self.controlnet_vae:
raise ValueError("A VAE (e.g., controlnet_vae) must be provided to use Kontext conditioning.")
kontext_extension = KontextExtension(
context=context,
kontext_conditioning=self.kontext_conditioning,
vae_field=self.controlnet_vae,
device=TorchDevice.choose_torch_device(),
dtype=inference_dtype,
)
# Prepare Kontext conditioning if provided
img_cond_seq = None
img_cond_seq_ids = None
if kontext_extension is not None:
# Ensure batch sizes match
kontext_extension.ensure_batch_size(x.shape[0])
img_cond_seq, img_cond_seq_ids = kontext_extension.kontext_latents, kontext_extension.kontext_ids
x = denoise(
model=transformer,
img=x,
@@ -421,8 +391,6 @@ class FluxDenoiseInvocation(BaseInvocation):
pos_ip_adapter_extensions=pos_ip_adapter_extensions,
neg_ip_adapter_extensions=neg_ip_adapter_extensions,
img_cond=img_cond,
img_cond_seq=img_cond_seq,
img_cond_seq_ids=img_cond_seq_ids,
)
x = unpack(x.float(), self.height, self.width)
@@ -897,10 +865,7 @@ class FluxDenoiseInvocation(BaseInvocation):
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
def step_callback(state: PipelineIntermediateState) -> None:
# The denoise function now handles Kontext conditioning correctly,
# so we don't need to slice the latents here
latents = state.latents.float()
state.latents = unpack(latents, self.height, self.width).squeeze()
state.latents = unpack(state.latents.float(), self.height, self.width).squeeze()
context.util.flux_step_callback(state)
return step_callback

View File

@@ -1,40 +0,0 @@
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import (
FieldDescriptions,
FluxKontextConditioningField,
InputField,
OutputField,
)
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.services.shared.invocation_context import InvocationContext
@invocation_output("flux_kontext_output")
class FluxKontextOutput(BaseInvocationOutput):
"""The conditioning output of a FLUX Kontext invocation."""
kontext_cond: FluxKontextConditioningField = OutputField(
description=FieldDescriptions.flux_kontext_conditioning, title="Kontext Conditioning"
)
@invocation(
"flux_kontext",
title="Kontext Conditioning - FLUX",
tags=["conditioning", "kontext", "flux"],
category="conditioning",
version="1.0.0",
)
class FluxKontextInvocation(BaseInvocation):
"""Prepares a reference image for FLUX Kontext conditioning."""
image: ImageField = InputField(description="The Kontext reference image.")
def invoke(self, context: InvocationContext) -> FluxKontextOutput:
"""Packages the provided image into a Kontext conditioning field."""
return FluxKontextOutput(kontext_cond=FluxKontextConditioningField(image=self.image))

View File

@@ -1,5 +1,5 @@
from contextlib import ExitStack
from typing import Iterator, Literal, Optional, Tuple, Union
from typing import Iterator, Literal, Optional, Tuple
import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer, T5TokenizerFast
@@ -111,9 +111,6 @@ class FluxTextEncoderInvocation(BaseInvocation):
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)
if context.config.get().log_tokenization:
self._log_t5_tokenization(context, t5_tokenizer)
context.util.signal_progress("Running T5 encoder")
prompt_embeds = t5_encoder(prompt)
@@ -154,9 +151,6 @@ class FluxTextEncoderInvocation(BaseInvocation):
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
if context.config.get().log_tokenization:
self._log_clip_tokenization(context, clip_tokenizer)
context.util.signal_progress("Running CLIP encoder")
pooled_prompt_embeds = clip_encoder(prompt)
@@ -176,88 +170,3 @@ class FluxTextEncoderInvocation(BaseInvocation):
assert isinstance(lora_info.model, ModelPatchRaw)
yield (lora_info.model, lora.weight)
del lora_info
def _log_t5_tokenization(
self,
context: InvocationContext,
tokenizer: Union[T5Tokenizer, T5TokenizerFast],
) -> None:
"""Logs the tokenization of a prompt for a T5-based model like FLUX."""
# Tokenize the prompt using the same parameters as the model's text encoder.
# T5 tokenizers add an EOS token (</s>) and then pad to max_length.
tokenized_output = tokenizer(
self.prompt,
padding="max_length",
max_length=self.t5_max_seq_len,
truncation=True,
add_special_tokens=True, # This is important for T5 to add the EOS token.
return_tensors="pt",
)
input_ids = tokenized_output.input_ids[0]
tokens = tokenizer.convert_ids_to_tokens(input_ids)
# The T5 tokenizer uses a space-like character ' ' (U+2581) to denote spaces.
# We'll replace it with a regular space for readability.
tokens = [t.replace("\u2581", " ") for t in tokens]
tokenized_str = ""
used_tokens = 0
for token in tokens:
if token == tokenizer.eos_token:
tokenized_str += f"\x1b[0;31m{token}\x1b[0m" # Red for EOS
used_tokens += 1
elif token == tokenizer.pad_token:
# tokenized_str += f"\x1b[0;34m{token}\x1b[0m" # Blue for PAD
continue
else:
color = (used_tokens % 6) + 1 # Cycle through 6 colors
tokenized_str += f"\x1b[0;3{color}m{token}\x1b[0m"
used_tokens += 1
context.logger.info(f">> [T5 TOKENLOG] Tokens ({used_tokens}/{self.t5_max_seq_len}):")
context.logger.info(f"{tokenized_str}\x1b[0m")
def _log_clip_tokenization(
self,
context: InvocationContext,
tokenizer: CLIPTokenizer,
) -> None:
"""Logs the tokenization of a prompt for a CLIP-based model."""
max_length = tokenizer.model_max_length
tokenized_output = tokenizer(
self.prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
input_ids = tokenized_output.input_ids[0]
attention_mask = tokenized_output.attention_mask[0]
tokens = tokenizer.convert_ids_to_tokens(input_ids)
# The CLIP tokenizer uses '</w>' to denote spaces.
# We'll replace it with a regular space for readability.
tokens = [t.replace("</w>", " ") for t in tokens]
tokenized_str = ""
used_tokens = 0
for i, token in enumerate(tokens):
if attention_mask[i] == 0:
# Do not log padding tokens.
continue
if token == tokenizer.bos_token:
tokenized_str += f"\x1b[0;32m{token}\x1b[0m" # Green for BOS
elif token == tokenizer.eos_token:
tokenized_str += f"\x1b[0;31m{token}\x1b[0m" # Red for EOS
else:
color = (used_tokens % 6) + 1 # Cycle through 6 colors
tokenized_str += f"\x1b[0;3{color}m{token}\x1b[0m"
used_tokens += 1
context.logger.info(f">> [CLIP TOKENLOG] Tokens ({used_tokens}/{max_length}):")
context.logger.info(f"{tokenized_str}\x1b[0m")

View File

@@ -5,7 +5,6 @@ from typing import Optional
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
ImageNamesResult,
ImageRecord,
ImageRecordChanges,
ResourceOrigin,
@@ -109,6 +108,6 @@ class ImageRecordStorageBase(ABC):
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> ImageNamesResult:
"""Gets ordered list of image names with metadata for optimistic updates."""
) -> list[str]:
"""Gets ordered list of all image names (starred first, then unstarred)."""
pass

View File

@@ -212,11 +212,3 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
class ImageCollectionCounts(BaseModel):
starred_count: int = Field(description="The number of starred images in the collection.")
unstarred_count: int = Field(description="The number of unstarred images in the collection.")
class ImageNamesResult(BaseModel):
"""Response containing ordered image names with metadata for optimistic updates."""
image_names: list[str] = Field(description="Ordered list of image names")
starred_count: int = Field(description="Number of starred images (when starred_first=True)")
total_count: int = Field(description="Total number of images matching the query")

View File

@@ -7,7 +7,6 @@ from invokeai.app.services.image_records.image_records_base import ImageRecordSt
from invokeai.app.services.image_records.image_records_common import (
IMAGE_DTO_COLS,
ImageCategory,
ImageNamesResult,
ImageRecord,
ImageRecordChanges,
ImageRecordDeleteException,
@@ -397,10 +396,17 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> ImageNamesResult:
) -> list[str]:
cursor = self._conn.cursor()
# Build query conditions (reused for both starred count and image names queries)
# Base query to get image names in order (starred first, then unstarred)
query = """--sql
SELECT images.image_name
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
query_conditions = ""
query_params: list[Union[int, str, bool]] = []
@@ -445,38 +451,22 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
query_params.append(f"%{search_term.lower()}%")
query_params.append(f"%{search_term.lower()}%")
# Get starred count if starred_first is enabled
starred_count = 0
if starred_first:
starred_count_query = f"""--sql
SELECT COUNT(*)
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE images.starred = TRUE AND (1=1{query_conditions})
"""
cursor.execute(starred_count_query, query_params)
starred_count = cast(int, cursor.fetchone()[0])
# Get all image names with proper ordering
if starred_first:
names_query = f"""--sql
SELECT images.image_name
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1{query_conditions}
query += (
query_conditions
+ f"""--sql
ORDER BY images.starred DESC, images.created_at {order_dir.value}
"""
)
else:
names_query = f"""--sql
SELECT images.image_name
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1{query_conditions}
query += (
query_conditions
+ f"""--sql
ORDER BY images.created_at {order_dir.value}
"""
)
cursor.execute(names_query, query_params)
cursor.execute(query, query_params)
result = cast(list[sqlite3.Row], cursor.fetchall())
image_names = [row[0] for row in result]
return ImageNamesResult(image_names=image_names, starred_count=starred_count, total_count=len(image_names))
return [row[0] for row in result]

View File

@@ -6,7 +6,6 @@ from PIL.Image import Image as PILImageType
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
ImageNamesResult,
ImageRecord,
ImageRecordChanges,
ResourceOrigin,
@@ -159,6 +158,6 @@ class ImageServiceABC(ABC):
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> ImageNamesResult:
"""Gets ordered list of image names with metadata for optimistic updates."""
) -> list[str]:
"""Gets ordered list of all image names."""
pass

View File

@@ -10,7 +10,6 @@ from invokeai.app.services.image_files.image_files_common import (
)
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
ImageNamesResult,
ImageRecord,
ImageRecordChanges,
ImageRecordDeleteException,
@@ -320,7 +319,7 @@ class ImageService(ImageServiceABC):
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> ImageNamesResult:
) -> list[str]:
try:
return self.__invoker.services.image_records.get_image_names(
starred_first=starred_first,

View File

@@ -205,7 +205,6 @@ class FieldIdentifier(BaseModel):
kind: Literal["input", "output"] = Field(description="The kind of field")
node_id: str = Field(description="The ID of the node")
field_name: str = Field(description="The name of the field")
user_label: str | None = Field(description="The user label of the field, if any")
class SessionQueueItem(BaseModel):
@@ -332,7 +331,6 @@ class EnqueueBatchResult(BaseModel):
requested: int = Field(description="The total number of queue items requested to be enqueued")
batch: Batch = Field(description="The batch that was enqueued")
priority: int = Field(description="The priority of the enqueued batch")
item_ids: list[int] = Field(description="The IDs of the queue items that were enqueued")
class RetryItemsResult(BaseModel):

View File

@@ -133,18 +133,6 @@ class SqliteSessionQueue(SessionQueueBase):
""",
values_to_insert,
)
with self._conn:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT item_id
FROM session_queue
WHERE batch_id = ?
ORDER BY item_id DESC;
""",
(batch.batch_id,),
)
item_ids = [row[0] for row in cursor.fetchall()]
except Exception:
raise
enqueue_result = EnqueueBatchResult(
@@ -153,7 +141,6 @@ class SqliteSessionQueue(SessionQueueBase):
enqueued=enqueued_count,
batch=batch,
priority=priority,
item_ids=item_ids,
)
self.__invoker.services.events.emit_batch_enqueued(enqueue_result)
return enqueue_result
@@ -404,8 +391,6 @@ class SqliteSessionQueue(SessionQueueBase):
AND status != 'canceled'
AND status != 'completed'
AND status != 'failed'
-- We will cancel the current item separately below - skip it here
AND status != 'in_progress'
"""
params = [queue_id] + batch_ids
cursor.execute(
@@ -444,8 +429,6 @@ class SqliteSessionQueue(SessionQueueBase):
AND status != 'canceled'
AND status != 'completed'
AND status != 'failed'
-- We will cancel the current item separately below - skip it here
AND status != 'in_progress'
"""
params = (queue_id, destination)
cursor.execute(
@@ -548,8 +531,6 @@ class SqliteSessionQueue(SessionQueueBase):
AND status != 'canceled'
AND status != 'completed'
AND status != 'failed'
-- We will cancel the current item separately below - skip it here
AND status != 'in_progress'
"""
params = [queue_id]
cursor.execute(
@@ -570,9 +551,12 @@ class SqliteSessionQueue(SessionQueueBase):
tuple(params),
)
self._conn.commit()
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_id)
self.__invoker.services.events.emit_queue_item_status_changed(
current_queue_item, batch_status, queue_status
)
except Exception:
self._conn.rollback()
raise
@@ -743,7 +727,7 @@ class SqliteSessionQueue(SessionQueueBase):
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
current_item = self.get_current(queue_id=queue_id)
total = sum(row[1] or 0 for row in counts_result)
total = sum(row[1] for row in counts_result)
counts: dict[str, int] = {row[0]: row[1] for row in counts_result}
return SessionQueueStatus(
queue_id=queue_id,
@@ -772,7 +756,7 @@ class SqliteSessionQueue(SessionQueueBase):
(queue_id, batch_id),
)
result = cast(list[sqlite3.Row], cursor.fetchall())
total = sum(row[1] or 0 for row in result)
total = sum(row[1] for row in result)
counts: dict[str, int] = {row[0]: row[1] for row in result}
origin = result[0]["origin"] if result else None
destination = result[0]["destination"] if result else None
@@ -804,7 +788,7 @@ class SqliteSessionQueue(SessionQueueBase):
)
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
total = sum(row[1] or 0 for row in counts_result)
total = sum(row[1] for row in counts_result)
counts: dict[str, int] = {row[0]: row[1] for row in counts_result}
return SessionQueueCountsByDestination(

View File

@@ -2,7 +2,7 @@
import copy
import itertools
from typing import Any, Optional, TypeVar, Union, get_args, get_origin
from typing import Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
import networkx as nx
from pydantic import (
@@ -58,32 +58,17 @@ class Edge(BaseModel):
def get_output_field_type(node: BaseInvocation, field: str) -> Any:
# TODO(psyche): This is awkward - if field_info is None, it means the field is not defined in the output, which
# really should raise. The consumers of this utility expect it to never raise, and return None instead. Fixing this
# would require some fairly significant changes and I don't want risk breaking anything.
try:
invocation_class = type(node)
invocation_output_class = invocation_class.get_output_annotation()
field_info = invocation_output_class.model_fields.get(field)
assert field_info is not None, f"Output field '{field}' not found in {invocation_output_class.get_type()}"
output_field_type = field_info.annotation
return output_field_type
except Exception:
return None
node_type = type(node)
node_outputs = get_type_hints(node_type.get_output_annotation())
node_output_field = node_outputs.get(field) or None
return node_output_field
def get_input_field_type(node: BaseInvocation, field: str) -> Any:
# TODO(psyche): This is awkward - if field_info is None, it means the field is not defined in the output, which
# really should raise. The consumers of this utility expect it to never raise, and return None instead. Fixing this
# would require some fairly significant changes and I don't want risk breaking anything.
try:
invocation_class = type(node)
field_info = invocation_class.model_fields.get(field)
assert field_info is not None, f"Input field '{field}' not found in {invocation_class.get_type()}"
input_field_type = field_info.annotation
return input_field_type
except Exception:
return None
node_type = type(node)
node_inputs = get_type_hints(node_type)
node_input_field = node_inputs.get(field) or None
return node_input_field
def is_union_subtype(t1, t2):
@@ -1007,11 +992,10 @@ class GraphExecutionState(BaseModel):
new_node_ids = []
if isinstance(next_node, CollectInvocation):
# Collapse all iterator input mappings and create a single execution node for the collect invocation
all_iteration_mappings = []
for source_node_id in next_node_parents:
prepared_nodes = self.source_prepared_mapping[source_node_id]
all_iteration_mappings.extend([(source_node_id, p) for p in prepared_nodes])
all_iteration_mappings = list(
itertools.chain(*(((s, p) for p in self.source_prepared_mapping[s]) for s in next_node_parents))
)
# all_iteration_mappings = list(set(itertools.chain(*prepared_parent_mappings)))
create_results = self._create_execution_node(next_node_id, all_iteration_mappings)
if create_results is not None:
new_node_ids.extend(create_results)

View File

@@ -123,11 +123,7 @@ def calc_percentage(intermediate_state: PipelineIntermediateState) -> float:
if total_steps == 0:
return 0.0
if order == 2:
# Prevent division by zero when total_steps is 1 or 2
denominator = floor(total_steps / 2)
if denominator == 0:
return 0.0
return floor(step / 2) / denominator
return floor(step / 2) / floor(total_steps / 2)
# order == 1
return step / total_steps

View File

@@ -30,11 +30,8 @@ def denoise(
controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension],
pos_ip_adapter_extensions: list[XLabsIPAdapterExtension],
neg_ip_adapter_extensions: list[XLabsIPAdapterExtension],
# extra img tokens (channel-wise)
# extra img tokens
img_cond: torch.Tensor | None,
# extra img tokens (sequence-wise) - for Kontext conditioning
img_cond_seq: torch.Tensor | None = None,
img_cond_seq_ids: torch.Tensor | None = None,
):
# step 0 is the initial state
total_steps = len(timesteps) - 1
@@ -49,10 +46,6 @@ def denoise(
)
# guidance_vec is ignored for schnell.
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
# Store original sequence length for slicing predictions
original_seq_len = img.shape[1]
for step_index, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
@@ -78,26 +71,10 @@ def denoise(
# controlnet_residuals datastructure is efficient in that it likely contains multiple references to the same
# tensors. Calculating the sum materializes each tensor into its own instance.
merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals)
# Prepare input for model - concatenate fresh each step
img_input = img
img_input_ids = img_ids
# Add channel-wise conditioning (for ControlNet, FLUX Fill, etc.)
if img_cond is not None:
img_input = torch.cat((img_input, img_cond), dim=-1)
# Add sequence-wise conditioning (for Kontext)
if img_cond_seq is not None:
assert img_cond_seq_ids is not None, (
"You need to provide either both or neither of the sequence conditioning"
)
img_input = torch.cat((img_input, img_cond_seq), dim=1)
img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1)
pred_img = torch.cat((img, img_cond), dim=-1) if img_cond is not None else img
pred = model(
img=img_input,
img_ids=img_input_ids,
img=pred_img,
img_ids=img_ids,
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
@@ -111,10 +88,6 @@ def denoise(
regional_prompting_extension=pos_regional_prompting_extension,
)
# Slice prediction to only include the main image tokens
if img_input_ids is not None:
pred = pred[:, :original_seq_len]
step_cfg_scale = cfg_scale[step_index]
# If step_cfg_scale, is 1.0, then we don't need to run the negative prediction.

View File

@@ -1,149 +0,0 @@
import einops
import numpy as np
import torch
from einops import repeat
from PIL import Image
from invokeai.app.invocations.fields import FluxKontextConditioningField
from invokeai.app.invocations.flux_vae_encode import FluxVaeEncodeInvocation
from invokeai.app.invocations.model import VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.sampling_utils import pack
from invokeai.backend.flux.util import PREFERED_KONTEXT_RESOLUTIONS
def generate_img_ids_with_offset(
latent_height: int,
latent_width: int,
batch_size: int,
device: torch.device,
dtype: torch.dtype,
idx_offset: int = 0,
) -> torch.Tensor:
"""Generate tensor of image position ids with an optional offset.
Args:
latent_height (int): Height of image in latent space (after packing, this becomes h//2).
latent_width (int): Width of image in latent space (after packing, this becomes w//2).
batch_size (int): Number of images in the batch.
device (torch.device): Device to create tensors on.
dtype (torch.dtype): Data type for the tensors.
idx_offset (int): Offset to add to the first dimension of the image ids.
Returns:
torch.Tensor: Image position ids with shape [batch_size, (latent_height//2 * latent_width//2), 3].
"""
if device.type == "mps":
orig_dtype = dtype
dtype = torch.float16
# After packing, the spatial dimensions are halved due to the 2x2 patch structure
packed_height = latent_height // 2
packed_width = latent_width // 2
# Create base tensor for position IDs with shape [packed_height, packed_width, 3]
# The 3 channels represent: [batch_offset, y_position, x_position]
img_ids = torch.zeros(packed_height, packed_width, 3, device=device, dtype=dtype)
# Set the batch offset for all positions
img_ids[..., 0] = idx_offset
# Create y-coordinate indices (vertical positions)
y_indices = torch.arange(packed_height, device=device, dtype=dtype)
# Broadcast y_indices to match the spatial dimensions [packed_height, 1]
img_ids[..., 1] = y_indices[:, None]
# Create x-coordinate indices (horizontal positions)
x_indices = torch.arange(packed_width, device=device, dtype=dtype)
# Broadcast x_indices to match the spatial dimensions [1, packed_width]
img_ids[..., 2] = x_indices[None, :]
# Expand to include batch dimension: [batch_size, (packed_height * packed_width), 3]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
if device.type == "mps":
img_ids = img_ids.to(orig_dtype)
return img_ids
class KontextExtension:
"""Applies FLUX Kontext (reference image) conditioning."""
def __init__(
self,
kontext_conditioning: FluxKontextConditioningField,
context: InvocationContext,
vae_field: VAEField,
device: torch.device,
dtype: torch.dtype,
):
"""
Initializes the KontextExtension, pre-processing the reference image
into latents and positional IDs.
"""
self._context = context
self._device = device
self._dtype = dtype
self._vae_field = vae_field
self.kontext_conditioning = kontext_conditioning
# Pre-process and cache the kontext latents and ids upon initialization.
self.kontext_latents, self.kontext_ids = self._prepare_kontext()
def _prepare_kontext(self) -> tuple[torch.Tensor, torch.Tensor]:
"""Encodes the reference image and prepares its latents and IDs."""
image = self._context.images.get_pil(self.kontext_conditioning.image.image_name)
# Calculate aspect ratio of input image
width, height = image.size
aspect_ratio = width / height
# Find the closest preferred resolution by aspect ratio
_, target_width, target_height = min(
((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS), key=lambda x: x[0]
)
# Apply BFL's scaling formula
# This ensures compatibility with the model's training
scaled_width = 2 * int(target_width / 16)
scaled_height = 2 * int(target_height / 16)
# Resize to the exact resolution used during training
image = image.convert("RGB")
final_width = 8 * scaled_width
final_height = 8 * scaled_height
image = image.resize((final_width, final_height), Image.Resampling.LANCZOS)
# Convert to tensor with same normalization as BFL
image_np = np.array(image)
image_tensor = torch.from_numpy(image_np).float() / 127.5 - 1.0
image_tensor = einops.rearrange(image_tensor, "h w c -> 1 c h w")
image_tensor = image_tensor.to(self._device)
# Continue with VAE encoding
vae_info = self._context.models.load(self._vae_field.vae)
kontext_latents_unpacked = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
# Extract tensor dimensions
batch_size, _, latent_height, latent_width = kontext_latents_unpacked.shape
# Pack the latents and generate IDs
kontext_latents_packed = pack(kontext_latents_unpacked).to(self._device, self._dtype)
kontext_ids = generate_img_ids_with_offset(
latent_height=latent_height,
latent_width=latent_width,
batch_size=batch_size,
device=self._device,
dtype=self._dtype,
idx_offset=1,
)
return kontext_latents_packed, kontext_ids
def ensure_batch_size(self, target_batch_size: int) -> None:
"""Ensures the kontext latents and IDs match the target batch size by repeating if necessary."""
if self.kontext_latents.shape[0] != target_batch_size:
self.kontext_latents = self.kontext_latents.repeat(target_batch_size, 1, 1)
self.kontext_ids = self.kontext_ids.repeat(target_batch_size, 1, 1)

View File

@@ -174,13 +174,11 @@ def generate_img_ids(h: int, w: int, batch_size: int, device: torch.device, dtyp
dtype = torch.float16
img_ids = torch.zeros(h // 2, w // 2, 3, device=device, dtype=dtype)
# Set batch offset to 0 for main image tokens
img_ids[..., 0] = 0
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=device, dtype=dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=device, dtype=dtype)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
if device.type == "mps":
img_ids = img_ids.to(orig_dtype)
img_ids.to(orig_dtype)
return img_ids

View File

@@ -18,29 +18,6 @@ class ModelSpec:
repo_ae: str | None
# Preferred resolutions for Kontext models to avoid tiling artifacts
# These are the specific resolutions the model was trained on
PREFERED_KONTEXT_RESOLUTIONS = [
(672, 1568),
(688, 1504),
(720, 1456),
(752, 1392),
(800, 1328),
(832, 1248),
(880, 1184),
(944, 1104),
(1024, 1024),
(1104, 944),
(1184, 880),
(1248, 832),
(1328, 800),
(1392, 752),
(1456, 720),
(1504, 688),
(1568, 672),
]
max_seq_lengths: Dict[str, Literal[256, 512]] = {
"flux-dev": 512,
"flux-dev-fill": 512,

View File

@@ -37,7 +37,6 @@ from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_hash.hash_validator import validate_hash
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.omi import flux_dev_1_lora, stable_diffusion_xl_1_lora
from invokeai.backend.model_manager.taxonomy import (
AnyVariant,
BaseModelType,
@@ -353,14 +352,15 @@ class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase):
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
metadata = mod.metadata()
architecture = metadata["modelspec.architecture"]
base_str, _ = metadata["modelspec.architecture"].split("/")
base_str = base_str.lower()
if architecture == stable_diffusion_xl_1_lora:
if "stable-diffusion-xl-v1-base" in base_str:
base = BaseModelType.StableDiffusionXL
elif architecture == flux_dev_1_lora:
elif "flux" in base_str:
base = BaseModelType.Flux
else:
raise InvalidModelConfigException(f"Unrecognised/unsupported architecture for OMI LoRA: {architecture}")
raise InvalidModelConfigException(f"Unrecognised/unsupported base architecture for OMI LoRA: {base_str}")
return {"base": base}

View File

@@ -7,14 +7,7 @@ from typing import Optional
import accelerate
import torch
from safetensors.torch import load_file
from transformers import (
AutoConfig,
AutoModelForTextEncoding,
CLIPTextModel,
CLIPTokenizer,
T5EncoderModel,
T5TokenizerFast,
)
from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
@@ -146,7 +139,7 @@ class BnbQuantizedLlmInt8bCheckpointModel(ModelLoader):
)
match submodel_type:
case SubModelType.Tokenizer2 | SubModelType.Tokenizer3:
return T5TokenizerFast.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
case SubModelType.TextEncoder2 | SubModelType.TextEncoder3:
te2_model_path = Path(config.path) / "text_encoder_2"
model_config = AutoConfig.from_pretrained(te2_model_path)
@@ -190,7 +183,7 @@ class T5EncoderCheckpointModel(ModelLoader):
match submodel_type:
case SubModelType.Tokenizer2 | SubModelType.Tokenizer3:
return T5TokenizerFast.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
case SubModelType.TextEncoder2 | SubModelType.TextEncoder3:
return T5EncoderModel.from_pretrained(
Path(config.path) / "text_encoder_2", torch_dtype="auto", low_cpu_mem_usage=True

View File

@@ -13,7 +13,7 @@ from invokeai.backend.model_manager.config import AnyModelConfig
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.omi.omi import convert_from_omi
from invokeai.backend.model_manager.omi import convert_from_omi
from invokeai.backend.model_manager.taxonomy import (
AnyModel,
BaseModelType,

View File

@@ -0,0 +1,16 @@
import omi_model_standards.convert.lora.convert_lora_util as lora_util
from omi_model_standards.convert.lora.convert_flux_lora import convert_flux_lora_key_sets
from omi_model_standards.convert.lora.convert_sdxl_lora import convert_sdxl_lora_key_sets
from invokeai.backend.model_manager.model_on_disk import StateDict
from invokeai.backend.model_manager.taxonomy import BaseModelType
def convert_from_omi(weights_sd: StateDict, base: BaseModelType):
keyset = {
BaseModelType.Flux: convert_flux_lora_key_sets(),
BaseModelType.StableDiffusionXL: convert_sdxl_lora_key_sets(),
}[base]
source = "omi"
target = "legacy_diffusers"
return lora_util.__convert(weights_sd, keyset, source, target) # type: ignore

View File

@@ -1,7 +0,0 @@
from invokeai.backend.model_manager.omi.omi import convert_from_omi
from invokeai.backend.model_manager.omi.vendor.model_spec.architecture import (
flux_dev_1_lora,
stable_diffusion_xl_1_lora,
)
__all__ = ["flux_dev_1_lora", "stable_diffusion_xl_1_lora", "convert_from_omi"]

View File

@@ -1,21 +0,0 @@
from invokeai.backend.model_manager.model_on_disk import StateDict
from invokeai.backend.model_manager.omi.vendor.convert.lora import (
convert_flux_lora as omi_flux,
)
from invokeai.backend.model_manager.omi.vendor.convert.lora import (
convert_lora_util as lora_util,
)
from invokeai.backend.model_manager.omi.vendor.convert.lora import (
convert_sdxl_lora as omi_sdxl,
)
from invokeai.backend.model_manager.taxonomy import BaseModelType
def convert_from_omi(weights_sd: StateDict, base: BaseModelType):
keyset = {
BaseModelType.Flux: omi_flux.convert_flux_lora_key_sets(),
BaseModelType.StableDiffusionXL: omi_sdxl.convert_sdxl_lora_key_sets(),
}[base]
source = "omi"
target = "legacy_diffusers"
return lora_util.__convert(weights_sd, keyset, source, target) # type: ignore

View File

@@ -1,20 +0,0 @@
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_lora_util import (
LoraConversionKeySet,
map_prefix_range,
)
def map_clip(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
keys = []
keys += [LoraConversionKeySet("text_projection", "text_projection", parent=key_prefix)]
for k in map_prefix_range("text_model.encoder.layers", "text_model.encoder.layers", parent=key_prefix):
keys += [LoraConversionKeySet("mlp.fc1", "mlp.fc1", parent=k)]
keys += [LoraConversionKeySet("mlp.fc2", "mlp.fc2", parent=k)]
keys += [LoraConversionKeySet("self_attn.k_proj", "self_attn.k_proj", parent=k)]
keys += [LoraConversionKeySet("self_attn.out_proj", "self_attn.out_proj", parent=k)]
keys += [LoraConversionKeySet("self_attn.q_proj", "self_attn.q_proj", parent=k)]
keys += [LoraConversionKeySet("self_attn.v_proj", "self_attn.v_proj", parent=k)]
return keys

View File

@@ -1,84 +0,0 @@
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_clip import map_clip
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_lora_util import (
LoraConversionKeySet,
map_prefix_range,
)
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_t5 import map_t5
def __map_double_transformer_block(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
keys = []
keys += [LoraConversionKeySet("img_attn.qkv.0", "attn.to_q", parent=key_prefix)]
keys += [LoraConversionKeySet("img_attn.qkv.1", "attn.to_k", parent=key_prefix)]
keys += [LoraConversionKeySet("img_attn.qkv.2", "attn.to_v", parent=key_prefix)]
keys += [LoraConversionKeySet("txt_attn.qkv.0", "attn.add_q_proj", parent=key_prefix)]
keys += [LoraConversionKeySet("txt_attn.qkv.1", "attn.add_k_proj", parent=key_prefix)]
keys += [LoraConversionKeySet("txt_attn.qkv.2", "attn.add_v_proj", parent=key_prefix)]
keys += [LoraConversionKeySet("img_attn.proj", "attn.to_out.0", parent=key_prefix)]
keys += [LoraConversionKeySet("img_mlp.0", "ff.net.0.proj", parent=key_prefix)]
keys += [LoraConversionKeySet("img_mlp.2", "ff.net.2", parent=key_prefix)]
keys += [LoraConversionKeySet("img_mod.lin", "norm1.linear", parent=key_prefix)]
keys += [LoraConversionKeySet("txt_attn.proj", "attn.to_add_out", parent=key_prefix)]
keys += [LoraConversionKeySet("txt_mlp.0", "ff_context.net.0.proj", parent=key_prefix)]
keys += [LoraConversionKeySet("txt_mlp.2", "ff_context.net.2", parent=key_prefix)]
keys += [LoraConversionKeySet("txt_mod.lin", "norm1_context.linear", parent=key_prefix)]
return keys
def __map_single_transformer_block(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
keys = []
keys += [LoraConversionKeySet("linear1.0", "attn.to_q", parent=key_prefix)]
keys += [LoraConversionKeySet("linear1.1", "attn.to_k", parent=key_prefix)]
keys += [LoraConversionKeySet("linear1.2", "attn.to_v", parent=key_prefix)]
keys += [LoraConversionKeySet("linear1.3", "proj_mlp", parent=key_prefix)]
keys += [LoraConversionKeySet("linear2", "proj_out", parent=key_prefix)]
keys += [LoraConversionKeySet("modulation.lin", "norm.linear", parent=key_prefix)]
return keys
def __map_transformer(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
keys = []
keys += [LoraConversionKeySet("txt_in", "context_embedder", parent=key_prefix)]
keys += [
LoraConversionKeySet("final_layer.adaLN_modulation.1", "norm_out.linear", parent=key_prefix, swap_chunks=True)
]
keys += [LoraConversionKeySet("final_layer.linear", "proj_out", parent=key_prefix)]
keys += [
LoraConversionKeySet("guidance_in.in_layer", "time_text_embed.guidance_embedder.linear_1", parent=key_prefix)
]
keys += [
LoraConversionKeySet("guidance_in.out_layer", "time_text_embed.guidance_embedder.linear_2", parent=key_prefix)
]
keys += [LoraConversionKeySet("vector_in.in_layer", "time_text_embed.text_embedder.linear_1", parent=key_prefix)]
keys += [LoraConversionKeySet("vector_in.out_layer", "time_text_embed.text_embedder.linear_2", parent=key_prefix)]
keys += [LoraConversionKeySet("time_in.in_layer", "time_text_embed.timestep_embedder.linear_1", parent=key_prefix)]
keys += [LoraConversionKeySet("time_in.out_layer", "time_text_embed.timestep_embedder.linear_2", parent=key_prefix)]
keys += [LoraConversionKeySet("img_in.proj", "x_embedder", parent=key_prefix)]
for k in map_prefix_range("double_blocks", "transformer_blocks", parent=key_prefix):
keys += __map_double_transformer_block(k)
for k in map_prefix_range("single_blocks", "single_transformer_blocks", parent=key_prefix):
keys += __map_single_transformer_block(k)
return keys
def convert_flux_lora_key_sets() -> list[LoraConversionKeySet]:
keys = []
keys += [LoraConversionKeySet("bundle_emb", "bundle_emb")]
keys += __map_transformer(LoraConversionKeySet("transformer", "lora_transformer"))
keys += map_clip(LoraConversionKeySet("clip_l", "lora_te1"))
keys += map_t5(LoraConversionKeySet("t5", "lora_te2"))
return keys

View File

@@ -1,217 +0,0 @@
import torch
from torch import Tensor
from typing_extensions import Self
class LoraConversionKeySet:
def __init__(
self,
omi_prefix: str,
diffusers_prefix: str,
legacy_diffusers_prefix: str | None = None,
parent: Self | None = None,
swap_chunks: bool = False,
filter_is_last: bool | None = None,
next_omi_prefix: str | None = None,
next_diffusers_prefix: str | None = None,
):
if parent is not None:
self.omi_prefix = combine(parent.omi_prefix, omi_prefix)
self.diffusers_prefix = combine(parent.diffusers_prefix, diffusers_prefix)
else:
self.omi_prefix = omi_prefix
self.diffusers_prefix = diffusers_prefix
if legacy_diffusers_prefix is None:
self.legacy_diffusers_prefix = self.diffusers_prefix.replace(".", "_")
elif parent is not None:
self.legacy_diffusers_prefix = combine(parent.legacy_diffusers_prefix, legacy_diffusers_prefix).replace(
".", "_"
)
else:
self.legacy_diffusers_prefix = legacy_diffusers_prefix
self.parent = parent
self.swap_chunks = swap_chunks
self.filter_is_last = filter_is_last
self.prefix = parent
if next_omi_prefix is None and parent is not None:
self.next_omi_prefix = parent.next_omi_prefix
self.next_diffusers_prefix = parent.next_diffusers_prefix
self.next_legacy_diffusers_prefix = parent.next_legacy_diffusers_prefix
elif next_omi_prefix is not None and parent is not None:
self.next_omi_prefix = combine(parent.omi_prefix, next_omi_prefix)
self.next_diffusers_prefix = combine(parent.diffusers_prefix, next_diffusers_prefix)
self.next_legacy_diffusers_prefix = combine(parent.legacy_diffusers_prefix, next_diffusers_prefix).replace(
".", "_"
)
elif next_omi_prefix is not None and parent is None:
self.next_omi_prefix = next_omi_prefix
self.next_diffusers_prefix = next_diffusers_prefix
self.next_legacy_diffusers_prefix = next_diffusers_prefix.replace(".", "_")
else:
self.next_omi_prefix = None
self.next_diffusers_prefix = None
self.next_legacy_diffusers_prefix = None
def __get_omi(self, in_prefix: str, key: str) -> str:
return self.omi_prefix + key.removeprefix(in_prefix)
def __get_diffusers(self, in_prefix: str, key: str) -> str:
return self.diffusers_prefix + key.removeprefix(in_prefix)
def __get_legacy_diffusers(self, in_prefix: str, key: str) -> str:
key = self.legacy_diffusers_prefix + key.removeprefix(in_prefix)
suffix = key[key.rfind(".") :]
if suffix not in [".alpha", ".dora_scale"]: # some keys only have a single . in the suffix
suffix = key[key.removesuffix(suffix).rfind(".") :]
key = key.removesuffix(suffix)
return key.replace(".", "_") + suffix
def get_key(self, in_prefix: str, key: str, target: str) -> str:
if target == "omi":
return self.__get_omi(in_prefix, key)
elif target == "diffusers":
return self.__get_diffusers(in_prefix, key)
elif target == "legacy_diffusers":
return self.__get_legacy_diffusers(in_prefix, key)
return key
def __str__(self) -> str:
return f"omi: {self.omi_prefix}, diffusers: {self.diffusers_prefix}, legacy: {self.legacy_diffusers_prefix}"
def combine(left: str, right: str) -> str:
left = left.rstrip(".")
right = right.lstrip(".")
if left == "" or left is None:
return right
elif right == "" or right is None:
return left
else:
return left + "." + right
def map_prefix_range(
omi_prefix: str,
diffusers_prefix: str,
parent: LoraConversionKeySet,
) -> list[LoraConversionKeySet]:
# 100 should be a safe upper bound. increase if it's not enough in the future
return [
LoraConversionKeySet(
omi_prefix=f"{omi_prefix}.{i}",
diffusers_prefix=f"{diffusers_prefix}.{i}",
parent=parent,
next_omi_prefix=f"{omi_prefix}.{i + 1}",
next_diffusers_prefix=f"{diffusers_prefix}.{i + 1}",
)
for i in range(100)
]
def __convert(
state_dict: dict[str, Tensor],
key_sets: list[LoraConversionKeySet],
source: str,
target: str,
) -> dict[str, Tensor]:
out_states = {}
if source == target:
return dict(state_dict)
# TODO: maybe replace with a non O(n^2) algorithm
for key, tensor in state_dict.items():
for key_set in key_sets:
in_prefix = ""
if source == "omi":
in_prefix = key_set.omi_prefix
elif source == "diffusers":
in_prefix = key_set.diffusers_prefix
elif source == "legacy_diffusers":
in_prefix = key_set.legacy_diffusers_prefix
if not key.startswith(in_prefix):
continue
if key_set.filter_is_last is not None:
next_prefix = None
if source == "omi":
next_prefix = key_set.next_omi_prefix
elif source == "diffusers":
next_prefix = key_set.next_diffusers_prefix
elif source == "legacy_diffusers":
next_prefix = key_set.next_legacy_diffusers_prefix
is_last = not any(k.startswith(next_prefix) for k in state_dict)
if key_set.filter_is_last != is_last:
continue
name = key_set.get_key(in_prefix, key, target)
can_swap_chunks = target == "omi" or source == "omi"
if key_set.swap_chunks and name.endswith(".lora_up.weight") and can_swap_chunks:
chunk_0, chunk_1 = tensor.chunk(2, dim=0)
tensor = torch.cat([chunk_1, chunk_0], dim=0)
out_states[name] = tensor
break # only map the first matching key set
return out_states
def __detect_source(
state_dict: dict[str, Tensor],
key_sets: list[LoraConversionKeySet],
) -> str:
omi_count = 0
diffusers_count = 0
legacy_diffusers_count = 0
for key in state_dict:
for key_set in key_sets:
if key.startswith(key_set.omi_prefix):
omi_count += 1
if key.startswith(key_set.diffusers_prefix):
diffusers_count += 1
if key.startswith(key_set.legacy_diffusers_prefix):
legacy_diffusers_count += 1
if omi_count > diffusers_count and omi_count > legacy_diffusers_count:
return "omi"
if diffusers_count > omi_count and diffusers_count > legacy_diffusers_count:
return "diffusers"
if legacy_diffusers_count > omi_count and legacy_diffusers_count > diffusers_count:
return "legacy_diffusers"
return ""
def convert_to_omi(
state_dict: dict[str, Tensor],
key_sets: list[LoraConversionKeySet],
) -> dict[str, Tensor]:
source = __detect_source(state_dict, key_sets)
return __convert(state_dict, key_sets, source, "omi")
def convert_to_diffusers(
state_dict: dict[str, Tensor],
key_sets: list[LoraConversionKeySet],
) -> dict[str, Tensor]:
source = __detect_source(state_dict, key_sets)
return __convert(state_dict, key_sets, source, "diffusers")
def convert_to_legacy_diffusers(
state_dict: dict[str, Tensor],
key_sets: list[LoraConversionKeySet],
) -> dict[str, Tensor]:
source = __detect_source(state_dict, key_sets)
return __convert(state_dict, key_sets, source, "legacy_diffusers")

View File

@@ -1,125 +0,0 @@
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_clip import map_clip
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_lora_util import (
LoraConversionKeySet,
map_prefix_range,
)
def __map_unet_resnet_block(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
keys = []
keys += [LoraConversionKeySet("emb_layers.1", "time_emb_proj", parent=key_prefix)]
keys += [LoraConversionKeySet("in_layers.2", "conv1", parent=key_prefix)]
keys += [LoraConversionKeySet("out_layers.3", "conv2", parent=key_prefix)]
keys += [LoraConversionKeySet("skip_connection", "conv_shortcut", parent=key_prefix)]
return keys
def __map_unet_attention_block(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
keys = []
keys += [LoraConversionKeySet("proj_in", "proj_in", parent=key_prefix)]
keys += [LoraConversionKeySet("proj_out", "proj_out", parent=key_prefix)]
for k in map_prefix_range("transformer_blocks", "transformer_blocks", parent=key_prefix):
keys += [LoraConversionKeySet("attn1.to_q", "attn1.to_q", parent=k)]
keys += [LoraConversionKeySet("attn1.to_k", "attn1.to_k", parent=k)]
keys += [LoraConversionKeySet("attn1.to_v", "attn1.to_v", parent=k)]
keys += [LoraConversionKeySet("attn1.to_out.0", "attn1.to_out.0", parent=k)]
keys += [LoraConversionKeySet("attn2.to_q", "attn2.to_q", parent=k)]
keys += [LoraConversionKeySet("attn2.to_k", "attn2.to_k", parent=k)]
keys += [LoraConversionKeySet("attn2.to_v", "attn2.to_v", parent=k)]
keys += [LoraConversionKeySet("attn2.to_out.0", "attn2.to_out.0", parent=k)]
keys += [LoraConversionKeySet("ff.net.0.proj", "ff.net.0.proj", parent=k)]
keys += [LoraConversionKeySet("ff.net.2", "ff.net.2", parent=k)]
return keys
def __map_unet_down_blocks(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
keys = []
keys += __map_unet_resnet_block(LoraConversionKeySet("1.0", "0.resnets.0", parent=key_prefix))
keys += __map_unet_resnet_block(LoraConversionKeySet("2.0", "0.resnets.1", parent=key_prefix))
keys += [LoraConversionKeySet("3.0.op", "0.downsamplers.0.conv", parent=key_prefix)]
keys += __map_unet_resnet_block(LoraConversionKeySet("4.0", "1.resnets.0", parent=key_prefix))
keys += __map_unet_attention_block(LoraConversionKeySet("4.1", "1.attentions.0", parent=key_prefix))
keys += __map_unet_resnet_block(LoraConversionKeySet("5.0", "1.resnets.1", parent=key_prefix))
keys += __map_unet_attention_block(LoraConversionKeySet("5.1", "1.attentions.1", parent=key_prefix))
keys += [LoraConversionKeySet("6.0.op", "1.downsamplers.0.conv", parent=key_prefix)]
keys += __map_unet_resnet_block(LoraConversionKeySet("7.0", "2.resnets.0", parent=key_prefix))
keys += __map_unet_attention_block(LoraConversionKeySet("7.1", "2.attentions.0", parent=key_prefix))
keys += __map_unet_resnet_block(LoraConversionKeySet("8.0", "2.resnets.1", parent=key_prefix))
keys += __map_unet_attention_block(LoraConversionKeySet("8.1", "2.attentions.1", parent=key_prefix))
return keys
def __map_unet_mid_block(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
keys = []
keys += __map_unet_resnet_block(LoraConversionKeySet("0", "resnets.0", parent=key_prefix))
keys += __map_unet_attention_block(LoraConversionKeySet("1", "attentions.0", parent=key_prefix))
keys += __map_unet_resnet_block(LoraConversionKeySet("2", "resnets.1", parent=key_prefix))
return keys
def __map_unet_up_block(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
keys = []
keys += __map_unet_resnet_block(LoraConversionKeySet("0.0", "0.resnets.0", parent=key_prefix))
keys += __map_unet_attention_block(LoraConversionKeySet("0.1", "0.attentions.0", parent=key_prefix))
keys += __map_unet_resnet_block(LoraConversionKeySet("1.0", "0.resnets.1", parent=key_prefix))
keys += __map_unet_attention_block(LoraConversionKeySet("1.1", "0.attentions.1", parent=key_prefix))
keys += __map_unet_resnet_block(LoraConversionKeySet("2.0", "0.resnets.2", parent=key_prefix))
keys += __map_unet_attention_block(LoraConversionKeySet("2.1", "0.attentions.2", parent=key_prefix))
keys += [LoraConversionKeySet("2.2.conv", "0.upsamplers.0.conv", parent=key_prefix)]
keys += __map_unet_resnet_block(LoraConversionKeySet("3.0", "1.resnets.0", parent=key_prefix))
keys += __map_unet_attention_block(LoraConversionKeySet("3.1", "1.attentions.0", parent=key_prefix))
keys += __map_unet_resnet_block(LoraConversionKeySet("4.0", "1.resnets.1", parent=key_prefix))
keys += __map_unet_attention_block(LoraConversionKeySet("4.1", "1.attentions.1", parent=key_prefix))
keys += __map_unet_resnet_block(LoraConversionKeySet("5.0", "1.resnets.2", parent=key_prefix))
keys += __map_unet_attention_block(LoraConversionKeySet("5.1", "1.attentions.2", parent=key_prefix))
keys += [LoraConversionKeySet("5.2.conv", "1.upsamplers.0.conv", parent=key_prefix)]
keys += __map_unet_resnet_block(LoraConversionKeySet("6.0", "2.resnets.0", parent=key_prefix))
keys += __map_unet_resnet_block(LoraConversionKeySet("7.0", "2.resnets.1", parent=key_prefix))
keys += __map_unet_resnet_block(LoraConversionKeySet("8.0", "2.resnets.2", parent=key_prefix))
return keys
def __map_unet(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
keys = []
keys += [LoraConversionKeySet("input_blocks.0.0", "conv_in", parent=key_prefix)]
keys += [LoraConversionKeySet("time_embed.0", "time_embedding.linear_1", parent=key_prefix)]
keys += [LoraConversionKeySet("time_embed.2", "time_embedding.linear_2", parent=key_prefix)]
keys += [LoraConversionKeySet("label_emb.0.0", "add_embedding.linear_1", parent=key_prefix)]
keys += [LoraConversionKeySet("label_emb.0.2", "add_embedding.linear_2", parent=key_prefix)]
keys += __map_unet_down_blocks(LoraConversionKeySet("input_blocks", "down_blocks", parent=key_prefix))
keys += __map_unet_mid_block(LoraConversionKeySet("middle_block", "mid_block", parent=key_prefix))
keys += __map_unet_up_block(LoraConversionKeySet("output_blocks", "up_blocks", parent=key_prefix))
keys += [LoraConversionKeySet("out.0", "conv_norm_out", parent=key_prefix)]
keys += [LoraConversionKeySet("out.2", "conv_out", parent=key_prefix)]
return keys
def convert_sdxl_lora_key_sets() -> list[LoraConversionKeySet]:
keys = []
keys += [LoraConversionKeySet("bundle_emb", "bundle_emb")]
keys += __map_unet(LoraConversionKeySet("unet", "lora_unet"))
keys += map_clip(LoraConversionKeySet("clip_l", "lora_te1"))
keys += map_clip(LoraConversionKeySet("clip_g", "lora_te2"))
return keys

View File

@@ -1,19 +0,0 @@
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_lora_util import (
LoraConversionKeySet,
map_prefix_range,
)
def map_t5(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
keys = []
for k in map_prefix_range("encoder.block", "encoder.block", parent=key_prefix):
keys += [LoraConversionKeySet("layer.0.SelfAttention.k", "layer.0.SelfAttention.k", parent=k)]
keys += [LoraConversionKeySet("layer.0.SelfAttention.o", "layer.0.SelfAttention.o", parent=k)]
keys += [LoraConversionKeySet("layer.0.SelfAttention.q", "layer.0.SelfAttention.q", parent=k)]
keys += [LoraConversionKeySet("layer.0.SelfAttention.v", "layer.0.SelfAttention.v", parent=k)]
keys += [LoraConversionKeySet("layer.1.DenseReluDense.wi_0", "layer.1.DenseReluDense.wi_0", parent=k)]
keys += [LoraConversionKeySet("layer.1.DenseReluDense.wi_1", "layer.1.DenseReluDense.wi_1", parent=k)]
keys += [LoraConversionKeySet("layer.1.DenseReluDense.wo", "layer.1.DenseReluDense.wo", parent=k)]
return keys

View File

@@ -1,31 +0,0 @@
stable_diffusion_1_lora = "stable-diffusion-v1/lora"
stable_diffusion_1_inpainting_lora = "stable-diffusion-v1-inpainting/lora"
stable_diffusion_2_512_lora = "stable-diffusion-v2-512/lora"
stable_diffusion_2_768_v_lora = "stable-diffusion-v2-768-v/lora"
stable_diffusion_2_depth_lora = "stable-diffusion-v2-depth/lora"
stable_diffusion_2_inpainting_lora = "stable-diffusion-v2-inpainting/lora"
stable_diffusion_3_medium_lora = "stable-diffusion-v3-medium/lora"
stable_diffusion_35_medium_lora = "stable-diffusion-v3.5-medium/lora"
stable_diffusion_35_large_lora = "stable-diffusion-v3.5-large/lora"
stable_diffusion_xl_1_lora = "stable-diffusion-xl-v1-base/lora"
stable_diffusion_xl_1_inpainting_lora = "stable-diffusion-xl-v1-base-inpainting/lora"
wuerstchen_2_lora = "wuerstchen-v2-prior/lora"
stable_cascade_1_stage_a_lora = "stable-cascade-v1-stage-a/lora"
stable_cascade_1_stage_b_lora = "stable-cascade-v1-stage-b/lora"
stable_cascade_1_stage_c_lora = "stable-cascade-v1-stage-c/lora"
pixart_alpha_lora = "pixart-alpha/lora"
pixart_sigma_lora = "pixart-sigma/lora"
flux_dev_1_lora = "Flux.1-dev/lora"
flux_fill_dev_1_lora = "Flux.1-fill-dev/lora"
sana_lora = "sana/lora"
hunyuan_video_lora = "hunyuan-video/lora"
hi_dream_i1_lora = "hidream-i1/lora"

View File

@@ -23,7 +23,7 @@ class StarterModel(StarterModelWithoutDependencies):
dependencies: Optional[list[StarterModelWithoutDependencies]] = None
class StarterModelBundle(BaseModel):
class StarterModelBundles(BaseModel):
name: str
models: list[StarterModel]
@@ -109,7 +109,7 @@ flux_vae = StarterModel(
# region: Main
flux_schnell_quantized = StarterModel(
name="FLUX.1 schnell (quantized)",
name="FLUX Schnell (Quantized)",
base=BaseModelType.Flux,
source="InvokeAI/flux_schnell::transformer/bnb_nf4/flux1-schnell-bnb_nf4.safetensors",
description="FLUX schnell transformer quantized to bitsandbytes NF4 format. Total size with dependencies: ~12GB",
@@ -117,7 +117,7 @@ flux_schnell_quantized = StarterModel(
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
)
flux_dev_quantized = StarterModel(
name="FLUX.1 dev (quantized)",
name="FLUX Dev (Quantized)",
base=BaseModelType.Flux,
source="InvokeAI/flux_dev::transformer/bnb_nf4/flux1-dev-bnb_nf4.safetensors",
description="FLUX dev transformer quantized to bitsandbytes NF4 format. Total size with dependencies: ~12GB",
@@ -125,7 +125,7 @@ flux_dev_quantized = StarterModel(
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
)
flux_schnell = StarterModel(
name="FLUX.1 schnell",
name="FLUX Schnell",
base=BaseModelType.Flux,
source="InvokeAI/flux_schnell::transformer/base/flux1-schnell.safetensors",
description="FLUX schnell transformer in bfloat16. Total size with dependencies: ~33GB",
@@ -133,29 +133,13 @@ flux_schnell = StarterModel(
dependencies=[t5_base_encoder, flux_vae, clip_l_encoder],
)
flux_dev = StarterModel(
name="FLUX.1 dev",
name="FLUX Dev",
base=BaseModelType.Flux,
source="InvokeAI/flux_dev::transformer/base/flux1-dev.safetensors",
description="FLUX dev transformer in bfloat16. Total size with dependencies: ~33GB",
type=ModelType.Main,
dependencies=[t5_base_encoder, flux_vae, clip_l_encoder],
)
flux_kontext = StarterModel(
name="FLUX.1 Kontext dev",
base=BaseModelType.Flux,
source="https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev/resolve/main/flux1-kontext-dev.safetensors",
description="FLUX.1 Kontext dev transformer in bfloat16. Total size with dependencies: ~33GB",
type=ModelType.Main,
dependencies=[t5_base_encoder, flux_vae, clip_l_encoder],
)
flux_kontext_quantized = StarterModel(
name="FLUX.1 Kontext dev (Quantized)",
base=BaseModelType.Flux,
source="https://huggingface.co/unsloth/FLUX.1-Kontext-dev-GGUF/resolve/main/flux1-kontext-dev-Q4_K_M.gguf",
description="FLUX.1 Kontext dev quantized (q4_k_m). Total size with dependencies: ~14GB",
type=ModelType.Main,
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
)
sd35_medium = StarterModel(
name="SD3.5 Medium",
base=BaseModelType.StableDiffusion3,
@@ -672,7 +656,6 @@ flux_fill = StarterModel(
# List of starter models, displayed on the frontend.
# The order/sort of this list is not changed by the frontend - set it how you want it here.
STARTER_MODELS: list[StarterModel] = [
flux_kontext_quantized,
flux_schnell_quantized,
flux_dev_quantized,
flux_schnell,
@@ -793,13 +776,12 @@ flux_bundle: list[StarterModel] = [
flux_depth_control_lora,
flux_redux,
flux_fill,
flux_kontext_quantized,
]
STARTER_BUNDLES: dict[str, StarterModelBundle] = {
BaseModelType.StableDiffusion1: StarterModelBundle(name="Stable Diffusion 1.5", models=sd1_bundle),
BaseModelType.StableDiffusionXL: StarterModelBundle(name="SDXL", models=sdxl_bundle),
BaseModelType.Flux: StarterModelBundle(name="FLUX.1 dev", models=flux_bundle),
STARTER_BUNDLES: dict[str, list[StarterModel]] = {
BaseModelType.StableDiffusion1: sd1_bundle,
BaseModelType.StableDiffusionXL: sdxl_bundle,
BaseModelType.Flux: flux_bundle,
}
assert len(STARTER_MODELS) == len({m.source for m in STARTER_MODELS}), "Duplicate starter models"

View File

@@ -29,7 +29,6 @@ class BaseModelType(str, Enum):
Imagen3 = "imagen3"
Imagen4 = "imagen4"
ChatGPT4o = "chatgpt-4o"
FluxKontext = "flux-kontext"
class ModelType(str, Enum):

View File

@@ -17,15 +17,6 @@ module.exports = {
'no-promise-executor-return': 'error',
// https://eslint.org/docs/latest/rules/require-await
'require-await': 'error',
// Restrict setActiveTab calls to only use-navigation-api.tsx
'no-restricted-syntax': [
'error',
{
selector: 'CallExpression[callee.name="setActiveTab"]',
message:
'setActiveTab() can only be called from use-navigation-api.tsx. Use navigationApi.switchToTab() instead.',
},
],
// TODO: ENABLE THIS RULE BEFORE v6.0.0
'react/display-name': 'off',
'no-restricted-properties': [
@@ -42,38 +33,8 @@ module.exports = {
'The Clipboard API is not available by default in Firefox. Use the `useClipboard` hook instead, which wraps clipboard access to prevent errors.',
},
],
'no-restricted-imports': [
'error',
{
paths: [
{
name: 'lodash-es',
importNames: ['isEqual'],
message: 'Please use objectEquals from @observ33r/object-equals instead.',
},
{
name: 'lodash-es',
message: 'Please use es-toolkit instead.',
},
{
name: 'es-toolkit',
importNames: ['isEqual'],
message: 'Please use objectEquals from @observ33r/object-equals instead.',
},
],
},
],
},
overrides: [
/**
* Allow setActiveTab calls only in use-navigation-api.tsx
*/
{
files: ['**/use-navigation-api.tsx'],
rules: {
'no-restricted-syntax': 'off',
},
},
/**
* Overrides for stories
*/

View File

@@ -3,6 +3,8 @@ import type { KnipConfig } from 'knip';
const config: KnipConfig = {
project: ['src/**/*.{ts,tsx}!'],
ignore: [
// TODO(psyche): temporarily ignored all files for test build purposes
'src/**',
// This file is only used during debugging
'src/app/store/middleware/debugLoggerMiddleware.ts',
// Autogenerated types - shouldn't ever touch these
@@ -12,8 +14,10 @@ const config: KnipConfig = {
'src/features/parameters/types/parameterSchemas.ts',
// TODO(psyche): maybe we can clean up these utils after canvas v2 release
'src/features/controlLayers/konva/util.ts',
// Will be using this
'src/common/hooks/useAsyncState.ts',
// TODO(psyche): restore HRF functionality?
'src/features/hrf/**',
// This feature is (temprarily?) disabled
'src/features/controlLayers/components/InpaintMask/InpaintMaskAddButtons.tsx',
],
ignoreBinaries: ['only-allow'],
paths: {

View File

@@ -38,60 +38,71 @@
"test:ui": "vitest --coverage --ui",
"test:no-watch": "vitest --no-watch"
},
"madge": {
"excludeRegExp": [
"^index.ts$"
],
"detectiveOptions": {
"ts": {
"skipTypeImports": true
},
"tsx": {
"skipTypeImports": true
}
}
},
"dependencies": {
"@atlaskit/pragmatic-drag-and-drop": "^1.7.4",
"@atlaskit/pragmatic-drag-and-drop-auto-scroll": "^2.1.1",
"@atlaskit/pragmatic-drag-and-drop-hitbox": "^1.1.0",
"@dagrejs/dagre": "^1.1.5",
"@atlaskit/pragmatic-drag-and-drop": "^1.5.3",
"@atlaskit/pragmatic-drag-and-drop-auto-scroll": "^2.1.0",
"@atlaskit/pragmatic-drag-and-drop-hitbox": "^1.0.3",
"@dagrejs/dagre": "^1.1.4",
"@dagrejs/graphlib": "^2.2.4",
"@fontsource-variable/inter": "^5.2.6",
"@fontsource-variable/inter": "^5.2.5",
"@invoke-ai/ui-library": "^0.0.46",
"@nanostores/react": "^1.0.0",
"@observ33r/object-equals": "^1.1.4",
"@reduxjs/toolkit": "2.8.2",
"@roarr/browser-log-writer": "^1.3.0",
"@xyflow/react": "^12.7.1",
"ag-psd": "^28.2.1",
"@xyflow/react": "^12.6.0",
"async-mutex": "^0.5.0",
"chakra-react-select": "^4.9.2",
"cmdk": "^1.1.1",
"compare-versions": "^6.1.1",
"dockview": "^4.4.0",
"es-toolkit": "^1.39.5",
"dockview": "^4.3.1",
"filesize": "^10.1.6",
"fracturedjsonjs": "^4.1.0",
"framer-motion": "^11.10.0",
"i18next": "^25.2.1",
"i18next": "^25.0.1",
"i18next-http-backend": "^3.0.2",
"idb-keyval": "^6.2.2",
"idb-keyval": "^6.2.1",
"jsondiffpatch": "^0.7.3",
"konva": "^9.3.20",
"linkify-react": "^4.3.1",
"linkifyjs": "^4.3.1",
"linkify-react": "^4.2.0",
"linkifyjs": "^4.2.0",
"lodash-es": "^4.17.21",
"lru-cache": "^11.1.0",
"mtwist": "^1.0.2",
"nanoid": "^5.1.5",
"nanostores": "^1.0.1",
"new-github-issue-url": "^1.1.0",
"overlayscrollbars": "^2.11.4",
"overlayscrollbars": "^2.11.1",
"overlayscrollbars-react": "^0.5.6",
"perfect-freehand": "^1.2.2",
"query-string": "^9.2.1",
"query-string": "^9.1.1",
"raf-throttle": "^2.0.6",
"react": "^18.3.1",
"react-colorful": "^5.6.1",
"react-dom": "^18.3.1",
"react-dropzone": "^14.3.8",
"react-error-boundary": "^5.0.0",
"react-hook-form": "^7.58.1",
"react-hook-form": "^7.56.1",
"react-hotkeys-hook": "4.5.0",
"react-i18next": "^15.5.3",
"react-i18next": "^15.5.1",
"react-icons": "^5.5.0",
"react-redux": "9.2.0",
"react-resizable-panels": "^3.0.3",
"react-resizable-panels": "^2.1.8",
"react-textarea-autosize": "^8.5.9",
"react-use": "^17.6.0",
"react-virtuoso": "^4.13.0",
"react-virtuoso": "^4.12.6",
"redux-dynamic-middlewares": "^2.2.0",
"redux-remember": "^5.2.0",
"redux-undo": "^1.1.0",
@@ -99,12 +110,12 @@
"roarr": "^7.21.1",
"serialize-error": "^12.0.0",
"socket.io-client": "^4.8.1",
"stable-hash": "^0.0.6",
"use-debounce": "^10.0.5",
"stable-hash": "^0.0.5",
"use-debounce": "^10.0.4",
"use-device-pixel-ratio": "^1.1.2",
"uuid": "^11.1.0",
"zod": "^3.25.67",
"zod-validation-error": "^3.5.2"
"zod": "^3.24.3",
"zod-validation-error": "^3.4.0"
},
"peerDependencies": {
"react": "^18.2.0",
@@ -121,6 +132,7 @@
"@storybook/react": "^8.6.12",
"@storybook/react-vite": "^8.6.12",
"@storybook/theming": "^8.6.12",
"@types/lodash-es": "^4.17.12",
"@types/node": "^22.15.1",
"@types/react": "^18.3.11",
"@types/react-dom": "^18.3.0",
@@ -134,7 +146,7 @@
"eslint": "^8.57.1",
"eslint-plugin-i18next": "^6.1.1",
"eslint-plugin-path": "^1.3.0",
"knip": "^5.61.3",
"knip": "^5.50.5",
"openapi-types": "^12.1.3",
"openapi-typescript": "^7.6.1",
"prettier": "^3.5.3",
@@ -143,7 +155,7 @@
"tsafe": "^1.8.5",
"type-fest": "^4.40.0",
"typescript": "^5.8.3",
"vite": "^7.0.2",
"vite": "^6.3.3",
"vite-plugin-css-injected-by-js": "^3.5.2",
"vite-plugin-dts": "^4.5.3",
"vite-plugin-eslint": "^1.8.1",
@@ -151,7 +163,7 @@
"vitest": "^3.1.2"
},
"engines": {
"pnpm": "10"
"pnpm": "8"
},
"packageManager": "pnpm@10.12.4"
"packageManager": "pnpm@8.15.9+sha512.499434c9d8fdd1a2794ebf4552b3b25c0a633abcee5bb15e7b5de90f32f47b513aca98cd5cfd001c31f0db454bc3804edccd578501e4ca293a6816166bbd9f81"
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,3 +0,0 @@
onlyBuiltDependencies:
- '@swc/core'
- esbuild

View File

@@ -225,16 +225,7 @@
"prompt": {
"addPromptTrigger": "Add Prompt Trigger",
"compatibleEmbeddings": "Compatible Embeddings",
"noMatchingTriggers": "No matching triggers",
"generateFromImage": "Generate prompt from image",
"expandCurrentPrompt": "Expand Current Prompt",
"uploadImageForPromptGeneration": "Upload Image for Prompt Generation",
"expandingPrompt": "Expanding prompt...",
"resultTitle": "Prompt Expansion Complete",
"resultSubtitle": "Choose how to handle the expanded prompt:",
"replace": "Replace",
"insert": "Insert",
"discard": "Discard"
"noMatchingTriggers": "No matching triggers"
},
"queue": {
"queue": "Queue",
@@ -344,14 +335,14 @@
"images": "Images",
"assets": "Assets",
"alwaysShowImageSizeBadge": "Always Show Image Size Badge",
"assetsTab": "Files you've uploaded for use in your projects.",
"assetsTab": "Files youve uploaded for use in your projects.",
"autoAssignBoardOnClick": "Auto-Assign Board on Click",
"autoSwitchNewImages": "Auto-Switch to New Images",
"boardsSettings": "Boards Settings",
"copy": "Copy",
"currentlyInUse": "This image is currently in use in the following features:",
"drop": "Drop",
"dropOrUpload": "Drop or Upload",
"dropOrUpload": "$t(gallery.drop) or Upload",
"dropToUpload": "$t(gallery.drop) to Upload",
"deleteImage_one": "Delete Image",
"deleteImage_other": "Delete {{count}} Images",
@@ -366,7 +357,7 @@
"gallerySettings": "Gallery Settings",
"go": "Go",
"image": "image",
"imagesTab": "Images you've created and saved within Invoke.",
"imagesTab": "Images youve created and saved within Invoke.",
"imagesSettings": "Gallery Images Settings",
"jump": "Jump",
"loading": "Loading",
@@ -405,8 +396,7 @@
"compareHelp4": "Press <Kbd>Z</Kbd> or <Kbd>Esc</Kbd> to exit.",
"openViewer": "Open Viewer",
"closeViewer": "Close Viewer",
"move": "Move",
"useForPromptGeneration": "Use for Prompt Generation"
"move": "Move"
},
"hotkeys": {
"hotkeys": "Hotkeys",
@@ -589,16 +579,6 @@
"cancelTransform": {
"title": "Cancel Transform",
"desc": "Cancel the pending transform."
},
"settings": {
"behavior": "Behavior",
"display": "Display",
"grid": "Grid",
"debug": "Debug"
},
"toggleNonRasterLayers": {
"title": "Toggle Non-Raster Layers",
"desc": "Show or hide all non-raster layer categories (Control Layers, Inpaint Masks, Regional Guidance)."
}
},
"workflows": {
@@ -762,7 +742,7 @@
"vae": "VAE",
"width": "Width",
"workflow": "Workflow",
"canvasV2Metadata": "Canvas Layers"
"canvasV2Metadata": "Canvas"
},
"modelManager": {
"active": "active",
@@ -783,7 +763,7 @@
"convertToDiffusers": "Convert To Diffusers",
"convertToDiffusersHelpText1": "This model will be converted to the 🧨 Diffusers format.",
"convertToDiffusersHelpText2": "This process will replace your Model Manager entry with the Diffusers version of the same model.",
"convertToDiffusersHelpText3": "Your checkpoint file on disk WILL be deleted if it is in the InvokeAI root folder. If it is in a custom location, then it WILL NOT be deleted.",
"convertToDiffusersHelpText3": "Your checkpoint file on disk WILL be deleted if it is in InvokeAI root folder. If it is in a custom location, then it WILL NOT be deleted.",
"convertToDiffusersHelpText4": "This is a one time process only. It might take around 30s-60s depending on the specifications of your computer.",
"convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 2GB-7GB in size.",
"convertToDiffusersHelpText6": "Do you wish to convert this model?",
@@ -826,11 +806,7 @@
"urlUnauthorizedErrorMessage": "You may need to configure an API token to access this model.",
"urlUnauthorizedErrorMessage2": "Learn how here.",
"imageEncoderModelId": "Image Encoder Model ID",
"installedModelsCount": "{{installed}} of {{total}} models installed.",
"includesNModels": "Includes {{n}} models and their dependencies.",
"allNModelsInstalled": "All {{count}} models installed",
"nToInstall": "{{count}} to install",
"nAlreadyInstalled": "{{count}} already installed",
"includesNModels": "Includes {{n}} models and their dependencies",
"installQueue": "Install Queue",
"inplaceInstall": "In-place install",
"inplaceInstallDesc": "Install models without copying the files. When using the model, it will be loaded from its this location. If disabled, the model file(s) will be copied into the Invoke-managed models directory during installation.",
@@ -893,25 +869,6 @@
"starterBundleHelpText": "Easily install all models needed to get started with a base model, including a main model, controlnets, IP adapters, and more. Selecting a bundle will skip any models that you already have installed.",
"starterModels": "Starter Models",
"starterModelsInModelManager": "Starter Models can be found in Model Manager",
"bundleAlreadyInstalled": "Bundle already installed",
"bundleAlreadyInstalledDesc": "All models in the {{bundleName}} bundle are already installed.",
"launchpadTab": "Launchpad",
"launchpad": {
"welcome": "Welcome to Model Management",
"description": "Invoke requires models to be installed to utilize most features of the platform. Choose from manual installation options or explore curated starter models.",
"manualInstall": "Manual Installation",
"urlDescription": "Install models from a URL or local file path. Perfect for specific models you want to add.",
"huggingFaceDescription": "Browse and install models directly from HuggingFace repositories.",
"scanFolderDescription": "Scan a local folder to automatically detect and install models.",
"recommendedModels": "Recommended Models",
"exploreStarter": "Or browse all available starter models",
"quickStart": "Quick Start Bundles",
"bundleDescription": "Each bundle includes essential models for each model family and curated base models to get started.",
"browseAll": "Or browse all available models:",
"stableDiffusion15": "Stable Diffusion 1.5",
"sdxl": "SDXL",
"fluxDev": "FLUX.1 dev"
},
"controlLora": "Control LoRA",
"llavaOnevision": "LLaVA OneVision",
"syncModels": "Sync Models",
@@ -948,8 +905,7 @@
"selectModel": "Select a Model",
"noLoRAsInstalled": "No LoRAs installed",
"noRefinerModelsInstalled": "No SDXL Refiner models installed",
"defaultVAE": "Default VAE",
"noCompatibleLoRAs": "No Compatible LoRAs"
"defaultVAE": "Default VAE"
},
"nodes": {
"arithmeticSequence": "Arithmetic Sequence",
@@ -1191,7 +1147,6 @@
"modelIncompatibleScaledBboxWidth": "Scaled bbox width is {{width}} but {{model}} requires multiple of {{multiple}}",
"modelIncompatibleScaledBboxHeight": "Scaled bbox height is {{height}} but {{model}} requires multiple of {{multiple}}",
"fluxModelMultipleControlLoRAs": "Can only use 1 Control LoRA at a time",
"fluxKontextMultipleReferenceImages": "Can only use 1 Reference Image at a time with Flux Kontext",
"canvasIsFiltering": "Canvas is busy (filtering)",
"canvasIsTransforming": "Canvas is busy (transforming)",
"canvasIsRasterizing": "Canvas is busy (rasterizing)",
@@ -1199,9 +1154,7 @@
"canvasIsSelectingObject": "Canvas is busy (selecting object)",
"noPrompts": "No prompts generated",
"noNodesInGraph": "No nodes in graph",
"systemDisconnected": "System disconnected",
"promptExpansionPending": "Prompt expansion in progress",
"promptExpansionResultPending": "Please accept or discard your prompt expansion result"
"systemDisconnected": "System disconnected"
},
"maskBlur": "Mask Blur",
"negativePromptPlaceholder": "Negative Prompt",
@@ -1359,21 +1312,6 @@
"problemCopyingLayer": "Unable to Copy Layer",
"problemSavingLayer": "Unable to Save Layer",
"problemDownloadingImage": "Unable to Download Image",
"noRasterLayers": "No Raster Layers Found",
"noRasterLayersDesc": "Create at least one raster layer to export to PSD",
"noActiveRasterLayers": "No Active Raster Layers",
"noActiveRasterLayersDesc": "Enable at least one raster layer to export to PSD",
"noVisibleRasterLayers": "No Visible Raster Layers",
"noVisibleRasterLayersDesc": "Enable at least one raster layer to export to PSD",
"invalidCanvasDimensions": "Invalid Canvas Dimensions",
"canvasTooLarge": "Canvas Too Large",
"canvasTooLargeDesc": "Canvas dimensions exceed the maximum allowed size for PSD export. Reduce the total width and height of the canvas of the canvas and try again.",
"failedToProcessLayers": "Failed to Process Layers",
"psdExportSuccess": "PSD Export Complete",
"psdExportSuccessDesc": "Successfully exported {{count}} layers to PSD file",
"problemExportingPSD": "Problem Exporting PSD",
"canvasManagerNotAvailable": "Canvas Manager Not Available",
"noValidLayerAdapters": "No Valid Layer Adapters Found",
"pasteSuccess": "Pasted to {{destination}}",
"pasteFailed": "Paste Failed",
"prunedQueue": "Pruned Queue",
@@ -1399,15 +1337,9 @@
"fluxFillIncompatibleWithT2IAndI2I": "FLUX Fill is not compatible with Text to Image or Image to Image. Use other FLUX models for these tasks.",
"imagenIncompatibleGenerationMode": "Google {{model}} supports Text to Image only. Use other models for Image to Image, Inpainting and Outpainting tasks.",
"chatGPT4oIncompatibleGenerationMode": "ChatGPT 4o supports Text to Image and Image to Image only. Use other models Inpainting and Outpainting tasks.",
"fluxKontextIncompatibleGenerationMode": "FLUX Kontext does not support generation from images placed on the canvas. Re-try using the Reference Image section and disable any Raster Layers.",
"problemUnpublishingWorkflow": "Problem Unpublishing Workflow",
"problemUnpublishingWorkflowDescription": "There was a problem unpublishing the workflow. Please try again.",
"workflowUnpublished": "Workflow Unpublished",
"sentToCanvas": "Sent to Canvas",
"sentToUpscale": "Sent to Upscale",
"promptGenerationStarted": "Prompt generation started",
"uploadAndPromptGenerationFailed": "Failed to upload image and generate prompt",
"promptExpansionFailed": "We ran into an issue. Please try prompt expansion again."
"workflowUnpublished": "Workflow Unpublished"
},
"popovers": {
"clipSkip": {
@@ -1930,7 +1862,6 @@
"saveCanvasToGallery": "Save Canvas to Gallery",
"saveBboxToGallery": "Save Bbox to Gallery",
"saveLayerToAssets": "Save Layer to Assets",
"exportCanvasToPSD": "Export Canvas to PSD",
"cropLayerToBbox": "Crop Layer to Bbox",
"savedToGalleryOk": "Saved to Gallery",
"savedToGalleryError": "Error saving to gallery",
@@ -1956,13 +1887,11 @@
"mergingLayers": "Merging layers",
"clearHistory": "Clear History",
"bboxOverlay": "Show Bbox Overlay",
"ruleOfThirds": "Show Rule of Thirds",
"newSession": "New Session",
"clearCaches": "Clear Caches",
"recalculateRects": "Recalculate Rects",
"clipToBbox": "Clip Strokes to Bbox",
"outputOnlyMaskedRegions": "Output Only Generated Regions",
"saveAllImagesToGallery": "Save All Images to Gallery",
"addLayer": "Add Layer",
"duplicate": "Duplicate",
"moveToFront": "Move to Front",
@@ -2063,8 +1992,6 @@
"disableTransparencyEffect": "Disable Transparency Effect",
"hidingType": "Hiding {{type}}",
"showingType": "Showing {{type}}",
"showNonRasterLayers": "Show Non-Raster Layers (Shift+H)",
"hideNonRasterLayers": "Hide Non-Raster Layers (Shift+H)",
"dynamicGrid": "Dynamic Grid",
"logDebugInfo": "Log Debug Info",
"locked": "Locked",
@@ -2331,9 +2258,6 @@
"label": "Preserve Masked Region",
"alert": "Preserving Masked Region"
},
"saveAllImagesToGallery": {
"alert": "Saving All Images to Gallery"
},
"isolatedStagingPreview": "Isolated Staging Preview",
"isolatedPreview": "Isolated Preview",
"isolatedLayerPreview": "Isolated Layer Preview",
@@ -2362,7 +2286,6 @@
"newGlobalReferenceImage": "New Global Reference Image",
"newRegionalReferenceImage": "New Regional Reference Image",
"newControlLayer": "New Control Layer",
"newResizedControlLayer": "New Resized Control Layer",
"newRasterLayer": "New Raster Layer",
"newInpaintMask": "New Inpaint Mask",
"newRegionalGuidance": "New Regional Guidance",
@@ -2380,11 +2303,6 @@
"saveToGallery": "Save To Gallery",
"showResultsOn": "Showing Results",
"showResultsOff": "Hiding Results"
},
"autoSwitch": {
"off": "Off",
"switchOnStart": "On Start",
"switchOnFinish": "On Finish"
}
},
"upscaling": {
@@ -2451,8 +2369,7 @@
"uploadImage": "Upload Image",
"useForTemplate": "Use For Prompt Template",
"viewList": "View Template List",
"viewModeTooltip": "This is how your prompt will look with your currently selected template. To edit your prompt, click anywhere in the text box.",
"togglePromptPreviews": "Toggle Prompt Previews"
"viewModeTooltip": "This is how your prompt will look with your currently selected template. To edit your prompt, click anywhere in the text box."
},
"upsell": {
"inviteTeammates": "Invite Teammates",
@@ -2472,55 +2389,6 @@
"upscaling": "Upscaling",
"upscalingTab": "$t(ui.tabs.upscaling) $t(common.tab)",
"gallery": "Gallery"
},
"launchpad": {
"workflowsTitle": "Go deep with Workflows.",
"upscalingTitle": "Upscale and add detail.",
"canvasTitle": "Edit and refine on Canvas.",
"generateTitle": "Generate images from text prompts.",
"modelGuideText": "Want to learn what prompts work best for each model?",
"modelGuideLink": "Check out our Model Guide.",
"workflows": {
"description": "Workflows are reusable templates that automate image generation tasks, allowing you to quickly perform complex operations and get consistent results.",
"learnMoreLink": "Learn more about creating workflows",
"browseTemplates": {
"title": "Browse Workflow Templates",
"description": "Choose from pre-built workflows for common tasks"
},
"createNew": {
"title": "Create a new Workflow",
"description": "Start a new workflow from scratch"
},
"loadFromFile": {
"title": "Load workflow from file",
"description": "Upload a workflow to start with an existing setup"
}
},
"upscaling": {
"uploadImage": {
"title": "Upload Image to Upscale",
"description": "Click or drag an image to upscale (JPG, PNG, WebP up to 100MB)"
},
"replaceImage": {
"title": "Replace Current Image",
"description": "Click or drag a new image to replace the current one"
},
"imageReady": {
"title": "Image Ready",
"description": "Press Invoke to begin upscaling"
},
"readyToUpscale": {
"title": "Ready to upscale!",
"description": "Configure your settings below, then click the Invoke button to begin upscaling your image."
},
"upscaleModel": "Upscale Model",
"model": "Model",
"scale": "Scale",
"helpText": {
"promptAdvice": "When upscaling, use a prompt that describes the medium and style. Avoid describing specific content details in the image.",
"styleAdvice": "Upscaling works best with the general style of your image."
}
}
}
},
"system": {
@@ -2560,9 +2428,8 @@
"whatsNew": {
"whatsNewInInvoke": "What's New in Invoke",
"items": [
"Generate images faster with new Launchpads and a simplified Generate tab.",
"Edit with prompts using Flux Kontext Dev.",
"Export to PSD, bulk-hide overlays, organize models & images — all in a reimagined interface built for control."
"Inpainting: Per-mask noise levels and denoise limits.",
"Canvas: Smarter aspect ratios for SDXL and improved scroll-to-zoom."
],
"readReleaseNotes": "Read Release Notes",
"watchRecentReleaseVideos": "Watch Recent Release Videos",
@@ -2571,16 +2438,62 @@
"supportVideos": {
"supportVideos": "Support Videos",
"gettingStarted": "Getting Started",
"controlCanvas": "Control Canvas",
"watch": "Watch",
"studioSessionsDesc": "Join our <DiscordLink /> to participate in the live sessions and ask questions. Sessions are uploaded to the playlist the following week.",
"studioSessionsDesc1": "Check out the <StudioSessionsPlaylistLink /> for Invoke deep dives.",
"studioSessionsDesc2": "Join our <DiscordLink /> to participate in the live sessions and ask questions. Sessions are uploaded to the playlist the following week.",
"videos": {
"gettingStarted": {
"title": "Getting Started with Invoke",
"description": "Complete video series covering everything you need to know to get started with Invoke, from creating your first image to advanced techniques."
"creatingYourFirstImage": {
"title": "Creating Your First Image",
"description": "Introduction to creating an image from scratch using Invoke's tools."
},
"studioSessions": {
"title": "Studio Sessions",
"description": "Deep dive sessions exploring advanced Invoke features, creative workflows, and community discussions."
"usingControlLayersAndReferenceGuides": {
"title": "Using Control Layers and Reference Guides",
"description": "Learn how to guide your image creation with control layers and reference images."
},
"understandingImageToImageAndDenoising": {
"title": "Understanding Image-to-Image and Denoising",
"description": "Overview of image-to-image transformations and denoising in Invoke."
},
"exploringAIModelsAndConceptAdapters": {
"title": "Exploring AI Models and Concept Adapters",
"description": "Dive into AI models and how to use concept adapters for creative control."
},
"creatingAndComposingOnInvokesControlCanvas": {
"title": "Creating and Composing on Invoke's Control Canvas",
"description": "Learn to compose images using Invoke's control canvas."
},
"upscaling": {
"title": "Upscaling",
"description": "How to upscale images with Invoke's tools to enhance resolution."
},
"howDoIGenerateAndSaveToTheGallery": {
"title": "How Do I Generate and Save to the Gallery?",
"description": "Steps to generate and save images to the gallery."
},
"howDoIEditOnTheCanvas": {
"title": "How Do I Edit on the Canvas?",
"description": "Guide to editing images directly on the canvas."
},
"howDoIDoImageToImageTransformation": {
"title": "How Do I Do Image-to-Image Transformation?",
"description": "Tutorial on performing image-to-image transformations in Invoke."
},
"howDoIUseControlNetsAndControlLayers": {
"title": "How Do I Use Control Nets and Control Layers?",
"description": "Learn to apply control layers and controlnets to your images."
},
"howDoIUseGlobalIPAdaptersAndReferenceImages": {
"title": "How Do I Use Global IP Adapters and Reference Images?",
"description": "Introduction to adding reference images and global IP adapters."
},
"howDoIUseInpaintMasks": {
"title": "How Do I Use Inpaint Masks?",
"description": "How to apply inpaint masks for image correction and variation."
},
"howDoIOutpaint": {
"title": "How Do I Outpaint?",
"description": "Guide to outpainting beyond the original image borders."
}
}
}

View File

@@ -2,7 +2,8 @@ import { Box } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { GlobalHookIsolator } from 'app/components/GlobalHookIsolator';
import { GlobalModalIsolator } from 'app/components/GlobalModalIsolator';
import { $didStudioInit, type StudioInitAction } from 'app/hooks/useStudioInitAction';
import type { StudioInitAction } from 'app/hooks/useStudioInitAction';
import { $globalIsLoading } from 'app/store/nanostores/globalIsLoading';
import type { PartialAppConfig } from 'app/types/invokeai';
import Loading from 'common/components/Loading/Loading';
import { useClearStorage } from 'common/hooks/useClearStorage';
@@ -11,7 +12,6 @@ import { memo, useCallback } from 'react';
import { ErrorBoundary } from 'react-error-boundary';
import AppErrorBoundaryFallback from './AppErrorBoundaryFallback';
import ThemeLocaleProvider from './ThemeLocaleProvider';
const DEFAULT_CONFIG = {};
interface Props {
@@ -20,7 +20,7 @@ interface Props {
}
const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
const didStudioInit = useStore($didStudioInit);
const globalIsLoading = useStore($globalIsLoading);
const clearStorage = useClearStorage();
const handleReset = useCallback(() => {
@@ -31,14 +31,12 @@ const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
return (
<ErrorBoundary onReset={handleReset} FallbackComponent={AppErrorBoundaryFallback}>
<ThemeLocaleProvider>
<Box id="invoke-app-wrapper" w="100dvw" h="100dvh" position="relative" overflow="hidden">
<AppContent />
{!didStudioInit && <Loading />}
</Box>
<GlobalHookIsolator config={config} studioInitAction={studioInitAction} />
<GlobalModalIsolator />
</ThemeLocaleProvider>
<Box id="invoke-app-wrapper" w="100dvw" h="100dvh" position="relative" overflow="hidden">
<AppContent />
{globalIsLoading && <Loading />}
</Box>
<GlobalHookIsolator config={config} studioInitAction={studioInitAction} />
<GlobalModalIsolator />
</ErrorBoundary>
);
};

View File

@@ -1,5 +1,4 @@
import { useGlobalModifiersInit } from '@invoke-ai/ui-library';
import { setupListeners } from '@reduxjs/toolkit/query';
import type { StudioInitAction } from 'app/hooks/useStudioInitAction';
import { useStudioInitAction } from 'app/hooks/useStudioInitAction';
import { useSyncQueueStatus } from 'app/hooks/useSyncQueueStatus';
@@ -11,15 +10,14 @@ import type { PartialAppConfig } from 'app/types/invokeai';
import { useFocusRegionWatcher } from 'common/hooks/focus';
import { useCloseChakraTooltipsOnDragFix } from 'common/hooks/useCloseChakraTooltipsOnDragFix';
import { useGlobalHotkeys } from 'common/hooks/useGlobalHotkeys';
import { useDndMonitor } from 'features/dnd/useDndMonitor';
import { useDynamicPromptsWatcher } from 'features/dynamicPrompts/hooks/useDynamicPromptsWatcher';
import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast';
import { useWorkflowBuilderWatcher } from 'features/nodes/components/sidePanel/workflow/IsolatedWorkflowBuilderWatcher';
import { useReadinessWatcher } from 'features/queue/store/readiness';
import { configChanged } from 'features/system/store/configSlice';
import { selectLanguage } from 'features/system/store/systemSelectors';
import { useNavigationApi } from 'features/ui/layouts/use-navigation-api';
import i18n from 'i18n';
import { size } from 'lodash-es';
import { memo, useEffect } from 'react';
import { useGetOpenAPISchemaQuery } from 'services/api/endpoints/appInfo';
import { useGetQueueCountsByDestinationQuery } from 'services/api/endpoints/queue';
@@ -45,8 +43,6 @@ export const GlobalHookIsolator = memo(
useGetOpenAPISchemaQuery();
useSyncLoggingConfig();
useCloseChakraTooltipsOnDragFix();
useNavigationApi();
useDndMonitor();
// Persistent subscription to the queue counts query - canvas relies on this to know if there are pending
// and/or in progress canvas sessions.
@@ -57,18 +53,16 @@ export const GlobalHookIsolator = memo(
}, [language]);
useEffect(() => {
logger.info({ config }, 'Received config');
dispatch(configChanged(config));
if (size(config)) {
logger.info({ config }, 'Received config');
dispatch(configChanged(config));
}
}, [dispatch, config, logger]);
useEffect(() => {
dispatch(appStarted());
}, [dispatch]);
useEffect(() => {
return setupListeners(dispatch);
}, [dispatch]);
useStudioInitAction(studioInitAction);
useStarterModelsToast();
useSyncQueueStatus();

View File

@@ -1,14 +1,11 @@
import { useAppSelector } from 'app/store/storeHooks';
import { useIsRegionFocused } from 'common/hooks/focus';
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
import { useLoadWorkflow } from 'features/gallery/hooks/useLoadWorkflow';
import { useRecallAll } from 'features/gallery/hooks/useRecallAll';
import { useRecallDimensions } from 'features/gallery/hooks/useRecallDimensions';
import { useRecallPrompts } from 'features/gallery/hooks/useRecallPrompts';
import { useRecallRemix } from 'features/gallery/hooks/useRecallRemix';
import { useRecallSeed } from 'features/gallery/hooks/useRecallSeed';
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { useImageActions } from 'features/gallery/hooks/useImageActions';
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { memo } from 'react';
import { useImageDTO } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
@@ -30,64 +27,59 @@ GlobalImageHotkeys.displayName = 'GlobalImageHotkeys';
const GlobalImageHotkeysInternal = memo(({ imageDTO }: { imageDTO: ImageDTO }) => {
const isGalleryFocused = useIsRegionFocused('gallery');
const isViewerFocused = useIsRegionFocused('viewer');
const isFocusOK = isGalleryFocused || isViewerFocused;
const recallAll = useRecallAll(imageDTO);
const recallRemix = useRecallRemix(imageDTO);
const recallPrompts = useRecallPrompts(imageDTO);
const recallSeed = useRecallSeed(imageDTO);
const recallDimensions = useRecallDimensions(imageDTO);
const loadWorkflow = useLoadWorkflow(imageDTO);
const imageActions = useImageActions(imageDTO);
const isStaging = useAppSelector(selectIsStaging);
const isUpscalingEnabled = useFeatureStatus('upscaling');
useRegisteredHotkeys({
id: 'loadWorkflow',
category: 'viewer',
callback: loadWorkflow.load,
options: { enabled: loadWorkflow.isEnabled && isFocusOK },
dependencies: [loadWorkflow, isFocusOK],
callback: imageActions.loadWorkflow,
options: { enabled: isGalleryFocused || isViewerFocused },
dependencies: [imageActions.loadWorkflow, isGalleryFocused, isViewerFocused],
});
useRegisteredHotkeys({
id: 'recallAll',
category: 'viewer',
callback: recallAll.recall,
options: { enabled: recallAll.isEnabled && isFocusOK },
dependencies: [recallAll, isFocusOK],
callback: imageActions.recallAll,
options: { enabled: !isStaging && (isGalleryFocused || isViewerFocused) },
dependencies: [imageActions.recallAll, isStaging, isGalleryFocused, isViewerFocused],
});
useRegisteredHotkeys({
id: 'recallSeed',
category: 'viewer',
callback: recallSeed.recall,
options: { enabled: recallSeed.isEnabled && isFocusOK },
dependencies: [recallSeed, isFocusOK],
callback: imageActions.recallSeed,
options: { enabled: isGalleryFocused || isViewerFocused },
dependencies: [imageActions.recallSeed, isGalleryFocused, isViewerFocused],
});
useRegisteredHotkeys({
id: 'recallPrompts',
category: 'viewer',
callback: recallPrompts.recall,
options: { enabled: recallPrompts.isEnabled && isFocusOK },
dependencies: [recallPrompts, isFocusOK],
callback: imageActions.recallPrompts,
options: { enabled: isGalleryFocused || isViewerFocused },
dependencies: [imageActions.recallPrompts, isGalleryFocused, isViewerFocused],
});
useRegisteredHotkeys({
id: 'remix',
category: 'viewer',
callback: recallRemix.recall,
options: { enabled: recallRemix.isEnabled && isFocusOK },
dependencies: [recallRemix, isFocusOK],
callback: imageActions.remix,
options: { enabled: isGalleryFocused || isViewerFocused },
dependencies: [imageActions.remix, isGalleryFocused, isViewerFocused],
});
useRegisteredHotkeys({
id: 'useSize',
category: 'viewer',
callback: recallDimensions.recall,
options: { enabled: recallDimensions.isEnabled && isFocusOK },
dependencies: [recallDimensions, isFocusOK],
callback: imageActions.recallSize,
options: { enabled: !isStaging && (isGalleryFocused || isViewerFocused) },
dependencies: [imageActions.recallSize, isStaging, isGalleryFocused, isViewerFocused],
});
useRegisteredHotkeys({
id: 'runPostprocessing',
category: 'viewer',
callback: imageActions.upscale,
options: { enabled: isUpscalingEnabled && isViewerFocused },
dependencies: [isUpscalingEnabled, imageDTO, isViewerFocused],
});
return null;
});

View File

@@ -42,6 +42,7 @@ import { $socketOptions } from 'services/events/stores';
import type { ManagerOptions, SocketOptions } from 'socket.io-client';
const App = lazy(() => import('./App'));
const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider'));
interface Props extends PropsWithChildren {
apiUrl?: string;
@@ -329,7 +330,9 @@ const InvokeAIUI = ({
<React.StrictMode>
<Provider store={store}>
<React.Suspense fallback={<Loading />}>
<App config={config} studioInitAction={studioInitAction} />
<ThemeLocaleProvider>
<App config={config} studioInitAction={studioInitAction} />
</ThemeLocaleProvider>
</React.Suspense>
</Provider>
</React.StrictMode>

View File

@@ -8,7 +8,7 @@ import { paramsReset } from 'features/controlLayers/store/paramsSlice';
import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
import { imageDTOToImageObject } from 'features/controlLayers/store/util';
import { sentImageToCanvas } from 'features/gallery/store/actions';
import { MetadataUtils } from 'features/metadata/parsing';
import { parseAndRecallAllMetadata } from 'features/metadata/util/handlers';
import { $hasTemplates } from 'features/nodes/store/nodesSlice';
import { $isWorkflowLibraryModalOpen } from 'features/nodes/store/workflowLibraryModal';
import {
@@ -19,9 +19,7 @@ import {
} from 'features/nodes/store/workflowLibrarySlice';
import { $isStylePresetsMenuOpen, activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice';
import { toast } from 'features/toast/toast';
import { navigationApi } from 'features/ui/layouts/navigation-api';
import { LAUNCHPAD_PANEL_ID, WORKSPACE_PANEL_ID } from 'features/ui/layouts/shared';
import { activeTabCanvasRightPanelChanged } from 'features/ui/store/uiSlice';
import { activeTabCanvasRightPanelChanged, setActiveTab } from 'features/ui/store/uiSlice';
import { useLoadWorkflowWithDialog } from 'features/workflowLibrary/components/LoadWorkflowConfirmationAlertDialog';
import { atom } from 'nanostores';
import { useCallback, useEffect } from 'react';
@@ -92,7 +90,6 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
const overrides: Partial<CanvasRasterLayerState> = {
objects: [imageObject],
};
await navigationApi.focusPanel('canvas', WORKSPACE_PANEL_ID);
store.dispatch(canvasReset());
store.dispatch(rasterLayerAdded({ overrides, isSelected: true }));
store.dispatch(sentImageToCanvas());
@@ -119,23 +116,23 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
const metadata = getImageMetadataResult.value;
store.dispatch(canvasReset());
// This shows a toast
await MetadataUtils.recallAll(metadata, store);
await parseAndRecallAllMetadata(metadata, true);
},
[store, t]
);
const handleLoadWorkflow = useCallback(
(workflowId: string) => {
async (workflowId: string) => {
// This shows a toast
loadWorkflowWithDialog({
await loadWorkflowWithDialog({
type: 'library',
data: workflowId,
onSuccess: () => {
navigationApi.switchToTab('workflows');
store.dispatch(setActiveTab('workflows'));
},
});
},
[loadWorkflowWithDialog]
[loadWorkflowWithDialog, store]
);
const handleSelectStylePreset = useCallback(
@@ -149,7 +146,7 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
return;
}
store.dispatch(activeStylePresetIdChanged(stylePresetId));
navigationApi.switchToTab('canvas');
store.dispatch(setActiveTab('canvas'));
toast({
title: t('toast.stylePresetLoaded'),
status: 'info',
@@ -159,34 +156,33 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
);
const handleGoToDestination = useCallback(
async (destination: StudioDestinationAction['data']['destination']) => {
(destination: StudioDestinationAction['data']['destination']) => {
switch (destination) {
case 'generation':
// Go to the generate tab, open the launchpad
await navigationApi.focusPanel('generate', LAUNCHPAD_PANEL_ID);
// Go to the canvas tab, open the image viewer, and enable send-to-gallery mode
store.dispatch(paramsReset());
store.dispatch(activeTabCanvasRightPanelChanged('gallery'));
break;
case 'canvas':
// Go to the canvas tab, open the launchpad
await navigationApi.focusPanel('canvas', WORKSPACE_PANEL_ID);
// Go to the canvas tab, close the image viewer, and disable send-to-gallery mode
store.dispatch(canvasReset());
break;
case 'workflows':
// Go to the workflows tab
navigationApi.switchToTab('workflows');
store.dispatch(setActiveTab('workflows'));
break;
case 'upscaling':
// Go to the upscaling tab
navigationApi.switchToTab('upscaling');
store.dispatch(setActiveTab('upscaling'));
break;
case 'viewAllWorkflows':
// Go to the workflows tab and open the workflow library modal
navigationApi.switchToTab('workflows');
store.dispatch(setActiveTab('workflows'));
$isWorkflowLibraryModalOpen.set(true);
break;
case 'viewAllWorkflowsRecommended':
// Go to the workflows tab and open the workflow library modal with the recommended workflows view
navigationApi.switchToTab('workflows');
store.dispatch(setActiveTab('workflows'));
$isWorkflowLibraryModalOpen.set(true);
store.dispatch(workflowLibraryViewChanged('defaults'));
store.dispatch(workflowLibraryTagsReset());
@@ -198,7 +194,7 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
break;
case 'viewAllStylePresets':
// Go to the canvas tab and open the style presets menu
navigationApi.switchToTab('canvas');
store.dispatch(setActiveTab('canvas'));
$isStylePresetsMenuOpen.set(true);
break;
}

View File

@@ -2,7 +2,7 @@ import { createLogWriter } from '@roarr/browser-log-writer';
import { atom } from 'nanostores';
import type { Logger, MessageSerializer } from 'roarr';
import { ROARR, Roarr } from 'roarr';
import { z } from 'zod/v4';
import { z } from 'zod';
const serializeMessage: MessageSerializer = (message) => {
return JSON.stringify(message);

View File

@@ -1,13 +1,13 @@
import { objectEquals } from '@observ33r/object-equals';
import { createDraftSafeSelectorCreator, createSelectorCreator, lruMemoize } from '@reduxjs/toolkit';
import { isEqual } from 'lodash-es';
/**
* A memoized selector creator that uses LRU cache and @observ33r/object-equals's objectEquals for equality check.
* A memoized selector creator that uses LRU cache and lodash's isEqual for equality check.
*/
export const createMemoizedSelector = createSelectorCreator({
memoize: lruMemoize,
memoizeOptions: {
resultEqualityCheck: objectEquals,
resultEqualityCheck: isEqual,
},
argsMemoize: lruMemoize,
});

View File

@@ -8,13 +8,10 @@ import { diff } from 'jsondiffpatch';
* Super simple logger middleware. Useful for debugging when the redux devtools are awkward.
*/
export const getDebugLoggerMiddleware =
(options?: { filter?: (action: unknown) => boolean; withDiff?: boolean; withNextState?: boolean }): Middleware =>
(options?: { withDiff?: boolean; withNextState?: boolean }): Middleware =>
(api: MiddlewareAPI) =>
(next) =>
(action) => {
if (options?.filter?.(action)) {
return next(action);
}
const originalState = api.getState();
console.log('REDUX: dispatching', action);
const result = next(action);

View File

@@ -8,6 +8,9 @@ import { addBatchEnqueuedListener } from 'app/store/middleware/listenerMiddlewar
import { addDeleteBoardAndImagesFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/boardAndImagesDeleted';
import { addBoardIdSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/boardIdSelected';
import { addBulkDownloadListeners } from 'app/store/middleware/listenerMiddleware/listeners/bulkDownload';
import { addEnqueueRequestedLinear } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear';
import { addEnsureImageIsSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/ensureImageIsSelectedListener';
import { addGalleryImageClickedListener } from 'app/store/middleware/listenerMiddleware/listeners/galleryImageClicked';
import { addGetOpenAPISchemaListener } from 'app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema';
import { addImageAddedToBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageAddedToBoard';
import { addImageRemovedFromBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageRemovedFromBoard';
@@ -19,6 +22,7 @@ import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMi
import type { AppDispatch, RootState } from 'app/store/store';
import { addArchivedOrDeletedBoardListener } from './listeners/addArchivedOrDeletedBoardListener';
import { addEnqueueRequestedUpscale } from './listeners/enqueueRequestedUpscale';
export const listenerMiddleware = createListenerMiddleware();
@@ -40,7 +44,12 @@ addImageUploadedFulfilledListener(startAppListening);
// Image deleted
addDeleteBoardAndImagesFulfilledListener(startAppListening);
// Gallery
addGalleryImageClickedListener(startAppListening);
// User Invoked
addEnqueueRequestedLinear(startAppListening);
addEnqueueRequestedUpscale(startAppListening);
addAnyEnqueuedListener(startAppListening);
addBatchEnqueuedListener(startAppListening);
@@ -71,3 +80,5 @@ addAppConfigReceivedListener(startAppListening);
addAdHocPostProcessingRequestedListener(startAppListening);
addSetDefaultSettingsListener(startAppListening);
addEnsureImageIsSelectedListener(startAppListening);

View File

@@ -1,29 +1,15 @@
import { createAction } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { imagesApi } from 'services/api/endpoints/images';
export const appStarted = createAction('app/appStarted');
export const addAppStartedListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: appStarted,
effect: async (action, { unsubscribe, cancelActiveListeners, take, getState, dispatch }) => {
effect: (action, { unsubscribe, cancelActiveListeners }) => {
// this should only run once
cancelActiveListeners();
unsubscribe();
// ensure an image is selected when we load the first board
const firstImageLoad = await take(imagesApi.endpoints.getImageNames.matchFulfilled);
if (firstImageLoad !== null) {
const [{ payload }] = firstImageLoad;
const selectedImage = selectLastSelectedImage(getState());
if (selectedImage) {
return;
}
dispatch(imageSelected(payload.image_names.at(0) ?? null));
}
},
});
};

View File

@@ -1,9 +1,9 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { truncate } from 'es-toolkit/compat';
import { zPydanticValidationError } from 'features/system/store/zodSchemas';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { truncate } from 'lodash-es';
import { serializeError } from 'serialize-error';
import { queueApi } from 'services/api/endpoints/queue';
import type { JsonObject } from 'type-fest';

View File

@@ -1,6 +1,6 @@
import { isAnyOf } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectGetImageNamesQueryArgs, selectSelectedBoardId } from 'features/gallery/store/gallerySelectors';
import { selectListImagesBaseQueryArgs } from 'features/gallery/store/gallerySelectors';
import { boardIdSelected, galleryViewChanged, imageSelected } from 'features/gallery/store/gallerySlice';
import { imagesApi } from 'services/api/endpoints/images';
@@ -11,35 +11,36 @@ export const addBoardIdSelectedListener = (startAppListening: AppStartListening)
// Cancel any in-progress instances of this listener, we don't want to select an image from a previous board
cancelActiveListeners();
if (boardIdSelected.match(action) && action.payload.selectedImageName) {
// This action already has a selected image name, we trust it is valid
return;
}
const state = getState();
const board_id = selectSelectedBoardId(state);
const queryArgs = { ...selectGetImageNamesQueryArgs(state), board_id };
const queryArgs = { ...selectListImagesBaseQueryArgs(state), offset: 0 };
// wait until the board has some images - maybe it already has some from a previous fetch
// must use getState() to ensure we do not have stale state
const isSuccess = await condition(
() => imagesApi.endpoints.getImageNames.select(queryArgs)(getState()).isSuccess,
() => imagesApi.endpoints.listImages.select(queryArgs)(getState()).isSuccess,
5000
);
if (!isSuccess) {
if (isSuccess) {
// the board was just changed - we can select the first image
const { data: boardImagesData } = imagesApi.endpoints.listImages.select(queryArgs)(getState());
if (boardImagesData && boardIdSelected.match(action) && action.payload.selectedImageName) {
const selectedImage = boardImagesData.items.find(
(item) => item.image_name === action.payload.selectedImageName
);
dispatch(imageSelected(selectedImage?.image_name ?? null));
} else if (boardImagesData) {
dispatch(imageSelected(boardImagesData.items[0]?.image_name ?? null));
} else {
// board has no images - deselect
dispatch(imageSelected(null));
}
} else {
// fallback - deselect
dispatch(imageSelected(null));
return;
}
// the board was just changed - we can select the first image
const imageNames = imagesApi.endpoints.getImageNames.select(queryArgs)(getState()).data?.image_names;
const imageToSelect = imageNames?.at(0) ?? null;
dispatch(imageSelected(imageToSelect));
},
});
};

View File

@@ -0,0 +1,151 @@
import type { AlertStatus } from '@invoke-ai/ui-library';
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError';
import { withResult, withResultAsync } from 'common/util/result';
import { parseify } from 'common/util/serialize';
import {
canvasSessionIdCreated,
generateSessionIdCreated,
selectCanvasSessionId,
selectGenerateSessionId,
} from 'features/controlLayers/store/canvasStagingAreaSlice';
import { $canvasManager } from 'features/controlLayers/store/ephemeral';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildChatGPT4oGraph';
import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph';
import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph';
import { buildImagen3Graph } from 'features/nodes/util/graph/generation/buildImagen3Graph';
import { buildImagen4Graph } from 'features/nodes/util/graph/generation/buildImagen4Graph';
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Graph';
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
import { toast } from 'features/toast/toast';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { serializeError } from 'serialize-error';
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
import { assert, AssertionError } from 'tsafe';
const log = logger('generation');
export const enqueueRequestedCanvas = createAction<{ prepend: boolean }>('app/enqueueRequestedCanvas');
export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: enqueueRequestedCanvas,
effect: async (action, { getState, dispatch }) => {
log.debug('Enqueue requested');
const tab = selectActiveTab(getState());
let sessionId = null;
if (tab === 'generate') {
sessionId = selectGenerateSessionId(getState());
if (!sessionId) {
dispatch(generateSessionIdCreated());
sessionId = selectGenerateSessionId(getState());
}
} else if (tab === 'canvas') {
sessionId = selectCanvasSessionId(getState());
if (!sessionId) {
dispatch(canvasSessionIdCreated());
sessionId = selectCanvasSessionId(getState());
}
} else {
log.warn(`Enqueue requested in unsupported tab ${tab}`);
return;
}
const state = getState();
const destination = sessionId;
assert(destination !== null);
const { prepend } = action.payload;
const manager = $canvasManager.get();
// assert(manager, 'No canvas manager');
const model = state.params.model;
assert(model, 'No model found in state');
const base = model.base;
const buildGraphResult = await withResultAsync(async () => {
switch (base) {
case 'sdxl':
return await buildSDXLGraph(state, manager);
case 'sd-1':
case `sd-2`:
return await buildSD1Graph(state, manager);
case `sd-3`:
return await buildSD3Graph(state, manager);
case `flux`:
return await buildFLUXGraph(state, manager);
case 'cogview4':
return await buildCogView4Graph(state, manager);
case 'imagen3':
return await buildImagen3Graph(state, manager);
case 'imagen4':
return await buildImagen4Graph(state, manager);
case 'chatgpt-4o':
return await buildChatGPT4oGraph(state, manager);
default:
assert(false, `No graph builders for base ${base}`);
}
});
if (buildGraphResult.isErr()) {
let title = 'Failed to build graph';
let status: AlertStatus = 'error';
let description: string | null = null;
if (buildGraphResult.error instanceof AssertionError) {
description = extractMessageFromAssertionError(buildGraphResult.error);
} else if (buildGraphResult.error instanceof UnsupportedGenerationModeError) {
title = 'Unsupported generation mode';
description = buildGraphResult.error.message;
status = 'warning';
}
const error = serializeError(buildGraphResult.error);
log.error({ error }, 'Failed to build graph');
toast({
status,
title,
description,
});
return;
}
const { g, seedFieldIdentifier, positivePromptFieldIdentifier } = buildGraphResult.value;
const prepareBatchResult = withResult(() =>
prepareLinearUIBatch({
state,
g,
prepend,
seedFieldIdentifier,
positivePromptFieldIdentifier,
origin: 'canvas',
destination,
})
);
if (prepareBatchResult.isErr()) {
log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
return;
}
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(prepareBatchResult.value, enqueueMutationFixedCacheKeyOptions)
);
try {
await req.unwrap();
log.debug(parseify({ batchConfig: prepareBatchResult.value }), 'Enqueued batch');
} catch (error) {
log.error({ error: serializeError(error as Error) }, 'Failed to enqueue batch');
} finally {
req.reset();
}
},
});
};

View File

@@ -0,0 +1,44 @@
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { parseify } from 'common/util/serialize';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildMultidiffusionUpscaleGraph } from 'features/nodes/util/graph/buildMultidiffusionUpscaleGraph';
import { serializeError } from 'serialize-error';
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
const log = logger('generation');
export const enqueueRequestedUpscaling = createAction<{ prepend: boolean }>('app/enqueueRequestedUpscaling');
export const addEnqueueRequestedUpscale = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: enqueueRequestedUpscaling,
effect: async (action, { getState, dispatch }) => {
const state = getState();
const { prepend } = action.payload;
const { g, seedFieldIdentifier, positivePromptFieldIdentifier } = await buildMultidiffusionUpscaleGraph(state);
const batchConfig = prepareLinearUIBatch({
state,
g,
prepend,
seedFieldIdentifier,
positivePromptFieldIdentifier,
origin: 'upscaling',
destination: 'gallery',
});
const req = dispatch(queueApi.endpoints.enqueueBatch.initiate(batchConfig, enqueueMutationFixedCacheKeyOptions));
try {
await req.unwrap();
log.debug(parseify({ batchConfig }), 'Enqueued batch');
} catch (error) {
log.error({ error: serializeError(error as Error) }, 'Failed to enqueue batch');
} finally {
req.reset();
}
},
});
};

View File

@@ -0,0 +1,16 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { imagesApi } from 'services/api/endpoints/images';
export const addEnsureImageIsSelectedListener = (startAppListening: AppStartListening) => {
// When we list images, if no images is selected, select the first one.
startAppListening({
matcher: imagesApi.endpoints.listImages.matchFulfilled,
effect: (action, { dispatch, getState }) => {
const selection = getState().gallery.selection;
if (selection.length === 0) {
dispatch(imageSelected(action.payload.items[0]?.image_name ?? null));
}
},
});
};

View File

@@ -0,0 +1,77 @@
import { createAction } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectListImageNamesQueryArgs } from 'features/gallery/store/gallerySelectors';
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
import { uniq } from 'lodash-es';
import { imagesApi } from 'services/api/endpoints/images';
export const galleryImageClicked = createAction<{
imageName: string;
shiftKey: boolean;
ctrlKey: boolean;
metaKey: boolean;
altKey: boolean;
}>('gallery/imageClicked');
/**
* This listener handles the logic for selecting images in the gallery.
*
* Previously, this logic was in a `useCallback` with the whole gallery selection as a dependency. Every time
* the selection changed, the callback got recreated and all images rerendered. This could easily block for
* hundreds of ms, more for lower end devices.
*
* Moving this logic into a listener means we don't need to recalculate anything dynamically and the gallery
* is much more responsive.
*/
export const addGalleryImageClickedListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: galleryImageClicked,
effect: (action, { dispatch, getState }) => {
const { imageName, shiftKey, ctrlKey, metaKey, altKey } = action.payload;
const state = getState();
const queryArgs = selectListImageNamesQueryArgs(state);
const imageNames = imagesApi.endpoints.getImageNames.select(queryArgs)(state).data ?? [];
// If we don't have the image names cached, we can't perform selection operations
// This can happen if the user clicks on an image before the names are loaded
if (imageNames.length === 0) {
// For basic click without modifiers, we can still set selection
if (!shiftKey && !ctrlKey && !metaKey && !altKey) {
dispatch(selectionChanged([imageName]));
}
return;
}
const selection = state.gallery.selection;
if (altKey) {
if (state.gallery.imageToCompare === imageName) {
dispatch(imageToCompareChanged(null));
} else {
dispatch(imageToCompareChanged(imageName));
}
} else if (shiftKey) {
const rangeEndImageName = imageName;
const lastSelectedImage = selection.at(-1);
const lastClickedIndex = imageNames.findIndex((name) => name === lastSelectedImage);
const currentClickedIndex = imageNames.findIndex((name) => name === rangeEndImageName);
if (lastClickedIndex > -1 && currentClickedIndex > -1) {
// We have a valid range!
const start = Math.min(lastClickedIndex, currentClickedIndex);
const end = Math.max(lastClickedIndex, currentClickedIndex);
const imagesToSelect = imageNames.slice(start, end + 1);
dispatch(selectionChanged(uniq(selection.concat(imagesToSelect))));
}
} else if (ctrlKey || metaKey) {
if (selection.some((n) => n === imageName) && selection.length > 1) {
dispatch(selectionChanged(uniq(selection.filter((n) => n !== imageName))));
} else {
dispatch(selectionChanged(uniq(selection.concat(imageName))));
}
} else {
dispatch(selectionChanged([imageName]));
}
},
});
};

View File

@@ -1,9 +1,9 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { parseify } from 'common/util/serialize';
import { size } from 'es-toolkit/compat';
import { $templates } from 'features/nodes/store/nodesSlice';
import { parseSchema } from 'features/nodes/util/schema/parseSchema';
import { size } from 'lodash-es';
import { serializeError } from 'serialize-error';
import { appInfoApi } from 'services/api/endpoints/appInfo';
import type { JsonObject } from 'type-fest';

View File

@@ -2,12 +2,12 @@ import { isAnyOf } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { RootState } from 'app/store/store';
import { omit } from 'es-toolkit/compat';
import { imageUploadedClientSide } from 'features/gallery/store/actions';
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
import { boardIdSelected, galleryViewChanged } from 'features/gallery/store/gallerySlice';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { omit } from 'lodash-es';
import { boardsApi } from 'services/api/endpoints/boards';
import { imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';

View File

@@ -1,28 +1,14 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { bboxSyncedToOptimalDimension, rgRefImageModelChanged } from 'features/controlLayers/store/canvasSlice';
import { bboxSyncedToOptimalDimension } from 'features/controlLayers/store/canvasSlice';
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
import { modelChanged, syncedToOptimalDimension, vaeSelected } from 'features/controlLayers/store/paramsSlice';
import { refImageModelChanged, selectReferenceImageEntities } from 'features/controlLayers/store/refImagesSlice';
import {
selectAllEntitiesOfType,
selectBboxModelBase,
selectCanvasSlice,
} from 'features/controlLayers/store/selectors';
import { getEntityIdentifier } from 'features/controlLayers/store/types';
import { modelChanged, vaeSelected } from 'features/controlLayers/store/paramsSlice';
import { selectBboxModelBase } from 'features/controlLayers/store/selectors';
import { modelSelected } from 'features/parameters/store/actions';
import { zParameterModel } from 'features/parameters/types/parameterSchemas';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { selectGlobalRefImageModels, selectRegionalRefImageModels } from 'services/api/hooks/modelsByType';
import type { AnyModelConfig } from 'services/api/types';
import {
isChatGPT4oModelConfig,
isFluxKontextApiModelConfig,
isFluxKontextModelConfig,
isFluxReduxModelConfig,
} from 'services/api/types';
const log = logger('models');
@@ -39,8 +25,9 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
}
const newModel = result.data;
const newBase = newModel.base;
const didBaseModelChange = state.params.model?.base !== newBase;
const newBaseModel = newModel.base;
const didBaseModelChange = state.params.model?.base !== newBaseModel;
if (didBaseModelChange) {
// we may need to reset some incompatible submodels
@@ -48,7 +35,7 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
// handle incompatible loras
state.loras.loras.forEach((lora) => {
if (lora.model.base !== newBase) {
if (lora.model.base !== newBaseModel) {
dispatch(loraDeleted({ id: lora.id }));
modelsCleared += 1;
}
@@ -56,82 +43,20 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
// handle incompatible vae
const { vae } = state.params;
if (vae && vae.base !== newBase) {
if (vae && vae.base !== newBaseModel) {
dispatch(vaeSelected(null));
modelsCleared += 1;
}
// Handle incompatible reference image models - switch to first compatible model, with some smart logic
// to choose the best available model based on the new main model.
const allRefImageModels = selectGlobalRefImageModels(state).filter(({ base }) => base === newBase);
let newGlobalRefImageModel = null;
// Certain models require the ref image model to be the same as the main model - others just need a matching
// base. Helper to grab the first exact match or the first available model if no exact match is found.
const exactMatchOrFirst = <T extends AnyModelConfig>(candidates: T[]): T | null =>
candidates.find(({ key }) => key === newModel.key) ?? candidates[0] ?? null;
// The only way we can differentiate between FLUX and FLUX Kontext is to check for "kontext" in the name
if (newModel.base === 'flux' && newModel.name.toLowerCase().includes('kontext')) {
const fluxKontextDevModels = allRefImageModels.filter(isFluxKontextModelConfig);
newGlobalRefImageModel = exactMatchOrFirst(fluxKontextDevModels);
} else if (newModel.base === 'chatgpt-4o') {
const chatGPT4oModels = allRefImageModels.filter(isChatGPT4oModelConfig);
newGlobalRefImageModel = exactMatchOrFirst(chatGPT4oModels);
} else if (newModel.base === 'flux-kontext') {
const fluxKontextApiModels = allRefImageModels.filter(isFluxKontextApiModelConfig);
newGlobalRefImageModel = exactMatchOrFirst(fluxKontextApiModels);
} else if (newModel.base === 'flux') {
const fluxReduxModels = allRefImageModels.filter(isFluxReduxModelConfig);
newGlobalRefImageModel = fluxReduxModels[0] ?? null;
} else {
newGlobalRefImageModel = allRefImageModels[0] ?? null;
}
// All ref image entities are updated to use the same new model
const refImageEntities = selectReferenceImageEntities(state);
for (const entity of refImageEntities) {
const shouldUpdateModel =
(entity.config.model && entity.config.model.base !== newBase) ||
(!entity.config.model && newGlobalRefImageModel);
if (shouldUpdateModel) {
dispatch(
refImageModelChanged({
id: entity.id,
modelConfig: newGlobalRefImageModel,
})
);
modelsCleared += 1;
}
}
// For regional guidance, there is no smart logic - we just pick the first available model.
const newRegionalRefImageModel = selectRegionalRefImageModels(state)[0] ?? null;
// All regional guidance entities are updated to use the same new model.
const canvasState = selectCanvasSlice(state);
const canvasRegionalGuidanceEntities = selectAllEntitiesOfType(canvasState, 'regional_guidance');
for (const entity of canvasRegionalGuidanceEntities) {
for (const refImage of entity.referenceImages) {
// Only change the model if the current one is not compatible with the new base model.
const shouldUpdateModel =
(refImage.config.model && refImage.config.model.base !== newBase) ||
(!refImage.config.model && newRegionalRefImageModel);
if (shouldUpdateModel) {
dispatch(
rgRefImageModelChanged({
entityIdentifier: getEntityIdentifier(entity),
referenceImageId: refImage.id,
modelConfig: newRegionalRefImageModel,
})
);
modelsCleared += 1;
}
}
}
// handle incompatible controlnets
// state.canvas.present.controlAdapters.entities.forEach((ca) => {
// if (ca.model?.base !== newBaseModel) {
// modelsCleared += 1;
// if (ca.isEnabled) {
// dispatch(entityIsEnabledToggled({ entityIdentifier: { id: ca.id, type: 'control_adapter' } }));
// }
// }
// });
if (modelsCleared > 0) {
toast({
@@ -146,16 +71,9 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
}
dispatch(modelChanged({ model: newModel, previousModel: state.params.model }));
const modelBase = selectBboxModelBase(state);
if (modelBase !== state.params.model?.base) {
// Sync generate tab settings whenever the model base changes
dispatch(syncedToOptimalDimension());
if (!selectIsStaging(state)) {
// Canvas tab only syncs if not staging
dispatch(bboxSyncedToOptimalDimension());
}
if (!selectIsStaging(state) && modelBase !== state.params.model?.base) {
dispatch(bboxSyncedToOptimalDimension());
}
},
});

View File

@@ -1,9 +1,7 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { isNil } from 'es-toolkit';
import { bboxHeightChanged, bboxWidthChanged } from 'features/controlLayers/store/canvasSlice';
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import {
heightChanged,
setCfgRescaleMultiplier,
setCfgScale,
setGuidance,
@@ -11,7 +9,6 @@ import {
setSteps,
vaePrecisionChanged,
vaeSelected,
widthChanged,
} from 'features/controlLayers/store/paramsSlice';
import { setDefaultSettings } from 'features/parameters/store/actions';
import {
@@ -26,7 +23,6 @@ import {
zParameterVAEModel,
} from 'features/parameters/types/parameterSchemas';
import { toast } from 'features/toast/toast';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { t } from 'i18next';
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
import { isNonRefinerMainModelConfig } from 'services/api/types';
@@ -90,16 +86,10 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
}
}
if (!isNil(cfg_rescale_multiplier)) {
if (cfg_rescale_multiplier) {
if (isParameterCFGRescaleMultiplier(cfg_rescale_multiplier)) {
dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier));
}
} else {
// Set this to 0 if it doesn't have a default. This value is
// easy to miss in the UI when users are resetting defaults
// and leaving it non-zero could lead to detrimental
// effects.
dispatch(setCfgRescaleMultiplier(0));
}
if (steps) {
@@ -116,24 +106,15 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
const setSizeOptions = { updateAspectRatio: true, clamp: true };
const isStaging = selectIsStaging(getState());
const activeTab = selectActiveTab(getState());
if (activeTab === 'generate') {
if (!isStaging && width) {
if (isParameterWidth(width)) {
dispatch(widthChanged({ width, ...setSizeOptions }));
}
if (isParameterHeight(height)) {
dispatch(heightChanged({ height, ...setSizeOptions }));
dispatch(bboxWidthChanged({ width, ...setSizeOptions }));
}
}
if (activeTab === 'canvas') {
if (!isStaging) {
if (isParameterWidth(width)) {
dispatch(bboxWidthChanged({ width, ...setSizeOptions }));
}
if (isParameterHeight(height)) {
dispatch(bboxHeightChanged({ height, ...setSizeOptions }));
}
if (!isStaging && height) {
if (isParameterHeight(height)) {
dispatch(bboxHeightChanged({ height, ...setSizeOptions }));
}
}

View File

@@ -1,8 +1,8 @@
import { objectEquals } from '@observ33r/object-equals';
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { $baseUrl } from 'app/store/nanostores/baseUrl';
import { isEqual } from 'lodash-es';
import { atom } from 'nanostores';
import { api } from 'services/api';
import { modelsApi } from 'services/api/endpoints/models';
@@ -64,7 +64,7 @@ export const addSocketConnectedEventListener = (startAppListening: AppStartListe
const nextQueueStatusData = await queueStatusRequest.unwrap();
// If the queue hasn't changed, we don't need to do anything.
if (objectEquals(prevQueueStatusData?.queue, nextQueueStatusData.queue)) {
if (isEqual(prevQueueStatusData?.queue, nextQueueStatusData.queue)) {
return;
}

View File

@@ -0,0 +1,13 @@
import { $didStudioInit } from 'app/hooks/useStudioInitAction';
import { atom, computed } from 'nanostores';
import { flushSync } from 'react-dom';
export const $isLayoutLoading = atom(false);
export const setIsLayoutLoading = (isLoading: boolean) => {
flushSync(() => {
$isLayoutLoading.set(isLoading);
});
};
export const $globalIsLoading = computed([$didStudioInit, $isLayoutLoading], (didStudioInit, isLayoutLoading) => {
return !didStudioInit || isLayoutLoading;
});

View File

@@ -1,3 +1,4 @@
import { useStore } from '@nanostores/react';
import type { AppStore } from 'app/store/store';
import { atom } from 'nanostores';
@@ -31,3 +32,11 @@ export const getStore = () => {
}
return store;
};
export const useAppStore = () => {
const store = useStore($store);
if (!store) {
throw new ReduxStoreNotInitialized();
}
return store;
};

View File

@@ -11,7 +11,5 @@ export const $false: ReadableAtom<boolean> = atom(false);
/**
* A fallback non-writable atom that always returns `true`, used when a nanostores atom is only conditionally available
* in a hook or component.
*
* @knipignore
*/
export const $true: ReadableAtom<boolean> = atom(true);

View File

@@ -4,7 +4,6 @@ import { logger } from 'app/logging/logger';
import { idbKeyValDriver } from 'app/store/enhancers/reduxRemember/driver';
import { errorHandler } from 'app/store/enhancers/reduxRemember/errors';
import { deepClone } from 'common/util/deepClone';
import { keys, mergeWith, omit, pick } from 'es-toolkit/compat';
import { changeBoardModalSlice } from 'features/changeBoardModal/store/slice';
import { canvasSettingsPersistConfig, canvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
import { canvasPersistConfig, canvasSlice, canvasUndoableConfig } from 'features/controlLayers/store/canvasSlice';
@@ -17,6 +16,7 @@ import { paramsPersistConfig, paramsSlice } from 'features/controlLayers/store/p
import { refImagesPersistConfig, refImagesSlice } from 'features/controlLayers/store/refImagesSlice';
import { dynamicPromptsPersistConfig, dynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import { galleryPersistConfig, gallerySlice } from 'features/gallery/store/gallerySlice';
import { hrfPersistConfig, hrfSlice } from 'features/hrf/store/hrfSlice';
import { modelManagerV2PersistConfig, modelManagerV2Slice } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { nodesPersistConfig, nodesSlice, nodesUndoableConfig } from 'features/nodes/store/nodesSlice';
import { workflowLibraryPersistConfig, workflowLibrarySlice } from 'features/nodes/store/workflowLibrarySlice';
@@ -28,6 +28,7 @@ import { configSlice } from 'features/system/store/configSlice';
import { systemPersistConfig, systemSlice } from 'features/system/store/systemSlice';
import { uiPersistConfig, uiSlice } from 'features/ui/store/uiSlice';
import { diff } from 'jsondiffpatch';
import { keys, mergeWith, omit, pick } from 'lodash-es';
import dynamicMiddlewares from 'redux-dynamic-middlewares';
import type { SerializeFunction, UnserializeFunction } from 'redux-remember';
import { rememberEnhancer, rememberReducer } from 'redux-remember';
@@ -56,6 +57,7 @@ const allReducers = {
[changeBoardModalSlice.name]: changeBoardModalSlice.reducer,
[modelManagerV2Slice.name]: modelManagerV2Slice.reducer,
[queueSlice.name]: queueSlice.reducer,
[hrfSlice.name]: hrfSlice.reducer,
[canvasSlice.name]: undoable(canvasSlice.reducer, canvasUndoableConfig),
[workflowSettingsSlice.name]: workflowSettingsSlice.reducer,
[upscaleSlice.name]: upscaleSlice.reducer,
@@ -101,6 +103,7 @@ const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
[uiPersistConfig.name]: uiPersistConfig,
[dynamicPromptsPersistConfig.name]: dynamicPromptsPersistConfig,
[modelManagerV2PersistConfig.name]: modelManagerV2PersistConfig,
[hrfPersistConfig.name]: hrfPersistConfig,
[canvasPersistConfig.name]: canvasPersistConfig,
[workflowSettingsPersistConfig.name]: workflowSettingsPersistConfig,
[upscalePersistConfig.name]: upscalePersistConfig,

View File

@@ -1,8 +1,8 @@
import type { AppStore, AppThunkDispatch, RootState } from 'app/store/store';
import type { AppThunkDispatch, RootState } from 'app/store/store';
import type { TypedUseSelectorHook } from 'react-redux';
import { useDispatch, useSelector, useStore } from 'react-redux';
// Use throughout your app instead of plain `useDispatch` and `useSelector`
export const useAppDispatch = () => useDispatch<AppThunkDispatch>();
export const useAppSelector: TypedUseSelectorHook<RootState> = useSelector;
export const useAppStore = () => useStore.withTypes<AppStore>()();
export const useAppStore = () => useStore<RootState>();

View File

@@ -1,6 +1,6 @@
import type { Selector } from '@reduxjs/toolkit';
import { useAppStore } from 'app/store/nanostores/store';
import type { RootState } from 'app/store/store';
import { useAppStore } from 'app/store/storeHooks';
import { useEffect, useState } from 'react';
/**

View File

@@ -14,7 +14,6 @@ export type AppFeature =
| 'githubLink'
| 'discordLink'
| 'bugLink'
| 'aboutModal'
| 'localization'
| 'consoleLogging'
| 'dynamicPrompting'
@@ -30,8 +29,7 @@ export type AppFeature =
| 'hfToken'
| 'retryQueueItem'
| 'cancelAndClearAll'
| 'chatGPT4oHigh'
| 'modelRelationships';
| 'chatGPT4oHigh';
/**
* A disable-able Stable Diffusion feature
*/
@@ -78,7 +76,6 @@ export type AppConfig = {
allowPrivateStylePresets: boolean;
allowClientSideUpload: boolean;
allowPublishWorkflows: boolean;
allowPromptExpansion: boolean;
disabledTabs: TabName[];
disabledFeatures: AppFeature[];
disabledSDFeatures: SDFeature[];

View File

@@ -0,0 +1,56 @@
import { Box, type BoxProps, type SystemStyleObject } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { type FocusRegionName, useFocusRegion, useIsRegionFocused } from 'common/hooks/focus';
import { selectSystemShouldEnableHighlightFocusedRegions } from 'features/system/store/systemSlice';
import { memo, useMemo, useRef } from 'react';
interface FocusRegionWrapperProps extends BoxProps {
region: FocusRegionName;
focusOnMount?: boolean;
}
const FOCUS_REGION_STYLES: SystemStyleObject = {
position: 'relative',
'&[data-highlighted="true"]::after': {
borderColor: 'blue.700',
},
'&::after': {
content: '""',
position: 'absolute',
inset: 0,
zIndex: 1,
borderRadius: 'base',
border: '2px solid',
borderColor: 'transparent',
pointerEvents: 'none',
transition: 'border-color 0.1s ease-in-out',
},
};
export const FocusRegionWrapper = memo(
({ region, focusOnMount = false, sx, children, ...boxProps }: FocusRegionWrapperProps) => {
const shouldHighlightFocusedRegions = useAppSelector(selectSystemShouldEnableHighlightFocusedRegions);
const ref = useRef<HTMLDivElement>(null);
const options = useMemo(() => ({ focusOnMount }), [focusOnMount]);
useFocusRegion(region, ref, options);
const isFocused = useIsRegionFocused(region);
const isHighlighted = isFocused && shouldHighlightFocusedRegions;
return (
<Box
ref={ref}
tabIndex={-1}
sx={useMemo(() => ({ ...FOCUS_REGION_STYLES, ...sx }), [sx])}
data-highlighted={isHighlighted}
{...boxProps}
>
{children}
</Box>
);
}
);
FocusRegionWrapper.displayName = 'FocusRegionWrapper';

View File

@@ -15,9 +15,9 @@ import {
} from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { merge, omit } from 'es-toolkit/compat';
import { selectSystemSlice, setShouldEnableInformationalPopovers } from 'features/system/store/systemSlice';
import { toast } from 'features/toast/toast';
import { merge, omit } from 'lodash-es';
import type { ReactElement } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';

View File

@@ -8,16 +8,21 @@ const Loading = () => {
return (
<Flex
position="absolute"
width="100dvw"
height="100dvh"
alignItems="center"
justifyContent="center"
bg="hsl(220 12% 10% / 1)" // base.900
inset={0}
bg="#151519"
top={0}
right={0}
bottom={0}
left={0}
zIndex={99999}
>
<Image src={InvokeLogoWhite} w="8rem" h="8rem" />
<Spinner
label="Loading"
color="hsl(220 12% 68% / 1)" // base.300
color="grey"
position="absolute"
size="sm"
width="24px !important"

View File

@@ -1,5 +1,5 @@
import { deepClone } from 'common/util/deepClone';
import { merge } from 'es-toolkit/compat';
import { merge } from 'lodash-es';
import { ClickScrollPlugin, OverlayScrollbars } from 'overlayscrollbars';
import type { UseOverlayScrollbarsParams } from 'overlayscrollbars-react';
import type { CSSProperties } from 'react';

View File

@@ -87,7 +87,7 @@ export const buildGroup = <T extends object>(group: Omit<Group<T>, typeof unique
[uniqueGroupKey]: true,
});
export const isGroup = <T extends object>(optionOrGroup: OptionOrGroup<T>): optionOrGroup is Group<T> => {
const isGroup = <T extends object>(optionOrGroup: OptionOrGroup<T>): optionOrGroup is Group<T> => {
return uniqueGroupKey in optionOrGroup && optionOrGroup[uniqueGroupKey] === true;
};
@@ -198,10 +198,6 @@ type PickerProps<T extends object> = {
* Whether the picker should be searchable. If true, renders a search input.
*/
searchable?: boolean;
/**
* Initial state for group toggles. If provided, groups will start with these states instead of all being disabled.
*/
initialGroupStates?: GroupStatusMap;
};
export type PickerContextState<T extends object> = {
@@ -314,9 +310,9 @@ const flattenOptions = <T extends object>(options: OptionOrGroup<T>[]): T[] => {
return flattened;
};
export type GroupStatusMap = Record<string, boolean>;
type GroupStatusMap = Record<string, boolean>;
const useTogglableGroups = <T extends object>(options: OptionOrGroup<T>[], initialGroupStates?: GroupStatusMap) => {
const useTogglableGroups = <T extends object>(options: OptionOrGroup<T>[]) => {
const groupsWithOptions = useMemo(() => {
const ids: string[] = [];
for (const optionOrGroup of options) {
@@ -336,16 +332,14 @@ const useTogglableGroups = <T extends object>(options: OptionOrGroup<T>[], initi
const groupStatusMap = $groupStatusMap.get();
const newMap: GroupStatusMap = {};
for (const id of groupsWithOptions) {
if (initialGroupStates && initialGroupStates[id] !== undefined) {
newMap[id] = initialGroupStates[id];
if (newMap[id] === undefined) {
newMap[id] = false;
} else if (groupStatusMap[id] !== undefined) {
newMap[id] = groupStatusMap[id];
} else {
newMap[id] = false;
}
}
$groupStatusMap.set(newMap);
}, [groupsWithOptions, $groupStatusMap, initialGroupStates]);
}, [groupsWithOptions, $groupStatusMap]);
const toggleGroup = useCallback(
(idToToggle: string) => {
@@ -517,14 +511,10 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
OptionComponent = DefaultOptionComponent,
NextToSearchBar,
searchable,
initialGroupStates,
} = props;
const rootRef = useRef<HTMLDivElement>(null);
const inputRef = useRef<HTMLInputElement>(null);
const { $groupStatusMap, $areAllGroupsDisabled, toggleGroup } = useTogglableGroups(
optionsOrGroups,
initialGroupStates
);
const { $groupStatusMap, $areAllGroupsDisabled, toggleGroup } = useTogglableGroups(optionsOrGroups);
const $activeOptionId = useAtom(getFirstOptionId(optionsOrGroups, getOptionId));
const $compactView = useAtom(true);
const $optionsOrGroups = useAtom(optionsOrGroups);

View File

@@ -1,15 +1,20 @@
import { MenuItem } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import {
useNewCanvasSession,
useNewGallerySession,
} from 'features/controlLayers/components/NewSessionConfirmationAlertDialog';
import { allEntitiesDeleted } from 'features/controlLayers/store/canvasSlice';
import { paramsReset } from 'features/controlLayers/store/paramsSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowsCounterClockwiseBold } from 'react-icons/pi';
import { PiArrowsCounterClockwiseBold, PiFilePlusBold } from 'react-icons/pi';
export const SessionMenuItems = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const { newGallerySessionWithDialog } = useNewGallerySession();
const { newCanvasSessionWithDialog } = useNewCanvasSession();
const resetCanvasLayers = useCallback(() => {
dispatch(allEntitiesDeleted());
}, [dispatch]);
@@ -18,6 +23,12 @@ export const SessionMenuItems = memo(() => {
}, [dispatch]);
return (
<>
<MenuItem icon={<PiFilePlusBold />} onClick={newGallerySessionWithDialog}>
{t('controlLayers.newGallerySession')}
</MenuItem>
<MenuItem icon={<PiFilePlusBold />} onClick={newCanvasSessionWithDialog}>
{t('controlLayers.newCanvasSession')}
</MenuItem>
<MenuItem icon={<PiArrowsCounterClockwiseBold />} onClick={resetCanvasLayers}>
{t('controlLayers.resetCanvasLayers')}
</MenuItem>

View File

@@ -6,7 +6,6 @@ import { atom, computed } from 'nanostores';
import type { RefObject } from 'react';
import { useEffect } from 'react';
import { objectKeys } from 'tsafe';
import z from 'zod/v4';
/**
* We need to manage focus regions to conditionally enable hotkeys:
@@ -31,34 +30,23 @@ const log = logger('system');
/**
* The names of the focus regions.
*/
const zFocusRegionName = z.enum([
'launchpad',
'viewer',
'gallery',
'boards',
'layers',
'canvas',
'workflows',
'progress',
'settings',
]);
export type FocusRegionName = z.infer<typeof zFocusRegionName>;
export type FocusRegionName = 'gallery' | 'layers' | 'canvas' | 'workflows' | 'viewer';
/**
* A map of focus regions to the elements that are part of that region.
*/
const REGION_TARGETS: Record<FocusRegionName, Set<HTMLElement>> = zFocusRegionName.options.values().reduce(
(acc, region) => {
acc[region] = new Set<HTMLElement>();
return acc;
},
{} as Record<FocusRegionName, Set<HTMLElement>>
);
const REGION_TARGETS: Record<FocusRegionName, Set<HTMLElement>> = {
gallery: new Set<HTMLElement>(),
layers: new Set<HTMLElement>(),
canvas: new Set<HTMLElement>(),
workflows: new Set<HTMLElement>(),
viewer: new Set<HTMLElement>(),
} as const;
/**
* The currently-focused region or `null` if no region is focused.
*/
const $focusedRegion = atom<FocusRegionName | null>(null);
export const $focusedRegion = atom<FocusRegionName | null>(null);
/**
* A map of focus regions to atoms that indicate if that region is focused.
@@ -74,13 +62,11 @@ const FOCUS_REGIONS = objectKeys(REGION_TARGETS).reduce(
/**
* Sets the focused region, logging a trace level message.
*/
export const setFocusedRegion = (region: FocusRegionName | null) => {
const setFocus = (region: FocusRegionName | null) => {
$focusedRegion.set(region);
log.trace(`Focus changed: ${region}`);
};
export const getFocusedRegion = () => $focusedRegion.get();
type UseFocusRegionOptions = {
focusOnMount?: boolean;
};
@@ -113,14 +99,14 @@ export const useFocusRegion = (
REGION_TARGETS[region].add(element);
if (focusOnMount) {
setFocusedRegion(region);
setFocus(region);
}
return () => {
REGION_TARGETS[region].delete(element);
if (REGION_TARGETS[region].size === 0 && $focusedRegion.get() === region) {
setFocusedRegion(null);
setFocus(null);
}
};
}, [options, ref, region]);
@@ -177,7 +163,7 @@ const onFocus = (_: FocusEvent) => {
return;
}
setFocusedRegion(focusedRegion);
setFocus(focusedRegion);
};
/**

View File

@@ -1,115 +0,0 @@
import { useStore } from '@nanostores/react';
import { WrappedError } from 'common/util/result';
import type { Atom } from 'nanostores';
import { atom } from 'nanostores';
import { useCallback, useEffect, useMemo, useState } from 'react';
type SuccessState<T> = {
status: 'success';
value: T;
error: null;
};
type ErrorState = {
status: 'error';
value: null;
error: Error;
};
type PendingState = {
status: 'pending';
value: null;
error: null;
};
type IdleState = {
status: 'idle';
value: null;
error: null;
};
export type State<T> = IdleState | PendingState | SuccessState<T> | ErrorState;
type UseAsyncStateOptions = {
immediate?: boolean;
};
type UseAsyncReturn<T> = {
$state: Atom<State<T>>;
trigger: () => Promise<void>;
reset: () => void;
};
export const useAsyncState = <T>(execute: () => Promise<T>, options?: UseAsyncStateOptions): UseAsyncReturn<T> => {
const $state = useState(() =>
atom<State<T>>({
status: 'idle',
value: null,
error: null,
})
)[0];
const trigger = useCallback(async () => {
$state.set({
status: 'pending',
value: null,
error: null,
});
try {
const value = await execute();
$state.set({
status: 'success',
value,
error: null,
});
} catch (error) {
$state.set({
status: 'error',
value: null,
error: WrappedError.wrap(error),
});
}
}, [$state, execute]);
const reset = useCallback(() => {
$state.set({
status: 'idle',
value: null,
error: null,
});
}, [$state]);
useEffect(() => {
if (options?.immediate) {
trigger();
}
}, [options?.immediate, trigger]);
const api = useMemo(
() =>
({
$state,
trigger,
reset,
}) satisfies UseAsyncReturn<T>,
[$state, trigger, reset]
);
return api;
};
type UseAsyncReturnReactive<T> = {
state: State<T>;
trigger: () => Promise<void>;
reset: () => void;
};
export const useAsyncStateReactive = <T>(
execute: () => Promise<T>,
options?: UseAsyncStateOptions
): UseAsyncReturnReactive<T> => {
const { $state, trigger, reset } = useAsyncState(execute, options);
const state = useStore($state);
return { state, trigger, reset };
};

View File

@@ -73,7 +73,7 @@ export const useBoolean = (initialValue: boolean): UseBoolean => {
};
};
type UseDisclosure = {
export type UseDisclosure = {
isOpen: boolean;
open: () => void;
close: () => void;

View File

@@ -0,0 +1,165 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
/**
* Adapted from https://github.com/chakra-ui/chakra-ui/blob/v2/packages/hooks/src/use-outside-click.ts
*
* The main change here is to support filtering of outside clicks via a `filter` function.
*
* This lets us work around issues with portals and components like popovers, which typically close on an outside click.
*
* For example, consider a popover that has a custom drop-down component inside it, which uses a portal to render
* the drop-down options. The original outside click handler would close the popover when clicking on the drop-down options,
* because the click is outside the popover - but we expect the popover to stay open in this case.
*
* A filter function like this can fix that:
*
* ```ts
* const filter = (el: HTMLElement) => el.className.includes('chakra-portal') || el.id.includes('react-select')
* ```
*
* This ignores clicks on react-select-based drop-downs and Chakra UI portals and is used as the default filter.
*/
import { useCallback, useEffect, useRef } from 'react';
type FilterFunction = (el: HTMLElement | SVGElement) => boolean;
export function useCallbackRef<T extends (...args: any[]) => any>(
callback: T | undefined,
deps: React.DependencyList = []
) {
const callbackRef = useRef(callback);
useEffect(() => {
callbackRef.current = callback;
});
// eslint-disable-next-line react-hooks/exhaustive-deps
return useCallback(((...args) => callbackRef.current?.(...args)) as T, deps);
}
export interface UseOutsideClickProps {
/**
* Whether the hook is enabled
*/
enabled?: boolean;
/**
* The reference to a DOM element.
*/
ref: React.RefObject<HTMLElement | null>;
/**
* Function invoked when a click is triggered outside the referenced element.
*/
handler?: (e: Event) => void;
/**
* A function that filters the elements that should be considered as outside clicks.
*
* If omitted, a default filter function that ignores clicks in Chakra UI portals and react-select components is used.
*/
filter?: FilterFunction;
}
export const DEFAULT_FILTER: FilterFunction = (el) => {
if (el instanceof SVGElement) {
// SVGElement's type appears to be incorrect. Its className is not a string, which causes `includes` to fail.
// Let's assume that SVG elements with a class name are not part of the portal and should not be filtered.
return false;
}
return el.className.includes('chakra-portal') || el.id.includes('react-select');
};
/**
* Example, used in components like Dialogs and Popovers, so they can close
* when a user clicks outside them.
*/
export function useFilterableOutsideClick(props: UseOutsideClickProps) {
const { ref, handler, enabled = true, filter = DEFAULT_FILTER } = props;
const savedHandler = useCallbackRef(handler);
const stateRef = useRef({
isPointerDown: false,
ignoreEmulatedMouseEvents: false,
});
const state = stateRef.current;
useEffect(() => {
if (!enabled) {
return;
}
const onPointerDown: any = (e: PointerEvent) => {
if (isValidEvent(e, ref, filter)) {
state.isPointerDown = true;
}
};
const onMouseUp: any = (event: MouseEvent) => {
if (state.ignoreEmulatedMouseEvents) {
state.ignoreEmulatedMouseEvents = false;
return;
}
if (state.isPointerDown && handler && isValidEvent(event, ref)) {
state.isPointerDown = false;
savedHandler(event);
}
};
const onTouchEnd = (event: TouchEvent) => {
state.ignoreEmulatedMouseEvents = true;
if (handler && state.isPointerDown && isValidEvent(event, ref)) {
state.isPointerDown = false;
savedHandler(event);
}
};
const doc = getOwnerDocument(ref.current);
doc.addEventListener('mousedown', onPointerDown, true);
doc.addEventListener('mouseup', onMouseUp, true);
doc.addEventListener('touchstart', onPointerDown, true);
doc.addEventListener('touchend', onTouchEnd, true);
return () => {
doc.removeEventListener('mousedown', onPointerDown, true);
doc.removeEventListener('mouseup', onMouseUp, true);
doc.removeEventListener('touchstart', onPointerDown, true);
doc.removeEventListener('touchend', onTouchEnd, true);
};
}, [handler, ref, savedHandler, state, enabled, filter]);
}
function isValidEvent(event: Event, ref: React.RefObject<HTMLElement | null>, filter?: FilterFunction): boolean {
const target = (event.composedPath?.()[0] ?? event.target) as HTMLElement;
if (target) {
const doc = getOwnerDocument(target);
if (!doc.contains(target)) {
return false;
}
}
if (ref.current?.contains(target)) {
return false;
}
// This is the main logic change from the original hook.
if (filter) {
// Check if the click is inside an element matching the filter.
// This is used for portal-awareness or other general exclusion cases.
let currentElement: HTMLElement | null = target;
// Traverse up the DOM tree from the target element.
while (currentElement && currentElement !== document.body) {
if (filter(currentElement)) {
return false;
}
currentElement = currentElement.parentElement;
}
}
// If the click is not inside the ref and not inside a portal, it's a valid outside click.
return true;
}
function getOwnerDocument(node?: Element | null): Document {
return node?.ownerDocument ?? document;
}

View File

@@ -1,17 +1,13 @@
import { useAppStore } from 'app/store/storeHooks';
import { useDeleteImageModalApi } from 'features/deleteImageModal/store/state';
import { selectSelection } from 'features/gallery/store/gallerySelectors';
import { useAppDispatch } from 'app/store/storeHooks';
import { useClearQueue } from 'features/queue/hooks/useClearQueue';
import { useDeleteCurrentQueueItem } from 'features/queue/hooks/useDeleteCurrentQueueItem';
import { useInvoke } from 'features/queue/hooks/useInvoke';
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { navigationApi } from 'features/ui/layouts/navigation-api';
import { getFocusedRegion } from './focus';
import { setActiveTab } from 'features/ui/store/uiSlice';
export const useGlobalHotkeys = () => {
const { dispatch, getState } = useAppStore();
const dispatch = useAppDispatch();
const isModelManagerEnabled = useFeatureStatus('modelManager');
const queue = useInvoke();
@@ -69,7 +65,7 @@ export const useGlobalHotkeys = () => {
id: 'selectGenerateTab',
category: 'app',
callback: () => {
navigationApi.switchToTab('generate');
dispatch(setActiveTab('generate'));
},
dependencies: [dispatch],
});
@@ -78,7 +74,7 @@ export const useGlobalHotkeys = () => {
id: 'selectCanvasTab',
category: 'app',
callback: () => {
navigationApi.switchToTab('canvas');
dispatch(setActiveTab('canvas'));
},
dependencies: [dispatch],
});
@@ -87,7 +83,7 @@ export const useGlobalHotkeys = () => {
id: 'selectUpscalingTab',
category: 'app',
callback: () => {
navigationApi.switchToTab('upscaling');
dispatch(setActiveTab('upscaling'));
},
dependencies: [dispatch],
});
@@ -96,7 +92,7 @@ export const useGlobalHotkeys = () => {
id: 'selectWorkflowsTab',
category: 'app',
callback: () => {
navigationApi.switchToTab('workflows');
dispatch(setActiveTab('workflows'));
},
dependencies: [dispatch],
});
@@ -105,7 +101,7 @@ export const useGlobalHotkeys = () => {
id: 'selectModelsTab',
category: 'app',
callback: () => {
navigationApi.switchToTab('models');
dispatch(setActiveTab('models'));
},
options: {
enabled: isModelManagerEnabled,
@@ -117,26 +113,24 @@ export const useGlobalHotkeys = () => {
id: 'selectQueueTab',
category: 'app',
callback: () => {
navigationApi.switchToTab('queue');
dispatch(setActiveTab('queue'));
},
dependencies: [dispatch, isModelManagerEnabled],
});
const deleteImageModalApi = useDeleteImageModalApi();
useRegisteredHotkeys({
id: 'deleteSelection',
category: 'gallery',
callback: () => {
const focusedRegion = getFocusedRegion();
if (focusedRegion !== 'gallery' && focusedRegion !== 'viewer') {
return;
}
const selection = selectSelection(getState());
if (!selection.length) {
return;
}
deleteImageModalApi.delete(selection);
},
dependencies: [getState, deleteImageModalApi],
});
// TODO: implement delete - needs to handle gallery focus, which has changed w/ dockview
// useRegisteredHotkeys({
// id: 'deleteSelection',
// category: 'gallery',
// callback: () => {
// if (!selection.length) {
// return;
// }
// deleteImageModal.delete(selection);
// },
// options: {
// enabled: (isGalleryFocused || isImageViewerFocused) && isDeleteEnabledByTab && !isWorkflowsFocused,
// },
// dependencies: [isWorkflowsFocused, isDeleteEnabledByTab, selection, isWorkflowsFocused],
// });
};

View File

@@ -2,10 +2,10 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import type { GroupBase } from 'chakra-react-select';
import { groupBy, reduce } from 'es-toolkit/compat';
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { selectSystemShouldEnableModelDescriptions } from 'features/system/store/systemSlice';
import { groupBy, reduce } from 'lodash-es';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { AnyModelConfig } from 'services/api/types';

View File

@@ -21,15 +21,11 @@ type UseImageUploadButtonArgs =
isDisabled?: boolean;
allowMultiple: false;
onUpload?: (imageDTO: ImageDTO) => void;
onUploadStarted?: (files: File) => void;
onError?: (error: unknown) => void;
}
| {
isDisabled?: boolean;
allowMultiple: true;
onUpload?: (imageDTOs: ImageDTO[]) => void;
onUploadStarted?: (files: File[]) => void;
onError?: (error: unknown) => void;
};
const log = logger('gallery');
@@ -53,13 +49,7 @@ const log = logger('gallery');
* <Button {...getUploadButtonProps()} /> // will open the file dialog on click
* <input {...getUploadInputProps()} /> // hidden, handles native upload functionality
*/
export const useImageUploadButton = ({
onUpload,
isDisabled,
allowMultiple,
onUploadStarted,
onError,
}: UseImageUploadButtonArgs) => {
export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: UseImageUploadButtonArgs) => {
const autoAddBoardId = useAppSelector(selectAutoAddBoardId);
const isClientSideUploadEnabled = useAppSelector(selectIsClientSideUploadEnabled);
const [uploadImage, request] = useUploadImageMutation();
@@ -81,7 +71,6 @@ export const useImageUploadButton = ({
}
const file = files[0];
assert(file !== undefined); // should never happen
onUploadStarted?.(file);
const imageDTO = await uploadImage({
file,
image_category: 'user',
@@ -93,8 +82,6 @@ export const useImageUploadButton = ({
onUpload(imageDTO);
}
} else {
onUploadStarted?.(files);
let imageDTOs: ImageDTO[] = [];
if (isClientSideUploadEnabled && files.length > 1) {
imageDTOs = await Promise.all(files.map((file, i) => clientSideUpload(file, i)));
@@ -115,7 +102,6 @@ export const useImageUploadButton = ({
}
}
} catch (error) {
onError?.(error);
toast({
id: 'UPLOAD_FAILED',
title: t('toast.imageUploadFailed'),
@@ -123,17 +109,7 @@ export const useImageUploadButton = ({
});
}
},
[
allowMultiple,
onUploadStarted,
uploadImage,
autoAddBoardId,
onUpload,
isClientSideUploadEnabled,
clientSideUpload,
onError,
t,
]
[allowMultiple, autoAddBoardId, onUpload, uploadImage, isClientSideUploadEnabled, clientSideUpload, t]
);
const onDropRejected = useCallback(

View File

@@ -1,7 +1,7 @@
import { useAppStore } from 'app/store/storeHooks';
import { debounce } from 'es-toolkit/compat';
import { useAppStore } from 'app/store/nanostores/store';
import type { Dimensions } from 'features/controlLayers/store/types';
import { selectUiSlice, textAreaSizesStateChanged } from 'features/ui/store/uiSlice';
import { debounce } from 'lodash-es';
import { type RefObject, useCallback, useEffect, useMemo } from 'react';
type Options = {

View File

@@ -0,0 +1,132 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { EMPTY_ARRAY } from 'app/store/constants';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import type { GroupBase } from 'chakra-react-select';
import { selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice';
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { uniq } from 'lodash-es';
import { useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetRelatedModelIdsBatchQuery } from 'services/api/endpoints/modelRelationships';
import type { AnyModelConfig } from 'services/api/types';
import { useGroupedModelCombobox } from './useGroupedModelCombobox';
type UseRelatedGroupedModelComboboxArg<T extends AnyModelConfig> = {
modelConfigs: T[];
selectedModel?: ModelIdentifierField | null;
onChange: (value: T | null) => void;
getIsDisabled?: (model: T) => boolean;
isLoading?: boolean;
groupByType?: boolean;
};
// Custom hook to overlay the grouped model combobox with related models on top!
// Cleaner than hooking into useGroupedModelCombobox with a flag to enable/disable the related models
// Also allows for related models to be shown conditionally with some pretty simple logic if it ends up as a config flag.
type UseRelatedGroupedModelComboboxReturn = {
value: ComboboxOption | undefined | null;
options: GroupBase<ComboboxOption>[];
onChange: ComboboxOnChange;
placeholder: string;
noOptionsMessage: () => string;
};
const selectSelectedModelKeys = createMemoizedSelector(selectParamsSlice, selectLoRAsSlice, (params, loras) => {
const keys: string[] = [];
const main = params.model;
const vae = params.vae;
const refiner = params.refinerModel;
const controlnet = params.controlLora;
if (main) {
keys.push(main.key);
}
if (vae) {
keys.push(vae.key);
}
if (refiner) {
keys.push(refiner.key);
}
if (controlnet) {
keys.push(controlnet.key);
}
for (const { model } of loras.loras) {
keys.push(model.key);
}
return uniq(keys);
});
export function useRelatedGroupedModelCombobox<T extends AnyModelConfig>({
modelConfigs,
selectedModel,
onChange,
isLoading = false,
getIsDisabled,
groupByType,
}: UseRelatedGroupedModelComboboxArg<T>): UseRelatedGroupedModelComboboxReturn {
const { t } = useTranslation();
const selectedKeys = useAppSelector(selectSelectedModelKeys);
const { relatedKeys } = useGetRelatedModelIdsBatchQuery(selectedKeys, {
selectFromResult: ({ data }) => {
if (!data) {
return { relatedKeys: EMPTY_ARRAY };
}
return { relatedKeys: data };
},
});
// Base grouped options
const base = useGroupedModelCombobox({
modelConfigs,
selectedModel,
onChange,
getIsDisabled,
isLoading,
groupByType,
});
const options = useMemo(() => {
if (relatedKeys.length === 0) {
return base.options;
}
const relatedOptions: ComboboxOption[] = [];
const updatedGroups: GroupBase<ComboboxOption>[] = [];
for (const group of base.options) {
const remainingOptions: ComboboxOption[] = [];
for (const option of group.options) {
if (relatedKeys.includes(option.value)) {
relatedOptions.push({ ...option, label: `* ${option.label}` });
} else {
remainingOptions.push(option);
}
}
if (remainingOptions.length > 0) {
updatedGroups.push({
label: group.label,
options: remainingOptions,
});
}
}
if (relatedOptions.length > 0) {
return [{ label: t('modelManager.relatedModels'), options: relatedOptions }, ...updatedGroups];
} else {
return updatedGroups;
}
}, [base.options, relatedKeys, t]);
return {
...base,
options,
};
}

View File

@@ -0,0 +1,28 @@
import type { Selector } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
import { useAppStore } from 'app/store/storeHooks';
import type { Atom, WritableAtom } from 'nanostores';
import { atom } from 'nanostores';
import { useEffect, useState } from 'react';
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
export const useSelectorAsAtom = <T extends Selector<RootState, any>>(selector: T): Atom<ReturnType<T>> => {
const store = useAppStore();
const $atom = useState<WritableAtom<ReturnType<T>>>(() => atom<ReturnType<T>>(selector(store.getState())))[0];
useEffect(() => {
const unsubscribe = store.subscribe(() => {
const prev = $atom.get();
const next = selector(store.getState());
if (prev !== next) {
$atom.set(next);
}
});
return () => {
unsubscribe();
};
}, [$atom, selector, store]);
return $atom;
};

View File

@@ -1,20 +0,0 @@
export type Deferred<T> = {
promise: Promise<T>;
resolve: (value: T) => void;
reject: (error: Error) => void;
};
/**
* Create a promise and expose its resolve and reject callbacks.
*/
export const createDeferredPromise = <T>(): Deferred<T> => {
let resolve!: (value: T) => void;
let reject!: (error: Error) => void;
const promise = new Promise<T>((res, rej) => {
resolve = res;
reject = rej;
});
return { promise, resolve, reject };
};

View File

@@ -1,5 +1,5 @@
import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from 'app/constants';
import { random } from 'es-toolkit/compat';
import { random } from 'lodash-es';
type GenerateSeedsArg = {
count: number;

View File

@@ -0,0 +1,6 @@
/**
* Get the keys of an object. This is a wrapper around `Object.keys` that types the result as an array of the keys of the object.
* @param obj The object to get the keys of.
* @returns The keys of the object.
*/
export const objectKeys = <T extends Record<string, unknown>>(obj: T) => Object.keys(obj) as Array<keyof T>;

View File

@@ -89,7 +89,7 @@ export function withResult<T>(fn: () => T): Result<T> {
try {
return new Ok(fn());
} catch (error) {
return new Err(error instanceof Error ? error : new WrappedError(error));
return new Err(error instanceof Error ? error : new Error(String(error)));
}
}
@@ -104,23 +104,6 @@ export async function withResultAsync<T>(fn: () => Promise<T>): Promise<Result<T
const result = await fn();
return new Ok(result);
} catch (error) {
return new Err(error instanceof Error ? error : new WrappedError(error));
}
}
export class WrappedError extends Error {
error: unknown;
constructor(error: unknown) {
super('Wrapped Error');
this.name = this.constructor.name;
this.error = error;
}
static wrap(error: unknown): Error | WrappedError {
if (error instanceof Error) {
return error;
}
return new WrappedError(error);
return new Err(error instanceof Error ? error : new Error(String(error)));
}
}

View File

@@ -1,4 +1,4 @@
import type { z } from 'zod/v4';
import type { z } from 'zod';
/**
* Helper to create a type guard from a zod schema. The type guard will infer the schema's TS type.

View File

@@ -0,0 +1,182 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import {
ContextMenu,
Divider,
Flex,
IconButton,
Menu,
MenuButton,
MenuList,
Tab,
TabList,
TabPanel,
TabPanels,
Tabs,
} from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { FocusRegionWrapper } from 'common/components/FocusRegionWrapper';
import { CanvasAlertsInvocationProgress } from 'features/controlLayers/components/CanvasAlerts/CanvasAlertsInvocationProgress';
import { CanvasAlertsPreserveMask } from 'features/controlLayers/components/CanvasAlerts/CanvasAlertsPreserveMask';
import { CanvasAlertsSelectedEntityStatus } from 'features/controlLayers/components/CanvasAlerts/CanvasAlertsSelectedEntityStatus';
import { CanvasContextMenuGlobalMenuItems } from 'features/controlLayers/components/CanvasContextMenu/CanvasContextMenuGlobalMenuItems';
import { CanvasContextMenuSelectedEntityMenuItems } from 'features/controlLayers/components/CanvasContextMenu/CanvasContextMenuSelectedEntityMenuItems';
import { CanvasDropArea } from 'features/controlLayers/components/CanvasDropArea';
import { Filter } from 'features/controlLayers/components/Filters/Filter';
import { CanvasHUD } from 'features/controlLayers/components/HUD/CanvasHUD';
import { InvokeCanvasComponent } from 'features/controlLayers/components/InvokeCanvasComponent';
import { SelectObject } from 'features/controlLayers/components/SelectObject/SelectObject';
import { CanvasSessionContextProvider } from 'features/controlLayers/components/SimpleSession/context';
import { GenerateLaunchpadPanel } from 'features/controlLayers/components/SimpleSession/GenerateLaunchpadPanel';
import { StagingAreaItemsList } from 'features/controlLayers/components/SimpleSession/StagingAreaItemsList';
import { StagingAreaToolbar } from 'features/controlLayers/components/StagingArea/StagingAreaToolbar';
import { CanvasToolbar } from 'features/controlLayers/components/Toolbar/CanvasToolbar';
import { Transform } from 'features/controlLayers/components/Transform/Transform';
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { selectDynamicGrid, selectShowHUD } from 'features/controlLayers/store/canvasSettingsSlice';
import { ImageViewer } from 'features/gallery/components/ImageViewer/ImageViewer';
import { ViewerToolbar } from 'features/gallery/components/ImageViewer/ViewerToolbar';
import { memo, useCallback } from 'react';
import { PiDotsThreeOutlineVerticalFill } from 'react-icons/pi';
const FOCUS_REGION_STYLES: SystemStyleObject = {
width: 'full',
height: 'full',
};
const MenuContent = memo(() => {
return (
<CanvasManagerProviderGate>
<MenuList>
<CanvasContextMenuSelectedEntityMenuItems />
<CanvasContextMenuGlobalMenuItems />
</MenuList>
</CanvasManagerProviderGate>
);
});
MenuContent.displayName = 'MenuContent';
const canvasBgSx = {
position: 'relative',
w: 'full',
h: 'full',
borderRadius: 'base',
overflow: 'hidden',
bg: 'base.900',
'&[data-dynamic-grid="true"]': {
bg: 'base.850',
},
};
export const AdvancedSession = memo(({ id }: { id: string | null }) => {
const dynamicGrid = useAppSelector(selectDynamicGrid);
const showHUD = useAppSelector(selectShowHUD);
const renderMenu = useCallback(() => {
return <MenuContent />;
}, []);
return (
<Tabs w="full" h="full">
<TabList>
<Tab>Welcome</Tab>
<Tab>Workspace</Tab>
<Tab>Viewer</Tab>
</TabList>
<TabPanels w="full" h="full">
<TabPanel w="full" h="full" justifyContent="center">
<GenerateLaunchpadPanel />
</TabPanel>
<TabPanel w="full" h="full">
<FocusRegionWrapper region="canvas" sx={FOCUS_REGION_STYLES}>
<Flex
tabIndex={-1}
borderRadius="base"
position="relative"
flexDirection="column"
height="full"
width="full"
gap={2}
alignItems="center"
justifyContent="center"
overflow="hidden"
>
<CanvasManagerProviderGate>
<CanvasToolbar />
</CanvasManagerProviderGate>
<Divider />
<ContextMenu<HTMLDivElement> renderMenu={renderMenu} withLongPress={false}>
{(ref) => (
<Flex ref={ref} sx={canvasBgSx} data-dynamic-grid={dynamicGrid}>
<InvokeCanvasComponent />
<CanvasManagerProviderGate>
<Flex
position="absolute"
flexDir="column"
top={1}
insetInlineStart={1}
pointerEvents="none"
gap={2}
alignItems="flex-start"
>
{showHUD && <CanvasHUD />}
<CanvasAlertsSelectedEntityStatus />
<CanvasAlertsPreserveMask />
<CanvasAlertsInvocationProgress />
</Flex>
<Flex position="absolute" top={1} insetInlineEnd={1}>
<Menu>
<MenuButton as={IconButton} icon={<PiDotsThreeOutlineVerticalFill />} colorScheme="base" />
<MenuContent />
</Menu>
</Flex>
</CanvasManagerProviderGate>
</Flex>
)}
</ContextMenu>
{id !== null && (
<CanvasManagerProviderGate>
<CanvasSessionContextProvider type="advanced" id={id}>
<Flex
position="absolute"
flexDir="column"
bottom={4}
gap={2}
align="center"
justify="center"
left={4}
right={4}
>
<Flex position="relative" maxW="full" w="full" h={108}>
<StagingAreaItemsList />
</Flex>
<Flex gap={2}>
<StagingAreaToolbar />
</Flex>
</Flex>
</CanvasSessionContextProvider>
</CanvasManagerProviderGate>
)}
<Flex position="absolute" bottom={4}>
<CanvasManagerProviderGate>
<Filter />
<Transform />
<SelectObject />
</CanvasManagerProviderGate>
</Flex>
<CanvasManagerProviderGate>
<CanvasDropArea />
</CanvasManagerProviderGate>
</Flex>
</FocusRegionWrapper>
</TabPanel>
<TabPanel w="full" h="full">
<Flex flexDir="column" w="full" h="full">
<ViewerToolbar />
<ImageViewer />
</Flex>
</TabPanel>
</TabPanels>
</Tabs>
);
});
AdvancedSession.displayName = 'AdvancedSession';

View File

@@ -1,23 +0,0 @@
import { Alert, AlertIcon, AlertTitle } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { selectSaveAllImagesToGallery } from 'features/controlLayers/store/canvasSettingsSlice';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
export const CanvasAlertsSaveAllImagesToGallery = memo(() => {
const { t } = useTranslation();
const saveAllImagesToGallery = useAppSelector(selectSaveAllImagesToGallery);
if (!saveAllImagesToGallery) {
return null;
}
return (
<Alert status="info" borderRadius="base" fontSize="sm" shadow="md" w="fit-content">
<AlertIcon />
<AlertTitle>{t('controlLayers.settings.saveAllImagesToGallery.alert')}</AlertTitle>
</Alert>
);
});
CanvasAlertsSaveAllImagesToGallery.displayName = 'CanvasAlertsSaveAllImagesToGallery';

View File

@@ -1,4 +1,3 @@
import type { SpinnerProps } from '@invoke-ai/ui-library';
import { Spinner } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
@@ -6,7 +5,7 @@ import { useAllEntityAdapters } from 'features/controlLayers/contexts/EntityAdap
import { computed } from 'nanostores';
import { memo, useMemo } from 'react';
export const CanvasBusySpinner = memo((props: SpinnerProps) => {
export const CanvasBusySpinner = memo(() => {
const canvasManager = useCanvasManager();
const allEntityAdapters = useAllEntityAdapters();
const $isPendingRectCalculation = useMemo(
@@ -22,7 +21,7 @@ export const CanvasBusySpinner = memo((props: SpinnerProps) => {
const isCompositing = useStore(canvasManager.compositor.$isBusy);
if (isRasterizing || isCompositing || isPendingRectCalculation) {
return <Spinner opacity={0.3} {...props} />;
return <Spinner opacity={0.3} />;
}
return null;
});

View File

@@ -12,10 +12,6 @@ const addControlLayerFromImageDndTargetData = newCanvasEntityFromImageDndTarget.
const addRegionalGuidanceReferenceImageFromImageDndTargetData = newCanvasEntityFromImageDndTarget.getData({
type: 'regional_guidance_with_reference_image',
});
const addResizedControlLayerFromImageDndTargetData = newCanvasEntityFromImageDndTarget.getData({
type: 'control_layer',
withResize: true,
});
export const CanvasDropArea = memo(() => {
const { t } = useTranslation();
@@ -49,6 +45,7 @@ export const CanvasDropArea = memo(() => {
isDisabled={isBusy}
/>
</GridItem>
<GridItem position="relative">
<DndDropTarget
dndTarget={newCanvasEntityFromImageDndTarget}
@@ -57,14 +54,6 @@ export const CanvasDropArea = memo(() => {
isDisabled={isBusy}
/>
</GridItem>
<GridItem position="relative">
<DndDropTarget
dndTarget={newCanvasEntityFromImageDndTarget}
dndTargetData={addResizedControlLayerFromImageDndTargetData}
label={t('controlLayers.canvasContextMenu.newResizedControlLayer')}
isDisabled={isBusy}
/>
</GridItem>
</Grid>
</>
);

Some files were not shown because too many files have changed in this diff Show More