Compare commits

..

577 Commits

Author SHA1 Message Date
psychedelicious
0ea88dc170 chore: release v4.2.9.dev9 2024-08-30 22:24:08 +10:00
psychedelicious
8369826d22 fix(ui): entity opacity number input focus prevents slider from opening 2024-08-30 22:20:49 +10:00
psychedelicious
0e354f5164 feat(ui): add merge visible for raster and inpaint mask layers
I don't think it makes sense to merge control layers or regional guidance layers because they have additional state.
2024-08-30 22:20:49 +10:00
psychedelicious
41f2ee2633 fix(ui): save to gallery rect too large
Was including all layer types in the rect - only want the raster layers.
2024-08-30 22:20:49 +10:00
psychedelicious
4e74006c5f fix(ui): canvasToBlob not raising error correctly 2024-08-30 22:20:49 +10:00
psychedelicious
48edb6e023 feat(ui): add save to gallery button 2024-08-30 22:20:49 +10:00
psychedelicious
aeae6af0a1 fix(ui): fix getRectUnion util, add some tests 2024-08-30 22:20:49 +10:00
psychedelicious
ab11d9af8e fix(ui): modals not staying open
TBH not sure exactly why this broke. Fixed by rollback back the use of a render prop in favor of global state. Also revised the API of `useBoolean` and `buildUseBoolean`.
2024-08-30 22:20:49 +10:00
psychedelicious
2e84327ca4 fix(ui): correct labels for generation tab origin 2024-08-30 22:20:49 +10:00
psychedelicious
fa6842121c fix(ui): context menu doesn't work for new entities
I do not understand why this fixes the issue, doesn't seem like it should. But it does.
2024-08-30 22:20:49 +10:00
psychedelicious
c402aa397d tidy(ui): organise tool module 2024-08-30 22:20:49 +10:00
psychedelicious
a58c8adc38 fix(ui): staging hotkeys enabled at wrong times 2024-08-30 22:20:49 +10:00
psychedelicious
d43e2d690e fix(ui): incorrect batch origin preventing progress/staging 2024-08-30 22:20:49 +10:00
psychedelicious
284f768810 feat(ui): restore minimal HUD 2024-08-30 22:20:49 +10:00
psychedelicious
e933d1ae2b feat(ui): remove unused asPreview for StageComponent 2024-08-30 22:20:49 +10:00
psychedelicious
1e134de771 chore(ui): lint 2024-08-30 22:20:49 +10:00
psychedelicious
29c47c8be5 chore: release v4.2.9.dev8 2024-08-30 22:20:49 +10:00
psychedelicious
e1122c541d feat(ui): revise generation mode logic
- Canvas generation mode is replace with a boolean `sendToCanvas` flag. When off, images generated on the canvas go to the gallery. When on, they get added to the staging area.
- When an image result is received, if its destination is the canvas, staging is automatically started.
- Updated queue list to show the destination column.
- Added `IconSwitch` component to represent binary choices, used for the new `sendToCanvas` flag and image viewer toggle.
- Remove the queue actions menu in `QueueControls`. Move the queue count badge to the cancel button.
- Redo layout of `QueueControls` to prevent duplicate queue count badges.
- Fix issue where gallery and options panels could show thru transparent regions of queue tab.
- Disable panel hotkeys when on mm/queue tabs.
2024-08-30 22:20:49 +10:00
psychedelicious
2f81d1ac83 chore(ui): typegen 2024-08-30 22:20:49 +10:00
psychedelicious
56fbe751db feat(app): add destination column to session_queue
The frontend needs to know where queue items came from (i.e. which tab), and where results are going to (i.e. send images to gallery or canvas). The `origin` column is not quite enough to represent this cleanly.

A `destination` column provides the frontend what it needs to handle incoming generations.
2024-08-30 22:20:49 +10:00
psychedelicious
93f1d67fbf tidy(ui): ViewerToggleMenu -> ViewerToggle 2024-08-30 22:20:49 +10:00
psychedelicious
9467b937ff feat(ui): alt quick switches to color picker 2024-08-30 22:20:49 +10:00
psychedelicious
4242e6e6c2 feat(ui): tweak add entity button layout 2024-08-30 22:20:49 +10:00
psychedelicious
9b39452b3e feat(ui): restore context menu for entity list 2024-08-30 22:20:49 +10:00
psychedelicious
85b23784cf feat(ui): add delete button to each layer 2024-08-30 22:20:49 +10:00
psychedelicious
085cc82926 feat(ui): add + buttons to entity categories 2024-08-30 22:20:49 +10:00
psychedelicious
0098c33f81 feat(ui): tweak brush fill UI 2024-08-30 22:20:49 +10:00
psychedelicious
292e00ab68 feat(ui): do not select layer on staging accept 2024-08-30 22:20:49 +10:00
psychedelicious
6c1fb2d06e fix(ui): more fiddly queue count layout stuff 2024-08-30 22:20:49 +10:00
psychedelicious
d60605fcd8 fix(ui): floating params panel invoke button loading state 2024-08-30 22:20:49 +10:00
psychedelicious
38ed720ff2 feat(ui): move canvas undo/redo to hook 2024-08-30 22:20:49 +10:00
psychedelicious
22203b8eb0 fix(ui): queue count badge positioning 2024-08-30 22:20:49 +10:00
psychedelicious
cf5fa792a1 fix(ui): add node cmdk only enabled on workflows tab 2024-08-30 22:20:49 +10:00
psychedelicious
c636633a8e chore: release v4.2.9.dev7 2024-08-30 22:20:49 +10:00
psychedelicious
55fe1ebc53 fix(ui): pending node connection stuck 2024-08-30 22:20:49 +10:00
psychedelicious
3c2fa6b475 chore(ui): lint 2024-08-30 22:20:49 +10:00
psychedelicious
9b927de2e0 chore: release v4.2.9.dev6 2024-08-30 22:20:49 +10:00
psychedelicious
6a62854e7d feat(ui): migrate add node popover to cmdk
Put this together as a way to figure out the library before moving on to the full app cmdk. Works great.
2024-08-30 22:20:49 +10:00
psychedelicious
312093cbb0 fix(ui): schema parsing now that node_pack is guaranteed to be present 2024-08-30 22:20:49 +10:00
psychedelicious
06fe14e1fc chore(ui): typegen 2024-08-30 22:20:49 +10:00
psychedelicious
1b54e58726 fix(app): node_pack not added to openapi schema correctly 2024-08-30 22:20:49 +10:00
psychedelicious
219d7c9611 fix(ui): unnecessary z-index on invoke button 2024-08-30 22:20:49 +10:00
psychedelicious
9f742a669e feat(ui): split settings modal 2024-08-30 22:20:49 +10:00
psychedelicious
41e324fd51 perf(ui): disable useInert on modals
This hook forcibly updates _all_ portals with `data-hidden=true` when the modal opens - then reverts it when the modal closes. It's intended to help screen readers. Unfortunately, this absolutely tanks performance because we have many portals. React needs to do alot of layout calculations (not re-renders).

IMO this behaviour is a bug in chakra. The modals which generated the portals are hidden by default, so this data attr should really be set by default. Dunno why it isn't.
2024-08-30 22:20:36 +10:00
psychedelicious
ce55a96125 feat(ui): fix queue item count badge positioning
Previously this badge, floating over the queue menu button next to the invoke button, was rendered within the existing layout. When I initially positioned it, the app layout interfered - it would extend into an area reserved for a flex gap, which cut off the badge.

As a (bad) workaround, I had shifted the whole app down a few pixels to make room for it. What I should have done is what I've done in this commit - render the badge in a portal to take it out of the layout so we don't need that extra vertical padding.

Sleekified some styling a bit too.
2024-08-30 22:20:36 +10:00
psychedelicious
64e60a7fde fix(ui): transparency effect not updating 2024-08-30 22:20:36 +10:00
psychedelicious
972f03960a feat(ui): tidy canvas toolbar buttons 2024-08-30 22:20:36 +10:00
psychedelicious
5a403f087d feat(ui): revised viewer toggle @joshistoast 2024-08-30 22:20:36 +10:00
psychedelicious
fe59d7f3b0 fix(ui): opacity reset value incorrect 2024-08-30 22:20:36 +10:00
psychedelicious
b2b2b73aed revert(ui): roll back flip, doesn't work with rotate yet 2024-08-30 22:20:36 +10:00
psychedelicious
20b563c4cb fix(ui): disable opacity slider fully when no valid entity selected 2024-08-30 22:20:36 +10:00
psychedelicious
263a0ef5b4 fix(ui): layer preview image sometimes not rendering
The canvas size was dynamic based on the container div's size. When the div was hidden (e.g. when selecting another tab), the container's effective size is 0. This resulted in the preview image canvas being drawn at a scale of 0.

Fixed by using an absolute size for the canvas container.
2024-08-30 22:20:36 +10:00
psychedelicious
e8723b7cd3 feat(ui): tweak regional prompt box styles 2024-08-30 22:20:36 +10:00
psychedelicious
03e05b2068 feat(ui): tweak enabled/locked toggle styles 2024-08-30 22:20:36 +10:00
psychedelicious
6c0482a71d feat(ui): tweak filter styling 2024-08-30 22:20:36 +10:00
psychedelicious
e6153e6fa4 feat(ui): add flip & reset to transform 2024-08-30 22:20:36 +10:00
psychedelicious
6d209c6cc3 tidy(ui): use helper to sync scaled bbox size on model change 2024-08-30 22:20:36 +10:00
psychedelicious
beb4e823dc fix(ui): randomize seed toggle linked to prompt concat 2024-08-30 22:20:36 +10:00
psychedelicious
61ba4c606b chore: release v4.2.9.dev5 2024-08-30 22:20:36 +10:00
psychedelicious
af840cedf3 chore(ui): lint 2024-08-30 22:20:36 +10:00
psychedelicious
0bf0bca03f feat(ui): generalize mask fill, add to action bar 2024-08-30 22:20:36 +10:00
psychedelicious
e470eaf8f3 feat(ui): implement interaction locking on layers 2024-08-30 22:20:36 +10:00
psychedelicious
377db3f726 feat(ui): iterate on layer actions
- Add lock toggle
- Tweak lock and enabled styles
- Update entity list action bar w/ delete & delete all
- Move add layer menu to action bar
- Adjust opacity slider style
2024-08-30 22:20:36 +10:00
psychedelicious
77f020a997 feat(ui): collapsible entity groups 2024-08-30 22:20:36 +10:00
psychedelicious
34e2eda625 tidy(ui): rename some classes to be consistent 2024-08-30 22:20:36 +10:00
psychedelicious
e1d559db69 feat(ui): tuned canvas undo/redo
- Throttle pushing to history for actions of the same type, starting with 1000ms throttle.
- History has a limit of 64 items, same as workflow editor
- Add clear history button
- Fix an issue where entity transformers would reset the entity state when the entity is fully transparent, resetting the redo stack. This could happen when you undo to the starting state of a layer
2024-08-30 22:20:36 +10:00
psychedelicious
23a98e2ed6 tidy(ui): move all undoable reducers back to canvas slice 2024-08-30 22:20:36 +10:00
psychedelicious
fe3b2ed357 fix(ui): dnd image count 2024-08-30 22:20:36 +10:00
psychedelicious
eedf81dcc5 fix(ui): canvas entity opacity scale 2024-08-30 22:20:36 +10:00
psychedelicious
dbef1a9e06 perf(ui): optimize all selectors 2
Mostly selector optimization. Still a few places to tidy up but I'll get to that later.
2024-08-30 22:20:36 +10:00
psychedelicious
a41406ca9a perf(ui): optimize all selectors 1
I learned that the inline selector syntax recreates the selector function on every render:

```ts
const val = useAppSelector((s) => s.slice.val)
```

Not good! Better is to create a selector outside the function and use it. Doing that for all selectors now, most of the way through now. Feels snappier.
2024-08-30 22:20:12 +10:00
psychedelicious
f126a61f66 feat(ui): rough out undo/redo on canvas 2024-08-30 22:20:12 +10:00
psychedelicious
89c79276f3 chore: release v4.2.9.dev4
Canvas dev build.
2024-08-30 22:20:12 +10:00
psychedelicious
423e463b95 fix(ui): handle error from internal konva method
We are dipping into konva's private API for preview images and it appears to be unsafe (got an error once). Wrapped in a try/catch.
2024-08-30 22:20:12 +10:00
psychedelicious
52202e45de feat(ui): split out loras state from canvas rendering state 2024-08-30 22:20:12 +10:00
psychedelicious
100832c66d feat(ui): split out session state from canvas rendering state 2024-08-30 22:20:12 +10:00
psychedelicious
a58b91b221 feat(ui): split out settings state from canvas rendering state 2024-08-30 22:20:12 +10:00
psychedelicious
3af6d79852 feat(ui): split out tool state from canvas rendering state 2024-08-30 22:20:12 +10:00
psychedelicious
1303e18e93 feat(ui): split out params/compositing state from canvas rendering state
First step to restoring undo/redo - the undoable state must be in its own slice. So params and settings must be isolated.
2024-08-30 22:20:12 +10:00
psychedelicious
301da97670 feat(ui): add CanvasModuleBase class to standardize canvas APIs
I did this ages ago but undid it for some reason, not sure why. Caught a few issues related to subscriptions.
2024-08-30 22:20:12 +10:00
psychedelicious
17e76981bb feat(ui): move selected tool and tool buffer out of redux
This ephemeral state can live in the canvas classes.
2024-08-30 22:20:12 +10:00
psychedelicious
9c1732e2bb feat(ui): move ephemeral state into canvas classes
Things like `$lastCursorPos` are now created within the canvas drawing classes. Consumers in react access them via `useCanvasManager`.

For example:
```tsx
const canvasManager = useCanvasManager();
const lastCursorPos = useStore(canvasManager.stateApi.$lastCursorPos);
```
2024-08-30 22:20:12 +10:00
psychedelicious
a3179e7a3f feat(ui): normalize all actions to accept an entityIdentifier
Previously, canvas actions specific to an entity type only needed the id of that entity type. This allowed you to pass in the id of an entity of the wrong type.

All actions for a specific entity now take a full entity identifier, and the entity identifier type can be narrowed.

`selectEntity` and `selectEntityOrThrow` now need a full entity identifier, and narrow their return values to a specific entity type _if_ the entity identifier is narrowed.

The types for canvas entities are updated with optional type parameters for this purpose.

All reducers, actions and components have been updated.
2024-08-30 22:20:12 +10:00
psychedelicious
f86b50d18a feat(ui): move events into modules who care about them 2024-08-30 22:20:12 +10:00
psychedelicious
307885f505 fix(ui): color picker resets brush opacity 2024-08-30 22:20:12 +10:00
psychedelicious
4b49c1dd6b fix(ui): scaled bbox loses sync 2024-08-30 22:20:12 +10:00
psychedelicious
f917cefa84 feat(ui): add context menu to entity list 2024-08-30 22:20:12 +10:00
psychedelicious
bea98438fc chore(ui): bump @invoke-ai/ui-library 2024-08-30 22:20:12 +10:00
psychedelicious
17d3275086 fix(ui): missing vae precision in graph builders 2024-08-30 22:20:12 +10:00
psychedelicious
059b7a0fcf chore: release v4.2.9.dev3
Instead of using dates, just going to increment.
2024-08-30 22:20:12 +10:00
psychedelicious
05d3a989f6 feat(ui): use new Result utils for enqueueing 2024-08-30 22:20:12 +10:00
psychedelicious
590ae70c12 fix(ui): graph building issue w/ controlnet 2024-08-30 22:20:12 +10:00
psychedelicious
5240ec6e6f feat(ui): add Result type & helpers
Wrappers to capture errors and turn into results:
- `withResult` wraps a sync function
- `withResultAsync` wraps an async function

Comments, tests.
2024-08-30 22:20:12 +10:00
psychedelicious
04772b642c chore: release v4.2.9.dev20240824 2024-08-30 22:20:12 +10:00
psychedelicious
65f6cb416f fix(ui): lint & fix issues with adding regional ip adapters 2024-08-30 22:20:12 +10:00
psychedelicious
24c2028739 feat(ui): add knipignore tag
I'm not ready to delete some things but still want to build the app.
2024-08-30 22:20:12 +10:00
psychedelicious
b0db9a3f56 feat(ui): duplicate entity 2024-08-30 22:20:12 +10:00
psychedelicious
3ea83574c0 feat(ui): autocomplete on getPrefixeId 2024-08-30 22:20:12 +10:00
psychedelicious
05252a9bfc feat(ui): paste canvas gens back on source in generate mode 2024-08-30 22:20:12 +10:00
psychedelicious
ce854f086e chore(ui): typegen 2024-08-30 22:20:12 +10:00
psychedelicious
ff0c16978c feat(nodes): CanvasV2MaskAndCropInvocation can paste generated image back on source
This is needed for `Generate` mode.
2024-08-30 22:20:12 +10:00
psychedelicious
41cc650031 fix(ui): extraneous entity preview updates 2024-08-30 22:20:12 +10:00
psychedelicious
c3f7554053 fix(ui): newly-added entities are selected 2024-08-30 22:20:12 +10:00
psychedelicious
3f597a1c60 feat(ui): add crosshair to color picker 2024-08-30 22:20:12 +10:00
psychedelicious
ccffdf1878 fix(ui): color picker ignores alpha 2024-08-30 22:20:12 +10:00
psychedelicious
474089e892 fix(ui): calculate renderable entities correctly in tool module 2024-08-30 22:20:12 +10:00
psychedelicious
778e8ad161 feat(ui): better color picker 2024-08-30 22:20:12 +10:00
psychedelicious
9f29892c24 feat(ui): colored mask preview image 2024-08-30 22:20:12 +10:00
psychedelicious
56fd46a069 fix(ui): new rectangles don't trigger rerender 2024-08-30 22:20:12 +10:00
psychedelicious
579e594861 chore: bump version v4.2.9.dev20240823 2024-08-30 22:20:12 +10:00
psychedelicious
af3440fbe3 feat(ui): disable most interaction while filtering 2024-08-30 22:19:54 +10:00
psychedelicious
cc101f55c4 fix(ui): filter preview offset 2024-08-30 22:19:54 +10:00
psychedelicious
ef1adf07f5 feat(ui): tweak layout of staging area toolbar 2024-08-30 22:19:54 +10:00
psychedelicious
625c05d9be chore(ui): typegen 2024-08-30 22:19:54 +10:00
psychedelicious
8ad3d8f738 tidy(app): clean up app changes for canvas v2 2024-08-30 22:19:54 +10:00
psychedelicious
4759875733 feat(ui): use singleton for clear q confirm dialog 2024-08-30 22:19:54 +10:00
psychedelicious
768e6a3c55 fix(ui): rip out broken recall logic, NO TS ERRORS 2024-08-30 22:19:54 +10:00
psychedelicious
45bd85c039 chore(ui): lint 2024-08-30 22:19:54 +10:00
psychedelicious
9f94c5a8bd fix(ui): staging area interaction scopes 2024-08-30 22:19:54 +10:00
psychedelicious
23fdd65961 fix(ui): staging area actions 2024-08-30 22:19:54 +10:00
psychedelicious
8034195c30 tidy(ui): more cleanup 2024-08-30 22:19:54 +10:00
psychedelicious
08761127c9 fix(ui): upscale tab graph 2024-08-30 22:19:54 +10:00
psychedelicious
4a10010b6c fix(ui): sdxl graph builder 2024-08-30 22:19:54 +10:00
psychedelicious
14cc5e2453 fix(ui): select next entity in the list when deleting 2024-08-30 22:19:54 +10:00
psychedelicious
3d87adea60 feat(ui): fix delete layer hotkey 2024-08-30 22:19:54 +10:00
psychedelicious
36e8232ab6 tidy(ui): "eye dropper" -> "color picker" 2024-08-30 22:19:54 +10:00
psychedelicious
72722a73be tidy(ui): regional guidance buttons 2024-08-30 22:19:54 +10:00
psychedelicious
a09aa232a9 feat(ui): update entity list menu 2024-08-30 22:19:54 +10:00
psychedelicious
7ae8b64699 feat(ui): add log debug button 2024-08-30 22:19:54 +10:00
psychedelicious
60e0d17f34 chore(ui): lint 2024-08-30 22:19:54 +10:00
psychedelicious
bf8bef2f00 chore(ui): prettier 2024-08-30 22:19:54 +10:00
psychedelicious
b586d67bac chore(ui): eslint 2024-08-30 22:19:54 +10:00
psychedelicious
31e5e5af13 tidy(ui): remove unused stuff 4 2024-08-30 22:19:35 +10:00
psychedelicious
94871e88cd tidy(ui): remove unused stuff 3 2024-08-30 22:18:50 +10:00
psychedelicious
00e56d1968 tidy(ui): remove unused pkg @chakra-ui/react-use-size 2024-08-30 22:18:50 +10:00
psychedelicious
43672a53ab feat(ui): revise graph building for control layers, fix issues w/ invocation complete events 2024-08-30 22:18:50 +10:00
psychedelicious
45097ed2a6 feat(ui): use unique id for metadata in Graph class 2024-08-30 22:18:50 +10:00
psychedelicious
871f6b9f95 tidy(ui): remove unused stuff 2 2024-08-30 22:18:50 +10:00
psychedelicious
e6476e3c75 tidy(ui): remove unused stuff 2024-08-30 22:18:50 +10:00
psychedelicious
ac9b5f246d tidy(ui): reduce use of parseify util 2024-08-30 22:18:50 +10:00
psychedelicious
8bc72a2744 feat(ui): refine canvas entity list items & menus 2024-08-30 22:18:50 +10:00
psychedelicious
f76f1d89d7 feat(ui): canvas layer preview, revised reactivity for adapters 2024-08-30 22:18:50 +10:00
psychedelicious
7b54762b5e feat(ui): add SyncableMap
Can be used with useSyncExternal store to make a `Map` reactive.
2024-08-30 22:18:50 +10:00
psychedelicious
bc6faf6a6d tidy(ui): removed unused transform methods from canvasmanager 2024-08-30 22:18:50 +10:00
psychedelicious
e7ae1ac9b2 feat(ui): transform tool ux 2024-08-30 22:18:50 +10:00
psychedelicious
dcb436adb1 feat(ui): rough out canvas mode 2024-08-30 22:18:50 +10:00
psychedelicious
80f0441905 feat(ui): add canvas autosave checkbox 2024-08-30 22:18:50 +10:00
psychedelicious
8cde803654 fix(ui): memory leak when getting image DTO
must unsubscribe!
2024-08-30 22:18:50 +10:00
psychedelicious
62445680ad feat(ui): rework settings menu 2024-08-30 22:18:50 +10:00
psychedelicious
7685e36886 feat(ui): no entities fallback buttons 2024-08-30 22:18:50 +10:00
psychedelicious
4c196844bd perf(ui): optimize gallery image delete button rendering 2024-08-30 22:18:50 +10:00
psychedelicious
b36159bda4 feat(ui): remove "solid" background option 2024-08-30 22:18:50 +10:00
psychedelicious
b02948d49a tidy(ui): organise files and classes 2024-08-30 22:18:50 +10:00
psychedelicious
f442d206be tidy(ui): abstract compositing logic to module 2024-08-30 22:18:50 +10:00
psychedelicious
21ed6bccd8 fix(ui): fix canvas cache property access 2024-08-30 22:18:50 +10:00
psychedelicious
143ce7f00b tidy(ui): clean up CanvasFilter class 2024-08-30 22:18:50 +10:00
psychedelicious
28e716139b tidy(ui): clean up a few bits and bobs 2024-08-30 22:18:50 +10:00
psychedelicious
80a7c0c521 tidy(ui): abstract canvas rendering logic to module 2024-08-30 22:18:50 +10:00
psychedelicious
255ad3d2ad tidy(ui): abstract caching logic to module 2024-08-30 22:18:50 +10:00
psychedelicious
089bc9c7d8 tidy(ui): abstract worker logic to module 2024-08-30 22:18:50 +10:00
psychedelicious
ee7dafaf57 tidy(ui): abstract stage logic into module 2024-08-30 22:18:50 +10:00
psychedelicious
516ecdb0ee feat(ui): add entity group hiding 2024-08-30 22:18:50 +10:00
psychedelicious
b77675f74d feat(ui): move all caching out of redux
While we lose the benefit of the caches persisting across reloads, this is a much simpler way to handle things. If we need a persistent cache, we can explore it in the future.
2024-08-30 22:18:50 +10:00
psychedelicious
eea5c8efad feat(ui): revised rasterization caching
- use `stable-hash` to generate stable, non-crypto hashes for cache entries, instead of using deep object comparisons
- use an object to store image name caches
2024-08-30 22:18:50 +10:00
psychedelicious
09f1aac3a3 feat(ui): revise filter implementation 2024-08-30 22:18:50 +10:00
psychedelicious
dd1dcb5eba fix(ui): add button to delete inpaint mask 2024-08-30 22:18:50 +10:00
psychedelicious
757bd62ebe feat(ui): add contexts/hooks to access entity adapters directly 2024-08-30 22:18:50 +10:00
psychedelicious
5a3127949b feat(ui): add CanvasManagerProviderGate
This context waits to render its children its until the canvas manager is available. Then its children have access to the manager directly via hook.
2024-08-30 22:18:50 +10:00
psychedelicious
ced934c0a3 feat(ui) do not set $canvasManager until ready 2024-08-30 22:18:50 +10:00
psychedelicious
c32445084f fix(ui): inpaint mask naming 2024-08-30 22:18:50 +10:00
psychedelicious
9f1af0cdaa feat(ui): efficient canvas compositing
Also solves issue of exporting layers at different opacities than what is visible
2024-08-30 22:18:50 +10:00
psychedelicious
0d26cab400 feat(ui): allow multiple inpaint masks
This is easier than making it a nullable singleton
2024-08-30 22:18:50 +10:00
psychedelicious
c8de2da3fc fix(ui): missing rasterization cache invalidations 2024-08-30 22:18:50 +10:00
psychedelicious
ca089a105e feat(ui): iterate on filter UI, flow 2024-08-30 22:18:50 +10:00
psychedelicious
22000918d6 fix(ui): rehydration data loss 2024-08-30 22:18:50 +10:00
psychedelicious
6affc28da4 feat(ui): sort log namespaces 2024-08-30 22:18:50 +10:00
psychedelicious
f659995e1c fix(ui): do not merge arrays by index during rehydration 2024-08-30 22:18:50 +10:00
psychedelicious
56fb3e738f fix(ui): clone parsed data during state rehydration
Without this, the objects and arrays in `parsed` could be mutated, and the log statment would show the mutated data.
2024-08-30 22:18:50 +10:00
psychedelicious
56d450a907 fix(ui): fix logger filter
was accidetnally replacing the filter instead of appending to it.
2024-08-30 22:18:50 +10:00
psychedelicious
d3cdcef36b fix(ui): race condition queue status
Sequence of events causing the race condition:
- Enqueue batch
- Invalidate `SessionQueueStatus` tag
- Request updated queue status via HTTP - batch still processing at this point
- Batch completes
- Event emitted saying so
- Optimistically update the queue status cache, it is correct
- HTTP request makes it back and overwrites the optimistic update, indicating the batch is still in progress

FIxed by not invalidating the cache.
2024-08-30 22:18:50 +10:00
psychedelicious
19434e73b4 fix(ui): handle opacity for masks 2024-08-30 22:18:50 +10:00
psychedelicious
f7b3df9583 feat(ui): default background to checkerboard 2024-08-30 22:18:50 +10:00
psychedelicious
4da4b3bd50 feat(ui): clean up logging namespaces, allow skipping namespaces 2024-08-30 22:18:50 +10:00
psychedelicious
e83513882a chore(ui): bump ui library 2024-08-30 22:18:50 +10:00
psychedelicious
5adc784b6b fix(ui): do not allow drawing if layer disabled 2024-08-30 22:18:50 +10:00
psychedelicious
f177513523 fix(ui): stale state causing race conditions & extraneous renders 2024-08-30 22:18:50 +10:00
psychedelicious
8ebcf79b1a fix(ui): do not clear buffer when rendering "real" objects 2024-08-30 22:18:50 +10:00
psychedelicious
c7e5f24704 tidy(ui): remove "filter" from CanvasImageState 2024-08-30 22:18:50 +10:00
psychedelicious
ab3eb32ec8 feat(ui): better editable title 2024-08-30 22:18:50 +10:00
psychedelicious
d76509e5cb fix(ui): stroke eraserline 2024-08-30 22:18:50 +10:00
psychedelicious
04f56aab82 feat(ui): restore transparency effect for control layers 2024-08-30 22:18:50 +10:00
psychedelicious
c7913cbbbb feat(ui): use text cursor for entity title 2024-08-30 22:18:50 +10:00
psychedelicious
0556468518 tidy(ui): remove extraneous logging in CanvasStateApi 2024-08-30 22:18:49 +10:00
psychedelicious
1c7ef827b6 feat(ui): better buffer commit logic 2024-08-30 22:18:49 +10:00
psychedelicious
5720ed4d64 feat(ui): render buffer separately from "real" objects 2024-08-30 22:18:49 +10:00
psychedelicious
7f05af4a68 fix(ui): pixelRect should always be integer 2024-08-30 22:18:49 +10:00
psychedelicious
6db615ed5a fix(ui): only update stage attrs when stage itself is dragged 2024-08-30 22:18:49 +10:00
psychedelicious
465f020c86 feat(ui): add line simplification
This fixes some awkward issues where line segments stack up.
2024-08-30 22:18:49 +10:00
psychedelicious
f05b77088f fix(ui): various things listening when they need not listen 2024-08-30 22:18:49 +10:00
psychedelicious
80a5abf1ad feat(ui): layer opacity via caching 2024-08-30 22:18:49 +10:00
psychedelicious
7a6e8de60f feat(ui): reset view fits all visible objects 2024-08-30 22:18:49 +10:00
psychedelicious
8364fa74cf fix(ui): rerenders when changing canvas scale 2024-08-30 22:18:49 +10:00
psychedelicious
14f4566dd0 fix(ui): do not render rasterized layer unless renderObjects=true 2024-08-30 22:18:49 +10:00
psychedelicious
6145378923 feat(ui): revise app layout strategy, add interaction scopes for hotkeys 2024-08-30 22:18:49 +10:00
psychedelicious
68e2606427 feat(ui): tweak mask patterns 2024-08-30 22:18:49 +10:00
psychedelicious
0f3eb04d1a fix(ui): dynamic prompts recalcs when presets are loaded 2024-08-30 22:18:49 +10:00
psychedelicious
4a355323b2 fix(ui): use style preset prompts correctly 2024-08-30 22:18:49 +10:00
psychedelicious
8601fbb4ea fix(ui): discard selected staging image not all other images 2024-08-30 22:18:49 +10:00
psychedelicious
db885aa180 fix(ui): respect image size in staging preview 2024-08-30 22:18:49 +10:00
psychedelicious
c18fb980a2 tidy(ui): cleanup after events change 2024-08-30 22:18:49 +10:00
psychedelicious
b630dbdf20 feat(ui): move socket event handling out of redux
Download events and invocation status events (including progress images) are very frequent. There's no real need for these to pass through redux. Handling them outside redux is a significant performance win - far fewer store subscription calls, far fewer trips through middleware.

All event handling is moved outside middleware. Cleanup of unused actions and listeners to follow.
2024-08-30 22:18:49 +10:00
psychedelicious
29ac1b5e01 fix(ui): rebase conflicts 2024-08-30 22:18:49 +10:00
psychedelicious
506d3b079e fix(ui): update compositing rect when fill changes 2024-08-30 22:18:49 +10:00
psychedelicious
0670e6b53a feat(ui): add canvas background style 2024-08-30 22:18:49 +10:00
psychedelicious
76124ea35b feat(ui): mask layers choose own opacity 2024-08-30 22:18:49 +10:00
psychedelicious
6eae3470cd feat(ui): mask fill patterns 2024-08-30 22:18:49 +10:00
psychedelicious
c7ba7ac876 build(ui): add vite types to tsconfig 2024-08-30 22:18:49 +10:00
psychedelicious
edc733abd9 fix(ui): do not smooth pixel data when using eyeDropper 2024-08-30 22:18:49 +10:00
psychedelicious
a56ded664e tidy(ui): tool components & translations 2024-08-30 22:18:49 +10:00
psychedelicious
31ace5fb0c feat(ui): rough out eyedropper tool
It's a bit slow bc we are converting the stage to canvas on every mouse move. Also need to improve the visual but it works.
2024-08-30 22:18:49 +10:00
psychedelicious
11010236b3 fix(ui): ip adapters work 2024-08-30 22:18:49 +10:00
psychedelicious
5f061ac1e2 feat(ui): rename layers 2024-08-30 22:18:49 +10:00
psychedelicious
72919fa34e feat(ui): revise entity menus 2024-08-30 22:18:49 +10:00
psychedelicious
d5ca99fc3c feat(ui): split control layers from raster layers for UI and internal state, same rendering as raster layers 2024-08-30 22:18:49 +10:00
psychedelicious
e49b72ee4e feat(ui): implement cache for image rasterization, rip out some old controladapters code 2024-08-30 22:18:49 +10:00
psychedelicious
abe8db8154 feat(ui, app): use layer as control (wip) 2024-08-30 22:18:49 +10:00
psychedelicious
e0e5941384 feat(ui): add contextmenu for canvas entities 2024-08-30 22:18:49 +10:00
psychedelicious
86e1f4e8b0 feat(ui): more better logging & naming 2024-08-30 22:18:49 +10:00
psychedelicious
447d873ef0 feat(ui): better logging w/ path 2024-08-30 22:18:49 +10:00
psychedelicious
b21d613ce4 feat(ui): always show marks on canvas scale slider 2024-08-30 22:18:49 +10:00
psychedelicious
fc91adb32f fix(ui): do not import button from chakra 2024-08-30 22:18:49 +10:00
psychedelicious
71885db5fd fix(ui): scaled bbox preview 2024-08-30 22:18:49 +10:00
psychedelicious
b88d14b3df feat(ui): tidy up atoms 2024-08-30 22:18:49 +10:00
psychedelicious
d98d35a8a8 feat(ui): convert all my pubsubs to atoms
its the same but better
2024-08-30 22:18:49 +10:00
psychedelicious
87bc0ebd73 feat(ui): add trnalsation 2024-08-30 22:18:49 +10:00
psychedelicious
7b6ba3f690 fix(ui): give up on thumbnail loading, causes flash during transformer 2024-08-30 22:18:49 +10:00
psychedelicious
b0d8948428 fix(ui): depth anything v2 2024-08-30 22:18:49 +10:00
psychedelicious
b32d681cee tidy(ui): remove unused code, comments 2024-08-30 22:18:49 +10:00
psychedelicious
11a66d1d09 fix(ui): staging area works 2024-08-30 22:18:49 +10:00
psychedelicious
e41987f08c feat(nodes): temp disable canvas output crop 2024-08-30 22:18:49 +10:00
psychedelicious
34b57ec188 fix(ui): max scale 1 when reset view 2024-08-30 22:18:49 +10:00
psychedelicious
d74843be31 feat(ui): better scale changer component, reset view functionality 2024-08-30 22:18:49 +10:00
psychedelicious
1216c6f9c9 fix(ui): img2img 2024-08-30 22:18:49 +10:00
psychedelicious
865b6017d3 feat(ui): add manual scale controls 2024-08-30 22:18:49 +10:00
psychedelicious
922a021821 fix(ui): do not await clearBuffer 2024-08-30 22:18:49 +10:00
psychedelicious
0b5f4cac57 feat(ui): dnd image into layer 2024-08-30 22:18:49 +10:00
psychedelicious
c988c58c63 fix(ui): do not await commitBuffer 2024-08-30 22:18:49 +10:00
psychedelicious
ceb8cbf59e fix(ui): properly destroy entities in manager cleanup 2024-08-30 22:18:49 +10:00
psychedelicious
52e9f43c46 tidy(ui): clearer component names for regional guidance 2024-08-30 22:18:49 +10:00
psychedelicious
4e5e7761fc tidy(ui): clearer component names for ip adapter 2024-08-30 22:18:49 +10:00
psychedelicious
9879999a65 tidy(ui): clearer component names for inpaint mask 2024-08-30 22:18:49 +10:00
psychedelicious
bedaca70a3 tidy(ui): clearer component names for control adapters 2024-08-30 22:18:49 +10:00
psychedelicious
2dd2225d2e feat(ui): simplify canvas list item headers 2024-08-30 22:18:49 +10:00
psychedelicious
d82031eec1 fix(ui): ip adapter list item 2024-08-30 22:18:49 +10:00
psychedelicious
e5f2860b74 tidy(ui): clean up unused logic 2024-08-30 22:18:49 +10:00
psychedelicious
fa3560bb61 feat(ui): clean up state, add mutex for image loading, add thumbnail loading 2024-08-30 22:18:49 +10:00
psychedelicious
9b23f6ce30 chore(ui): add async-mutex dep 2024-08-30 22:18:49 +10:00
psychedelicious
5d6aa6cfd5 feat(ui): txt2img, img2img, inpaint & outpaint working 2024-08-30 22:18:49 +10:00
psychedelicious
7d1819335f feat(ui): no padding on transformer outlines 2024-08-30 22:18:49 +10:00
psychedelicious
539e7a3f2d feat(ui): restore object count to layer titles 2024-08-30 22:18:49 +10:00
psychedelicious
1686924ac8 tidy(ui): "useIsEntitySelected" -> "useEntityIsSelected" 2024-08-30 22:18:49 +10:00
psychedelicious
556c1dc67b tidy(ui): move transformer statics into class 2024-08-30 22:18:49 +10:00
psychedelicious
00f7093e65 tidy(ui): massive cleanup
- create a context for entity identifiers, massively simplifying UI for each entity int he list
- consolidate common redux actions
- remove now-unused code
2024-08-30 22:18:49 +10:00
psychedelicious
79eb11dce9 perf(ui): do not add duplicate points to lines 2024-08-30 22:18:49 +10:00
psychedelicious
0bf48c0d41 feat(ui): up line tension to 0.3 2024-08-30 22:18:49 +10:00
psychedelicious
3f33e5f770 perf(ui): disable stroke, perfect draw on compositing rect 2024-08-30 22:18:49 +10:00
psychedelicious
da3888ba9e tidy(ui): remove unused code, initial image 2024-08-30 22:18:49 +10:00
psychedelicious
a2f91b1055 tidy(ui): remove unused state & actions 2024-08-30 22:18:49 +10:00
psychedelicious
d26095dfa1 feat(ui): region mask rendering 2024-08-30 22:18:49 +10:00
psychedelicious
83e786bd1e feat(ui): esc cancels drawing buffer
maybe this is not wanted? we'll see
2024-08-30 22:18:49 +10:00
psychedelicious
4cae12a507 fix(ui): render transformer over objects, fix issue w/ inpaint rect color 2024-08-30 22:18:49 +10:00
psychedelicious
d8e3708e0f fix(ui): brush preview fill for inpaint/region 2024-08-30 22:18:49 +10:00
psychedelicious
f4de2fd3b1 fix(ui): no objects rendered until vis toggled 2024-08-30 22:18:49 +10:00
psychedelicious
e1cb30bbb4 feat(ui): inpaint mask transform 2024-08-30 22:18:49 +10:00
psychedelicious
97e0edc549 fix(ui): layer accidental early set isFirstRender=false 2024-08-30 22:18:49 +10:00
psychedelicious
f4e66bf14f fix(ui): inpaint mask rendering 2024-08-30 22:18:49 +10:00
psychedelicious
a6a7fe8aba feat(ui): wip inpaint mask uses new API 2024-08-30 22:18:49 +10:00
psychedelicious
a273f72560 feat(ui): move updatePosition to transformer 2024-08-30 22:18:49 +10:00
psychedelicious
b5126f45d6 feat(ui): move resetScale to transformer 2024-08-30 22:18:49 +10:00
psychedelicious
ba3bb7cbf3 tidy(ui): more imperative naming 2024-08-30 22:18:49 +10:00
psychedelicious
608279487b tidy(ui): use imperative names for setters in stateapi 2024-08-30 22:18:49 +10:00
psychedelicious
72b5374916 fix(ui): commit drawing buffer on tool change, fixing bbox not calculating 2024-08-30 22:18:49 +10:00
psychedelicious
08b03212ca fix(ui): sync transformer when requesting bbox calc 2024-08-30 22:18:49 +10:00
psychedelicious
7e341a05a1 tidy(ui): rename union CanvasEntity -> CanvasEntityState 2024-08-30 22:18:49 +10:00
psychedelicious
e665d08ee1 fix(ui): request rect calc immediately on transform, hiding rect 2024-08-30 22:18:49 +10:00
psychedelicious
ba6362dc9d feat(ui): move bbox calculation to transformer 2024-08-30 22:18:49 +10:00
psychedelicious
48f0797c43 feat(ui): use set for transformer subscriptions 2024-08-30 22:18:49 +10:00
psychedelicious
640b0c4939 tidy(ui): clean up worker tasks when complete 2024-08-30 22:18:49 +10:00
psychedelicious
287c61e277 tidy(ui): remove unused code in CanvasTool 2024-08-30 22:18:49 +10:00
psychedelicious
f7b2516109 feat(ui): use pubsub for isTransforming on manager 2024-08-30 22:18:49 +10:00
psychedelicious
b530eb49d4 docs(ui): update transformer docstrings 2024-08-30 22:18:49 +10:00
psychedelicious
fa94979ab6 feat(ui): revised event pubsub, transformer logic split out 2024-08-30 22:18:49 +10:00
psychedelicious
54f2acf5b9 feat(ui): add simple pubsub 2024-08-30 22:18:49 +10:00
psychedelicious
b6d845a4d0 feat(ui): document & clean up object renderer 2024-08-30 22:18:49 +10:00
psychedelicious
1095b7c37f feat(ui): split out object renderer 2024-08-30 22:18:49 +10:00
psychedelicious
136ffd97ca fix(ui): unable to hold shit while transforming to retain ratio 2024-08-30 22:18:49 +10:00
psychedelicious
80163d0af2 tidy(ui): rename canvas stuff 2024-08-30 22:18:49 +10:00
psychedelicious
e1c6e926e7 tidy(ui): consolidate getLoggingContext builders 2024-08-30 22:18:49 +10:00
psychedelicious
2bb74abf31 fix(ui): align all tools to 1px grid
- Offset brush tool by 0.5px when width is odd, ensuring each stroke edge is exactly on a pixel boundary
- Round the rect tool also
2024-08-30 22:18:49 +10:00
psychedelicious
0d4b91afe0 feat(ui): disable image smoothing on layers 2024-08-30 22:18:49 +10:00
psychedelicious
6c688d6878 fix(ui): round position when rasterizing layer 2024-08-30 22:18:49 +10:00
psychedelicious
243feecef9 feat(ui): continue modularizing transform 2024-08-30 22:18:49 +10:00
psychedelicious
abd22ba087 feat(ui): fix a few things that didn't unsubscribe correctly, add helper to manage subscriptions 2024-08-30 22:18:49 +10:00
psychedelicious
ab25546e97 feat(ui): merge bbox outline into transformer 2024-08-30 22:18:49 +10:00
psychedelicious
925f0fca2a fix(ui): update parent's pos not transformers 2024-08-30 22:18:49 +10:00
psychedelicious
066366d885 feat(ui): merge interaction rect into transformer class 2024-08-30 22:18:49 +10:00
psychedelicious
61d52e96b7 feat(ui): prepare staging area 2024-08-30 22:18:49 +10:00
psychedelicious
051e88ca90 feat(ui): typing for logging context 2024-08-30 22:18:49 +10:00
psychedelicious
e873b69850 feat(ui): remove inheritance of CanvasObject
JS is terrible
2024-08-30 22:18:49 +10:00
psychedelicious
661fd55556 feat(ui): split & document transformer logic, iterate on class structures 2024-08-30 22:18:49 +10:00
psychedelicious
402f5a4717 feat(ui): rotation snap to nearest 45deg when holding shift 2024-08-30 22:18:49 +10:00
psychedelicious
81bf52ef37 feat(ui): expose subscribe method for nanostores 2024-08-30 22:18:49 +10:00
psychedelicious
8ff92796df tidy(ui): remove layer scaling reducers 2024-08-30 22:18:49 +10:00
psychedelicious
68af60e12e fix(ui): pixel-perfect transforms 2024-08-30 22:18:49 +10:00
psychedelicious
cce6bf9428 fix(ui): layer visibility toggle 2024-08-30 22:18:49 +10:00
psychedelicious
078908fbea fix(nodes): fix canvas mask erode
it wasn't eroding enough and caused incorrect transparency in result images
2024-08-30 22:18:49 +10:00
psychedelicious
7275caaf5b fix(ui): do not reset layer on first render 2024-08-30 22:18:49 +10:00
psychedelicious
d9487c1df4 feat(ui): revised logging and naming setup, fix staging area 2024-08-30 22:18:49 +10:00
psychedelicious
3a9f955388 feat(ui): add repr methods to layer and object classes 2024-08-30 22:18:49 +10:00
psychedelicious
e46c7acd2e feat(ui): use nanoid(10) instead of uuidv4 for canvas
Shorter ids makes it much more readable
2024-08-30 22:18:49 +10:00
psychedelicious
b771664851 build(ui): add nanoid as explicit dep 2024-08-30 22:18:49 +10:00
psychedelicious
7c21819d20 fix(ui): move CanvasImage's konva image to correct object 2024-08-30 22:18:49 +10:00
psychedelicious
a57e618d47 fix(ui): prevent flash when applying transform 2024-08-30 22:18:49 +10:00
psychedelicious
c9849a79ea build(ui): add eslint rules for async stuff 2024-08-30 22:18:49 +10:00
psychedelicious
f1643fec08 feat(ui): trying to fix flicker after transform 2024-08-30 22:18:49 +10:00
psychedelicious
951e63ca87 feat(ui): transform cleanup 2024-08-30 22:18:49 +10:00
psychedelicious
8e539c8a8c feat(ui): fix transform when rotated 2024-08-30 22:18:49 +10:00
psychedelicious
1e689a4902 fix(ui): use pixel bbox when image is in layer 2024-08-30 22:18:49 +10:00
psychedelicious
7bbd25b5ec fix(ui): transforming when axes flipped 2024-08-30 22:18:49 +10:00
psychedelicious
b1c7236117 feat(ui): hallelujah (???) 2024-08-30 22:18:49 +10:00
psychedelicious
ae3e473024 feat(ui): add debug button 2024-08-30 22:18:49 +10:00
psychedelicious
fd616f247c fix(ui): transformer padding 2024-08-30 22:18:48 +10:00
psychedelicious
45dca2c821 feat(ui): wip transform mode 2 2024-08-30 22:18:48 +10:00
psychedelicious
40dc108c84 feat(ui): wip transform mode 2024-08-30 22:18:48 +10:00
psychedelicious
a421c25952 feat(ui): wip transform mode 2024-08-30 22:18:48 +10:00
psychedelicious
562d0afdbb fix(ui): dnd to canvas broke 2024-08-30 22:18:48 +10:00
psychedelicious
2ce4698eef fix(ui): conflicts after rebasing 2024-08-30 22:18:48 +10:00
psychedelicious
cb53108041 fix(ui): imageDropped listener 2024-08-30 22:18:48 +10:00
psychedelicious
5fa65e5cc6 wip 2024-08-30 22:18:48 +10:00
psychedelicious
e8b0b6cef5 fix(ui): transform tool seems to be working 2024-08-30 22:18:48 +10:00
psychedelicious
eca2712828 fix(ui): move tool fixes, add transform tool 2024-08-30 22:18:48 +10:00
psychedelicious
2804c0aede feat(ui): move tool now only moves 2024-08-30 22:18:48 +10:00
psychedelicious
0429f0480d feat(ui): layer bbox calc in worker 2024-08-30 22:18:48 +10:00
psychedelicious
024759a0fc feat(ui): tweaked entity & group selection styles 2024-08-30 22:18:48 +10:00
psychedelicious
9a94aef2b0 feat(ui): canvas entity list headers 2024-08-30 22:18:48 +10:00
psychedelicious
e329cb45cd tidy(ui): CanvasRegion 2024-08-30 22:18:48 +10:00
psychedelicious
0dc38bd684 tidy(ui): CanvasRect 2024-08-30 22:18:48 +10:00
psychedelicious
98ebca5f8c tidy(ui): CanvasLayer 2024-08-30 22:18:48 +10:00
psychedelicious
05cb3e03cf tidy(ui): CanvasInpaintMask 2024-08-30 22:18:48 +10:00
psychedelicious
181132c149 tidy(ui): CanvasInitialImage 2024-08-30 22:18:48 +10:00
psychedelicious
a69aa00155 tidy(ui): CanvasImage 2024-08-30 22:18:48 +10:00
psychedelicious
47d415e31c tidy(ui): CanvasEraserLine 2024-08-30 22:18:48 +10:00
psychedelicious
667a156817 tidy(ui): CanvasControlAdapter 2024-08-30 22:18:48 +10:00
psychedelicious
00f39b977e tidy(ui): CanvasBrushLine 2024-08-30 22:18:48 +10:00
psychedelicious
e5776e2bd6 tidy(ui): CanvasBbox 2024-08-30 22:18:48 +10:00
psychedelicious
2b21f54897 tidy(ui): CanvasBackground 2024-08-30 22:18:48 +10:00
psychedelicious
678d12fcd5 tidy(ui): update canvas classes, organise location of konva nodes 2024-08-30 22:18:48 +10:00
psychedelicious
03f06f611e feat(ui): add names to all konva objects
Makes troubleshooting much simpler
2024-08-30 22:18:48 +10:00
psychedelicious
6571e0f814 fix(ui): do not await creating new canvas image
If you await this, it causes a race condition where multiple images are created.
2024-08-30 22:18:48 +10:00
psychedelicious
44f91026e1 feat(ui): use position and dimensions instead of separate x,y,width,height attrs 2024-08-30 22:18:48 +10:00
psychedelicious
56237328f1 fix(ui): remove weird rtkq hook wrapper
I do not understand why I did that initially but it doesn't work with TS.
2024-08-30 22:18:48 +10:00
psychedelicious
ff68901e89 feat(ui): rename types size and position to dimensions and coordinate 2024-08-30 22:18:48 +10:00
psychedelicious
e0e7adb2b2 tidy(ui): hide layer settings by default 2024-08-30 22:18:48 +10:00
psychedelicious
0923a5b128 fix(ui): layer rendering when starting as disabled 2024-08-30 22:18:48 +10:00
psychedelicious
75f8a84c79 feat(invocation): reduce canvas v2 mask & crop mask dilation 2024-08-30 22:18:48 +10:00
psychedelicious
af815cf7eb feat(ui): de-jank staging area and progress images 2024-08-30 22:18:48 +10:00
psychedelicious
ef4d6c26f6 feat(ui): update staging handling to work w/ cropped mask 2024-08-30 22:18:48 +10:00
psychedelicious
5087b306c0 chore(ui): typegen 2024-08-30 22:18:48 +10:00
psychedelicious
a5708eaefe feat(app): update CanvasV2MaskAndCropInvocation 2024-08-30 22:18:48 +10:00
psychedelicious
389bfc9e31 feat(ui): use new canvas output node 2024-08-30 22:18:48 +10:00
psychedelicious
fd269e91e0 chore(ui): typegen 2024-08-30 22:18:48 +10:00
psychedelicious
80136b0dfc feat(app): add CanvasV2MaskAndCropInvocation & CanvasV2MaskAndCropOutput
This handles some masking and cropping that the canvas needs.
2024-08-30 22:18:48 +10:00
psychedelicious
9595eff1f9 fix(ui): restore nodes output tracking 2024-08-30 22:18:48 +10:00
psychedelicious
c3c95754f7 feat(ui): rip out document size
barely knew ye
2024-08-30 22:18:48 +10:00
psychedelicious
22ab63fe8d feat(ui): convert initial image to layer when starting canvas session 2024-08-30 22:18:48 +10:00
psychedelicious
5fefcab475 fix(ui): fix layer transparency calculation 2024-08-30 22:18:48 +10:00
psychedelicious
771a05b894 fix(ui): reset initial image when resetting canvas 2024-08-30 22:18:48 +10:00
psychedelicious
e2d8aaa923 fix(ui): reset node executions states when loading workflow 2024-08-30 22:18:48 +10:00
psychedelicious
0951aecb13 fix(ui): entity display list 2024-08-30 22:18:48 +10:00
psychedelicious
b1fe6f9853 feat(ui): img2img working 2024-08-30 22:18:48 +10:00
psychedelicious
551dd393aa feat(ui): rough out img2img on canvas 2024-08-30 22:18:48 +10:00
psychedelicious
78b4562184 UNDO ME WIP 2024-08-30 22:18:48 +10:00
psychedelicious
c49b90e621 feat(ui): log invocation source id on socket event 2024-08-30 22:18:48 +10:00
psychedelicious
89e6233fbf feat(ui): restore document size overlay renderer 2024-08-30 22:18:48 +10:00
psychedelicious
3f9496c237 feat(ui): make documnet size a rect 2024-08-30 22:18:48 +10:00
psychedelicious
36e94af598 refactor(ui): remove modular imagesize components
This is no longer necessary with canvas v2 and added a ton of extraneous redux actions when changing the image size. Also renamed to document size
2024-08-30 22:18:48 +10:00
psychedelicious
a181a684f5 feat(ui): initialState is for generation mode 2024-08-30 22:18:48 +10:00
psychedelicious
bb712b3b3f feat(ui): split out canvas entity list component 2024-08-30 22:18:48 +10:00
psychedelicious
e795de5647 feat(ui): hide bbox button when no canvas session active 2024-08-30 22:18:48 +10:00
psychedelicious
bdc428cdd8 tidy(ui): remove unused naming objects/utils
The canvas manager means we don't need to worry about konva node names as we never directly select konva nodes.
2024-08-30 22:18:48 +10:00
psychedelicious
e4376e21dd feat(ui): split up tool chooser buttons
Prep for distinct toolbars for generation vs canvas modes
2024-08-30 22:18:48 +10:00
psychedelicious
77acc7baed feat(ui): add useAssertSingleton util hook
This simple hook asserts that it is only ever called once. Particularly useful for things like hotkeys hooks.
2024-08-30 22:18:48 +10:00
psychedelicious
9db1556c4d feat(ui): "stagingArea" -> "session" 2024-08-30 22:18:48 +10:00
psychedelicious
65de8b329b feat(ui): add reset button to canvas 2024-08-30 22:18:48 +10:00
psychedelicious
08dae5b047 feat(ui): add snapToRect util 2024-08-30 22:18:48 +10:00
psychedelicious
8d2f056407 fix(ui): fiddle with control adapter filters
some jank still
2024-08-30 22:18:48 +10:00
psychedelicious
e66ef2e25e feat(ui): temp disable doc size overlay 2024-08-30 22:18:48 +10:00
psychedelicious
4d3ee7e082 feat(ui): no animation on layer selection
Felt sluggish
2024-08-30 22:18:48 +10:00
psychedelicious
fe48fda2f3 feat(ui): use canvas as source for control images (wip) 2024-08-30 22:18:48 +10:00
psychedelicious
0f66753aa1 fix(ui): control adapter translate & scale 2024-08-30 22:18:48 +10:00
psychedelicious
a18878474b tidy(ui): removed unused state related to non-buffered drawing 2024-08-30 22:18:48 +10:00
psychedelicious
0aa4568fd4 feat(ui): control adapter image rendering 2024-08-30 22:18:48 +10:00
psychedelicious
1de7e5760a fix(ui): do not floor bbox calc, it cuts off the last pixels 2024-08-30 22:18:48 +10:00
psychedelicious
135d6f2763 feat(ui): fix issue where creating line needs 2 points 2024-08-30 22:18:48 +10:00
psychedelicious
061767ede3 fix(ui): edge cases when holding shift and drawing lines 2024-08-30 22:18:48 +10:00
psychedelicious
7204844bcb fix(ui): set buffered rect color to full alpha 2024-08-30 22:18:48 +10:00
psychedelicious
f2279ecadd fix(ui): handle mouseup correctly 2024-08-30 22:18:48 +10:00
psychedelicious
75694869d2 feat(ui): buffered rect drawing 2024-08-30 22:18:48 +10:00
psychedelicious
d029680ac1 fix(ui): buffered drawing edge cases 2024-08-30 22:18:48 +10:00
psychedelicious
41c195d936 perf(ui): do not use stage.find 2024-08-30 22:18:48 +10:00
psychedelicious
03ea005e9c perf(ui): object groups do not listen 2024-08-30 22:18:48 +10:00
psychedelicious
6d936a7c44 perf(ui): buffered drawing (wip) 2024-08-30 22:18:48 +10:00
psychedelicious
fba17b93a6 tidy(ui): organise files 2024-08-30 22:18:48 +10:00
psychedelicious
73a7a27ea1 tidy(ui): organise files 2024-08-30 22:18:48 +10:00
psychedelicious
79287c2d16 tidy(ui): organise files 2024-08-30 22:18:48 +10:00
psychedelicious
662c5f4b77 fix(ui): background rendering 2024-08-30 22:18:48 +10:00
psychedelicious
7728ca6843 pkg(ui): remove unused deps react-konva & use-image 2024-08-30 22:18:48 +10:00
psychedelicious
9607372f89 feat(ui): organize konva state and files 2024-08-30 22:18:48 +10:00
psychedelicious
d27f948b78 fix(ui): merge conflicts in image deletion listener 2024-08-30 22:18:48 +10:00
psychedelicious
b7aab81717 fix(ui): region rendering 2024-08-30 22:18:48 +10:00
psychedelicious
2998287f61 fix(ui): inpaint mask rendering 2024-08-30 22:18:48 +10:00
psychedelicious
55d7f0ff5b fix(ui): staging area rendering 2024-08-30 22:18:48 +10:00
psychedelicious
4564f36d4a fix(ui): stale selected entity 2024-08-30 22:18:48 +10:00
psychedelicious
319de5c4e9 fix(ui): staging area image offset 2024-08-30 22:18:48 +10:00
psychedelicious
eee499faa3 feat(ui): tweak layer ui component 2024-08-30 22:18:48 +10:00
psychedelicious
63c5e42f2a fix(ui): resetting layer resets position 2024-08-30 22:18:48 +10:00
psychedelicious
bd16dc4479 feat(ui): updated layer list component styling 2024-08-30 22:18:48 +10:00
psychedelicious
49371ddec9 feat(ui): transformable layers 2024-08-30 22:18:48 +10:00
psychedelicious
6a10d31b19 feat(ui): move tool icon is pointer like in other apps 2024-08-30 22:18:48 +10:00
psychedelicious
c951e733d3 feat(ui): do not floor cursor position 2024-08-30 22:18:48 +10:00
psychedelicious
7ed24cf847 feat(ui): disable gallery hotkeys while staging 2024-08-30 22:18:48 +10:00
psychedelicious
821b7a0435 feat(ui): revised canvas progress & staging image handling 2024-08-30 22:18:48 +10:00
psychedelicious
1b0344c412 feat(ui): show queue item origin in queue list 2024-08-30 22:18:48 +10:00
psychedelicious
03ca3c4b3d chore(ui): typegen 2024-08-30 22:18:48 +10:00
psychedelicious
b939192b16 feat(app): add origin to session queue
The origin is an optional field indicating the queue item's origin. For example, "canvas" when the queue item originated from the canvas or "workflows" when the queue item originated from the workflows tab. If omitted, we assume the queue item originated from the API directly.

- Add migration to add the nullable column to the `session_queue` table.
- Update relevant event payloads with the new field.
- Add `cancel_by_origin` method to `session_queue` service and corresponding route. This is required for the canvas to bail out early when staging images.
- Add `origin` to both `SessionQueueItem` and `Batch` - it needs to be provided initially via the batch and then passed onto the queue item.
-
2024-08-30 22:18:48 +10:00
psychedelicious
7ccf559a06 fix(ui): denoise start on outpainting 2024-08-30 22:18:48 +10:00
psychedelicious
9eb091f873 feat(ui): add redux events for queue cleared & batch enqueued socket events 2024-08-30 22:18:48 +10:00
psychedelicious
3bd5521641 feat(ui): canvas staging area works 2024-08-30 22:18:48 +10:00
psychedelicious
ced748e419 feat(ui): switch to view tool when staging 2024-08-30 22:18:48 +10:00
psychedelicious
fbd137da9f tidy(ui): disable preview images on every enqueue 2024-08-30 22:18:48 +10:00
psychedelicious
03baebced6 feat(ui): rough out save staging image 2024-08-30 22:18:48 +10:00
psychedelicious
cb19c1c370 feat(ui): staging area image visibility toggle 2024-08-30 22:18:47 +10:00
psychedelicious
788bad61d0 fix(ui): batch building after removing canvas files 2024-08-30 22:18:47 +10:00
psychedelicious
8f5f9bd44e feat(ui): make Graph class's getMetadataNode public 2024-08-30 22:18:47 +10:00
psychedelicious
2873e3e084 tidy(ui): remove old canvas graphs 2024-08-30 22:18:47 +10:00
psychedelicious
b004f17ae3 fix(ui): do not select already-selected entity 2024-08-30 22:18:47 +10:00
psychedelicious
bea1e8c99b tidy(ui): naming things 2024-08-30 22:18:47 +10:00
psychedelicious
111493223f tidy(ui): file organisation 2024-08-30 22:18:47 +10:00
psychedelicious
0a5ac2baec fix(ui): reset cursor pos when fitting document 2024-08-30 22:18:47 +10:00
psychedelicious
eec3c3b884 feat(ui): staging area works more better 2024-08-30 22:18:47 +10:00
psychedelicious
07b72c3d70 feat(ui): staging area barely works 2024-08-30 22:18:47 +10:00
psychedelicious
766e8c4eb0 feat(ui): consolidate konva API 2024-08-30 22:18:47 +10:00
psychedelicious
57c257d10d feat(ui): consolidate konva API 2024-08-30 22:18:47 +10:00
psychedelicious
d497da0e61 feat(ui): staging area (rendering wip) 2024-08-30 22:18:47 +10:00
psychedelicious
62310e7929 tidy(ui): type "Dimensions" -> "Size" 2024-08-30 22:18:47 +10:00
psychedelicious
d79aa173a6 feat(ui): add updateNode to Graph 2024-08-30 22:18:47 +10:00
psychedelicious
fbfdd3e003 feat(ui): sdxl graphs 2024-08-30 22:18:47 +10:00
psychedelicious
a62b4a26ef feat(ui): sd1 outpaint graph 2024-08-30 22:18:47 +10:00
psychedelicious
817d4168c6 tests(ui): add missing tests for Graph class 2024-08-30 22:18:47 +10:00
psychedelicious
7e0a6d1538 feat(ui): add Graph.getid() util 2024-08-30 22:18:47 +10:00
psychedelicious
ebc498ad19 feat(ui): outpaint graph, organize builder a bit 2024-08-30 22:18:47 +10:00
psychedelicious
b97b8c6ce6 feat(ui): inpaint sd1 graph 2024-08-30 22:18:47 +10:00
psychedelicious
b8abff65a1 feat(ui): temp disable image caching while testing 2024-08-30 22:18:47 +10:00
psychedelicious
a953dc1dbd feat(ui): txt2img & img2img graphs 2024-08-30 22:18:47 +10:00
psychedelicious
a7c9848e99 feat(ui): minor change to canvas bbox state type 2024-08-30 22:18:47 +10:00
psychedelicious
73a1449eaf feat(ui): simplified konva node to blob/imagedata utils 2024-08-30 22:18:47 +10:00
psychedelicious
59f57ff542 feat(ui): node manager getter/setter 2024-08-30 22:18:47 +10:00
psychedelicious
e9204b87e3 feat(ui): generation mode calculation, fudged graphs 2024-08-30 22:18:47 +10:00
psychedelicious
7dd11bd60a feat(ui): add utils for getting images from canvas 2024-08-30 22:18:47 +10:00
psychedelicious
275fc2ccf9 feat(ui): even more simplified API - lean on the konva node manager to abstract imperative state API & rendering 2024-08-30 22:18:47 +10:00
psychedelicious
a2ef8d9d47 feat(ui): revised docstrings for renderers & simplified api 2024-08-30 22:18:47 +10:00
psychedelicious
196779ff19 feat(ui): inpaint mask UI components 2024-08-30 22:18:47 +10:00
psychedelicious
aee3147365 feat(ui): inpaint mask rendering (wip) 2024-08-30 22:18:47 +10:00
psychedelicious
eaca940956 fix(ui): models loaded handler 2024-08-30 22:18:47 +10:00
psychedelicious
06006733e2 feat(ui): internal state for inpaint mask 2024-08-30 22:18:47 +10:00
psychedelicious
14d0bfbef6 refactor(ui): divvy up canvas state a bit 2024-08-30 22:18:47 +10:00
psychedelicious
0c9cf73702 feat(ui): get region and base layer canvas to blob logic working 2024-08-30 22:18:47 +10:00
psychedelicious
3b864921ac refactor(ui): node manager handles more tedious annoying stuff 2024-08-30 22:18:47 +10:00
psychedelicious
f41539532f feat(ui): use node manager for addRegions 2024-08-30 22:18:47 +10:00
psychedelicious
657009c254 feat(ui): persist bbox 2024-08-30 22:18:47 +10:00
psychedelicious
c47e02c309 fix(ui): fix generation graphs 2024-08-30 22:18:47 +10:00
psychedelicious
ce8a7bc178 feat(ui): add toggle for clipToBbox 2024-08-30 22:18:47 +10:00
psychedelicious
488ca87787 feat(ui): rename konva node manager 2024-08-30 22:18:47 +10:00
psychedelicious
d965df8ca9 refactor(ui): create classes to abstract mgmt of konva nodes 2024-08-30 22:18:47 +10:00
psychedelicious
995c26751e tidy(ui): organise renderers 2024-08-30 22:18:47 +10:00
psychedelicious
dd09723a2a refactor(ui): create entity to konva node map abstraction (wip)
Instead of chaining konva `find` and `findOne` methods, all konva nodes are added to a mapping object. Finding and manipulating them is much simpler.

Done for regions and layers, wip for control adapters.
2024-08-30 22:18:47 +10:00
psychedelicious
5ff5af3ba2 perf(ui): fix lag w/ region rendering
Needed to memoize these selectors
2024-08-30 22:18:47 +10:00
psychedelicious
4cb85404c0 feat(ui): move canvas fill color picker to right 2024-08-30 22:18:47 +10:00
psychedelicious
50bc2f100d refactor(ui): remove unused ellipse & polygon objects 2024-08-30 22:18:47 +10:00
psychedelicious
f65ce6a019 fix(ui): incorrect rect/brush/eraser positions 2024-08-30 22:18:47 +10:00
psychedelicious
c28b635f2d refactor(ui): enable global debugging flag 2024-08-30 22:18:47 +10:00
psychedelicious
e55896240d refactor(ui): disable the preview renderer for now 2024-08-30 22:18:47 +10:00
psychedelicious
2b478ee7e1 tweak(ui): canvas editor layout 2024-08-30 22:18:47 +10:00
psychedelicious
69912a35ea perf(ui): memoize layeractionsmenu valid actions 2024-08-30 22:18:47 +10:00
psychedelicious
9f1bd98c7e refactor(ui): decouple konva renderer from react
Subscribe to redux store directly, skipping all the react overhead.

With react in dev mode, a typical frame while using the brush tool on almost-empty canvas is reduced from ~7.5ms to ~3.5ms. All things considered, this still feels slow, but it's a massive improvement.
2024-08-30 22:18:47 +10:00
psychedelicious
b531d6b7f0 feat(ui): clip lines to bbox 2024-08-30 22:18:47 +10:00
psychedelicious
8aa963fb81 fix(ui): document fit positioning 2024-08-30 22:18:47 +10:00
psychedelicious
b76e0ab4e4 feat(ui): document bounds overlay 2024-08-30 22:18:47 +10:00
psychedelicious
aea03b4e92 tidy(ui): background layer 2024-08-30 22:18:47 +10:00
psychedelicious
b39e95966c refactor(ui): use "entity" instead of "data" for canvas 2024-08-30 22:18:47 +10:00
psychedelicious
d53e5e0158 feat(ui): brush size border radius = 1 2024-08-30 22:18:47 +10:00
psychedelicious
0368dd651b fix(ui): canvas HUD doesn't interrupt tool 2024-08-30 22:18:47 +10:00
psychedelicious
84a4a1024e refactor(ui): split up canvas entity renderers, temp disable preview 2024-08-30 22:18:47 +10:00
psychedelicious
af4f258489 fix(ui): delete all layers button 2024-08-30 22:18:47 +10:00
psychedelicious
ddfc8785b4 fix(ui): ignore keyboard shortcuts in input/textarea elements 2024-08-30 22:18:47 +10:00
psychedelicious
d8515b6efc fix(ui): canvas entity ids getting clobbered 2024-08-30 22:18:47 +10:00
psychedelicious
6a07f007a4 fix(ui): move lora followup fixes 2024-08-30 22:18:47 +10:00
psychedelicious
7a5a0c8075 chore(ui): lint 2024-08-30 22:18:47 +10:00
psychedelicious
5ed2e9b0fc refactor(ui): move loras to canvas slice 2024-08-30 22:18:47 +10:00
psychedelicious
aeb0a45eb6 fix(ui): layer is selected when added 2024-08-30 22:18:47 +10:00
psychedelicious
21e814d766 feat(ui): r to center & fit stage on document 2024-08-30 22:18:47 +10:00
psychedelicious
cafc1839e2 feat(ui): better HUD 2024-08-30 22:18:47 +10:00
psychedelicious
e937aa831f fix(ui): always use current brush width when making straight lines 2024-08-30 22:18:47 +10:00
psychedelicious
890e6a95ed feat(ui): hold shift w/ brush to draw straight line 2024-08-30 22:18:47 +10:00
psychedelicious
a5b7274359 fix(ui): update bg on canvas resize 2024-08-30 22:18:47 +10:00
psychedelicious
172acf2cf5 refactor(ui): better hud 2024-08-30 22:18:47 +10:00
psychedelicious
b49fdf6407 refactor(ui): scaled tool preview border 2024-08-30 22:18:47 +10:00
psychedelicious
5184d05bc2 refactor(ui): port remaining canvasV1 rendering logic to V2, remove old code 2024-08-30 22:18:47 +10:00
psychedelicious
7ef4553fc9 refactor(ui): fix more types 2024-08-30 22:18:47 +10:00
psychedelicious
d6bd1e4a49 refactor(ui): metadata recall (wip)
just enough let the app run
2024-08-30 22:18:47 +10:00
psychedelicious
29413f20a7 refactor(ui): undo/redo button temp fix 2024-08-30 22:18:47 +10:00
psychedelicious
04a44c8ea7 refactor(ui): fix renderer stuff 2024-08-30 22:18:47 +10:00
psychedelicious
426f1b6f9a refactor(ui): fix misc types 2024-08-30 22:18:47 +10:00
psychedelicious
9c7f5ed321 refactor(ui): fix gallery stuff 2024-08-30 22:18:47 +10:00
psychedelicious
4c37c7f280 refactor(ui): fix delete image stuff 2024-08-30 22:18:47 +10:00
psychedelicious
a2d13cacbf refactor(ui): fix useIsReadyToEnqueue for new adapterType field 2024-08-30 22:18:47 +10:00
psychedelicious
aa127b83a3 refactor(ui): update generation tab graphs 2024-08-30 22:18:47 +10:00
psychedelicious
e55192ae2a refactor(ui): add adapterType to ControlAdapterData 2024-08-30 22:18:47 +10:00
psychedelicious
5159fcbc33 refactor(ui): update components & logic to use new unified slice (again) 2024-08-30 22:18:47 +10:00
psychedelicious
02ad7a0f93 refactor(ui): update components & logic to use new unified slice 2024-08-30 22:18:47 +10:00
psychedelicious
bfa496e37f refactor(ui): merge compositing, params into canvasV2 slice 2024-08-30 22:18:47 +10:00
psychedelicious
fdf347af26 refactor(ui): add scaled bbox state 2024-08-30 22:18:47 +10:00
psychedelicious
0833dbb19d refactor(ui): update dnd/image upload 2024-08-30 22:18:47 +10:00
psychedelicious
1b6bf58e58 refactor(ui): update size/prompts state 2024-08-30 22:18:47 +10:00
psychedelicious
5ead7bc7b4 refactor(ui): rip out old control adapter implementation 2024-08-30 22:18:47 +10:00
psychedelicious
f326d17856 refactor(ui): canvas v2 (wip)
fix entity count select
2024-08-30 22:18:47 +10:00
psychedelicious
908aa9beea refactor(ui): canvas v2 (wip)
delete unused file
2024-08-30 22:18:47 +10:00
psychedelicious
4071e96245 refactor(ui): canvas v2 (wip)
merge all canvas state reducers into one big slice (but with the logic split across files so it's not hell)
2024-08-30 22:18:47 +10:00
psychedelicious
b4daf29bd8 refactor(ui): canvas v2 (wip)
Fix a few more components
2024-08-30 22:18:47 +10:00
psychedelicious
bf185339c2 refactor(ui): canvas v2 (wip)
missed a spot
2024-08-30 22:18:47 +10:00
psychedelicious
df3abc75c2 refactor(ui): canvas v2 (wip)
Redo all UI components for different canvas entity types
2024-08-30 22:18:47 +10:00
psychedelicious
28fc9a387c refactor(ui): canvas v2 (wip) 2024-08-30 22:18:47 +10:00
psychedelicious
8533f207dc refactor(ui): canvas v2 (wip) 2024-08-30 22:18:47 +10:00
psychedelicious
d135c48319 refactor(ui): canvas v2 (wip) 2024-08-30 22:18:47 +10:00
psychedelicious
ca9090d070 refactor(ui): canvas v2 (wip) 2024-08-30 22:18:47 +10:00
psychedelicious
93b185dc3b feat(ui): bbox tool 2024-08-30 22:18:47 +10:00
psychedelicious
98e5efa895 fix(ui): rect tool preview 2024-08-30 22:18:47 +10:00
psychedelicious
c6774b829d fix(ui): multiple stages 2024-08-30 22:18:47 +10:00
psychedelicious
22925f92bd feat(ui): decouple konva logic from nanostores 2024-08-30 22:18:47 +10:00
psychedelicious
302efcf6e8 feat(ui): store all stage attrs in nanostores 2024-08-30 22:18:47 +10:00
psychedelicious
76f9f90f0a feat(ui): round stage scale 2024-08-30 22:18:47 +10:00
psychedelicious
5ba338e471 chore(ui): bump konva 2024-08-30 22:18:47 +10:00
psychedelicious
01f101c6f2 feat(ui): generation bbox transformation working
whew
2024-08-30 22:18:47 +10:00
psychedelicious
5606aec78d feat(ui): wip generation bbox 2024-08-30 22:18:47 +10:00
psychedelicious
db90e1fe8b feat(ui): wip generation bbox 2024-08-30 22:18:47 +10:00
psychedelicious
ae96c479f2 feat(ui): CL zoom and pan, some rendering optimizations 2024-08-30 22:18:47 +10:00
psychedelicious
344ed2c83e Revert "feat(ui): add x,y,scaleX,scaleY,rotation to objects"
This reverts commit 53318b396c967c72326a7e4dea09667b2ab20bdd.
2024-08-30 22:18:47 +10:00
psychedelicious
1985944659 feat(ui): layers manage their own bbox 2024-08-30 22:18:47 +10:00
psychedelicious
915357a6c1 docs(ui): konva image object docstrings 2024-08-30 22:18:47 +10:00
psychedelicious
63c34e78d7 feat(ui): add x,y,scaleX,scaleY,rotation to objects 2024-08-30 22:18:47 +10:00
psychedelicious
366c460c1f fix(ui): show color picker when using rect tool 2024-08-30 22:18:47 +10:00
psychedelicious
40cab08133 feat(ui): image loading fallback for raster layers 2024-08-30 22:18:47 +10:00
psychedelicious
51de25122a feat(ui): bbox calc for raster layers 2024-08-30 22:18:47 +10:00
psychedelicious
90313091db feat(ui): do not fill brush preview when drawing 2024-08-30 22:18:47 +10:00
psychedelicious
9982219d18 fix(ui): brush spacing handling 2024-08-30 22:18:47 +10:00
psychedelicious
b3fe03b8f9 fix(ui): jank when starting a shape when not already focused on stage 2024-08-30 22:18:47 +10:00
psychedelicious
6edd15d68a feat(ui): wip raster layers
I meant to split this up into smaller commits and undo some of it, but I committed afterwards and it's tedious to undo.
2024-08-30 22:18:47 +10:00
psychedelicious
0e2b328c88 feat(ui): support image objects on raster layers
Just the UI and internal state, not rendering yet.
2024-08-30 22:18:47 +10:00
psychedelicious
25d7f9c316 tidy(ui): clean up event handlers
Separate logic for each tool in preparation for ellipse and polygon tools.
2024-08-30 22:18:47 +10:00
psychedelicious
3870ebdf29 feat(ui): raster layer reset, object group util 2024-08-30 22:18:47 +10:00
psychedelicious
7595d05191 feat(ui): rect shape preview now has fill 2024-08-30 22:18:47 +10:00
psychedelicious
21af727d79 feat(ui): cancel shape drawing on esc 2024-08-30 22:18:47 +10:00
psychedelicious
5691829de6 feat(ui): temp disable history on CL 2024-08-30 22:18:47 +10:00
psychedelicious
20e6a57cf1 feat(ui): raster layer logic
- Deduplicate shared logic
- Split up giant renderers file into separate cohesive files
- Tons of cleanup
- Progress on raster layer functionality
2024-08-30 22:18:47 +10:00
psychedelicious
d0c40a8b5b feat(ui): add raster layer rendering and interaction (WIP) 2024-08-30 22:18:46 +10:00
psychedelicious
f663215f25 feat(ui): scaffold out raster layers
Raster layers may have images, lines and shapes. These will replace initial image layers and provide sketching functionality like we have on canvas.
2024-08-30 22:18:46 +10:00
psychedelicious
7c5dea6d12 refactor(ui): revise types for line and rect objects
- Create separate object types for brush and eraser lines, instead of a single type that has a `tool` field.
- Create new object type for rect shapes.
- Add logic to schemas to migrate old object types to new.
- Update renderers & reducers.
2024-08-30 22:18:46 +10:00
902 changed files with 25225 additions and 34890 deletions

View File

@@ -196,22 +196,6 @@ tips to reduce the problem:
=== "12GB VRAM GPU"
This should be sufficient to generate larger images up to about 1280x1280.
## Checkpoint Models Load Slowly or Use Too Much RAM
The difference between diffusers models (a folder containing multiple
subfolders) and checkpoint models (a file ending with .safetensors or
.ckpt) is that InvokeAI is able to load diffusers models into memory
incrementally, while checkpoint models must be loaded all at
once. With very large models, or systems with limited RAM, you may
experience slowdowns and other memory-related issues when loading
checkpoint models.
To solve this, go to the Model Manager tab (the cube), select the
checkpoint model that's giving you trouble, and press the "Convert"
button in the upper right of your browser window. This will conver the
checkpoint into a diffusers model, after which loading should be
faster and less memory-intensive.
## Memory Leak (Linux)

View File

@@ -3,10 +3,8 @@
import io
import pathlib
import shutil
import traceback
from copy import deepcopy
from enum import Enum
from tempfile import TemporaryDirectory
from typing import List, Optional, Type
@@ -19,7 +17,6 @@ from starlette.exceptions import HTTPException
from typing_extensions import Annotated
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.config import get_config
from invokeai.app.services.model_images.model_images_common import ModelImageFileNotFoundException
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
from invokeai.app.services.model_records import (
@@ -34,7 +31,6 @@ from invokeai.backend.model_manager.config import (
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.load.model_cache.model_cache_base import CacheStats
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
from invokeai.backend.model_manager.search import ModelSearch
@@ -54,13 +50,6 @@ class ModelsList(BaseModel):
model_config = ConfigDict(use_enum_values=True)
class CacheType(str, Enum):
"""Cache type - one of vram or ram."""
RAM = "RAM"
VRAM = "VRAM"
def add_cover_image_to_model_config(config: AnyModelConfig, dependencies: Type[ApiDependencies]) -> AnyModelConfig:
"""Add a cover image URL to a model configuration."""
cover_image = dependencies.invoker.services.model_images.get_url(config.key)
@@ -808,83 +797,3 @@ async def get_starter_models() -> list[StarterModel]:
model.dependencies = missing_deps
return starter_models
@model_manager_router.get(
"/model_cache",
operation_id="get_cache_size",
response_model=float,
summary="Get maximum size of model manager RAM or VRAM cache.",
)
async def get_cache_size(cache_type: CacheType = Query(description="The cache type", default=CacheType.RAM)) -> float:
"""Return the current RAM or VRAM cache size setting (in GB)."""
cache = ApiDependencies.invoker.services.model_manager.load.ram_cache
value = 0.0
if cache_type == CacheType.RAM:
value = cache.max_cache_size
elif cache_type == CacheType.VRAM:
value = cache.max_vram_cache_size
return value
@model_manager_router.put(
"/model_cache",
operation_id="set_cache_size",
response_model=float,
summary="Set maximum size of model manager RAM or VRAM cache, optionally writing new value out to invokeai.yaml config file.",
)
async def set_cache_size(
value: float = Query(description="The new value for the maximum cache size"),
cache_type: CacheType = Query(description="The cache type", default=CacheType.RAM),
persist: bool = Query(description="Write new value out to invokeai.yaml", default=False),
) -> float:
"""Set the current RAM or VRAM cache size setting (in GB). ."""
cache = ApiDependencies.invoker.services.model_manager.load.ram_cache
app_config = get_config()
# Record initial state.
vram_old = app_config.vram
ram_old = app_config.ram
# Prepare target state.
vram_new = vram_old
ram_new = ram_old
if cache_type == CacheType.RAM:
ram_new = value
elif cache_type == CacheType.VRAM:
vram_new = value
else:
raise ValueError(f"Unexpected {cache_type=}.")
config_path = app_config.config_file_path
new_config_path = config_path.with_suffix(".yaml.new")
try:
# Try to apply the target state.
cache.max_vram_cache_size = vram_new
cache.max_cache_size = ram_new
app_config.ram = ram_new
app_config.vram = vram_new
if persist:
app_config.write_file(new_config_path)
shutil.move(new_config_path, config_path)
except Exception as e:
# If there was a failure, restore the initial state.
cache.max_cache_size = ram_old
cache.max_vram_cache_size = vram_old
app_config.ram = ram_old
app_config.vram = vram_old
raise RuntimeError("Failed to update cache size") from e
return value
@model_manager_router.get(
"/stats",
operation_id="get_stats",
response_model=Optional[CacheStats],
summary="Get model manager RAM cache performance statistics.",
)
async def get_stats() -> Optional[CacheStats]:
"""Return performance statistics on the model manager's RAM cache. Will return null if no models have been loaded."""
return ApiDependencies.invoker.services.model_manager.load.ram_cache.stats

View File

@@ -11,6 +11,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
Batch,
BatchStatus,
CancelByBatchIDsResult,
CancelByOriginResult,
ClearResult,
EnqueueBatchResult,
PruneResult,
@@ -105,6 +106,19 @@ async def cancel_by_batch_ids(
return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(queue_id=queue_id, batch_ids=batch_ids)
@session_queue_router.put(
"/{queue_id}/cancel_by_origin",
operation_id="cancel_by_origin",
responses={200: {"model": CancelByBatchIDsResult}},
)
async def cancel_by_origin(
queue_id: str = Path(description="The queue id to perform this operation on"),
origin: str = Query(description="The origin to cancel all queue items for"),
) -> CancelByOriginResult:
"""Immediately cancels all queue items with the given origin"""
return ApiDependencies.invoker.services.session_queue.cancel_by_origin(queue_id=queue_id, origin=origin)
@session_queue_router.put(
"/{queue_id}/clear",
operation_id="clear",

View File

@@ -20,7 +20,6 @@ from typing import (
Type,
TypeVar,
Union,
cast,
)
import semver
@@ -80,7 +79,7 @@ class UIConfigBase(BaseModel):
version: str = Field(
description='The node\'s version. Should be a valid semver string e.g. "1.0.0" or "3.8.13".',
)
node_pack: Optional[str] = Field(default=None, description="Whether or not this is a custom node")
node_pack: str = Field(description="The node pack that this node belongs to, will be 'invokeai' for built-in nodes")
classification: Classification = Field(default=Classification.Stable, description="The node's classification")
model_config = ConfigDict(
@@ -230,18 +229,16 @@ class BaseInvocation(ABC, BaseModel):
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None:
"""Adds various UI-facing attributes to the invocation's OpenAPI schema."""
uiconfig = cast(UIConfigBase | None, getattr(model_class, "UIConfig", None))
if uiconfig is not None:
if uiconfig.title is not None:
schema["title"] = uiconfig.title
if uiconfig.tags is not None:
schema["tags"] = uiconfig.tags
if uiconfig.category is not None:
schema["category"] = uiconfig.category
if uiconfig.node_pack is not None:
schema["node_pack"] = uiconfig.node_pack
schema["classification"] = uiconfig.classification
schema["version"] = uiconfig.version
if title := model_class.UIConfig.title:
schema["title"] = title
if tags := model_class.UIConfig.tags:
schema["tags"] = tags
if category := model_class.UIConfig.category:
schema["category"] = category
if node_pack := model_class.UIConfig.node_pack:
schema["node_pack"] = node_pack
schema["classification"] = model_class.UIConfig.classification
schema["version"] = model_class.UIConfig.version
if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = []
schema["class"] = "invocation"
@@ -312,7 +309,7 @@ class BaseInvocation(ABC, BaseModel):
json_schema_extra={"field_kind": FieldKind.NodeAttribute},
)
UIConfig: ClassVar[Type[UIConfigBase]]
UIConfig: ClassVar[UIConfigBase]
model_config = ConfigDict(
protected_namespaces=(),
@@ -441,30 +438,25 @@ def invocation(
validate_fields(cls.model_fields, invocation_type)
# Add OpenAPI schema extras
uiconfig_name = cls.__qualname__ + ".UIConfig"
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconfig_name:
cls.UIConfig = type(uiconfig_name, (UIConfigBase,), {})
cls.UIConfig.title = title
cls.UIConfig.tags = tags
cls.UIConfig.category = category
cls.UIConfig.classification = classification
# Grab the node pack's name from the module name, if it's a custom node
is_custom_node = cls.__module__.rsplit(".", 1)[0] == "invokeai.app.invocations"
if is_custom_node:
cls.UIConfig.node_pack = cls.__module__.split(".")[0]
else:
cls.UIConfig.node_pack = None
uiconfig: dict[str, Any] = {}
uiconfig["title"] = title
uiconfig["tags"] = tags
uiconfig["category"] = category
uiconfig["classification"] = classification
# The node pack is the module name - will be "invokeai" for built-in nodes
uiconfig["node_pack"] = cls.__module__.split(".")[0]
if version is not None:
try:
semver.Version.parse(version)
except ValueError as e:
raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e
cls.UIConfig.version = version
uiconfig["version"] = version
else:
logger.warn(f'No version specified for node "{invocation_type}", using "1.0.0"')
cls.UIConfig.version = "1.0.0"
uiconfig["version"] = "1.0.0"
cls.UIConfig = UIConfigBase(**uiconfig)
if use_cache is not None:
cls.model_fields["use_cache"].default = use_cache

View File

@@ -19,8 +19,8 @@ from invokeai.app.invocations.model import CLIPField
from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.ti_utils import generate_ti_list
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.peft.lora import LoRAModelRaw
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
ConditioningFieldData,

View File

@@ -36,9 +36,9 @@ from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.peft.lora import LoRAModelRaw
from invokeai.backend.stable_diffusion import PipelineIntermediateState
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
@@ -185,7 +185,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
)
denoise_mask: Optional[DenoiseMaskField] = InputField(
default=None,
description=FieldDescriptions.denoise_mask,
description=FieldDescriptions.mask,
input=Input.Connection,
ui_order=8,
)

View File

@@ -181,7 +181,7 @@ class FieldDescriptions:
)
num_1 = "The first number"
num_2 = "The second number"
denoise_mask = "A mask of the region to apply the denoising process to."
mask = "The mask to use for the operation"
board = "The board to save the image to"
image = "The image to process"
tile_size = "Tile size"

View File

@@ -1,292 +0,0 @@
from typing import Callable, Iterator, Optional, Tuple
import torch
import torchvision.transforms as tv_transforms
from torchvision.transforms.functional import resize as tv_resize
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
DenoiseMaskField,
FieldDescriptions,
FluxConditioningField,
Input,
InputField,
LatentsField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import TransformerField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.session_processor.session_processor_common import CanceledException
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.denoise import denoise
from invokeai.backend.flux.inpaint_extension import InpaintExtension
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.sampling_utils import (
clip_timestep_schedule,
generate_img_ids,
get_noise,
get_schedule,
pack,
unpack,
)
from invokeai.backend.peft.lora import LoRAModelRaw
from invokeai.backend.peft.peft_patcher import PeftPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
from invokeai.backend.util.devices import TorchDevice
@invocation(
"flux_denoise",
title="FLUX Denoise",
tags=["image", "flux"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Run denoising process with a FLUX transformer model."""
# If latents is provided, this means we are doing image-to-image.
latents: Optional[LatentsField] = InputField(
default=None,
description=FieldDescriptions.latents,
input=Input.Connection,
)
# denoise_mask is used for image-to-image inpainting. Only the masked region is modified.
denoise_mask: Optional[DenoiseMaskField] = InputField(
default=None,
description=FieldDescriptions.denoise_mask,
input=Input.Connection,
)
denoising_start: float = InputField(
default=0.0,
ge=0,
le=1,
description=FieldDescriptions.denoising_start,
)
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
transformer: TransformerField = InputField(
description=FieldDescriptions.flux_model,
input=Input.Connection,
title="Transformer",
)
positive_text_conditioning: FluxConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
num_steps: int = InputField(
default=4, description="Number of diffusion steps. Recommended values are schnell: 4, dev: 50."
)
guidance: float = InputField(
default=4.0,
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.",
)
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = self._run_diffusion(context)
latents = latents.detach().to("cpu")
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
def _run_diffusion(
self,
context: InvocationContext,
):
inference_dtype = torch.bfloat16
# Load the conditioning data.
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
assert len(cond_data.conditionings) == 1
flux_conditioning = cond_data.conditionings[0]
assert isinstance(flux_conditioning, FLUXConditioningInfo)
flux_conditioning = flux_conditioning.to(dtype=inference_dtype)
t5_embeddings = flux_conditioning.t5_embeds
clip_embeddings = flux_conditioning.clip_embeds
# Load the input latents, if provided.
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
if init_latents is not None:
init_latents = init_latents.to(device=TorchDevice.choose_torch_device(), dtype=inference_dtype)
# Prepare input noise.
noise = get_noise(
num_samples=1,
height=self.height,
width=self.width,
device=TorchDevice.choose_torch_device(),
dtype=inference_dtype,
seed=self.seed,
)
transformer_info = context.models.load(self.transformer.transformer)
is_schnell = "schnell" in transformer_info.config.config_path
# Calculate the timestep schedule.
image_seq_len = noise.shape[-1] * noise.shape[-2] // 4
timesteps = get_schedule(
num_steps=self.num_steps,
image_seq_len=image_seq_len,
shift=not is_schnell,
)
# Clip the timesteps schedule based on denoising_start and denoising_end.
timesteps = clip_timestep_schedule(timesteps, self.denoising_start, self.denoising_end)
# Prepare input latent image.
if init_latents is not None:
# If init_latents is provided, we are doing image-to-image.
if is_schnell:
context.logger.warning(
"Running image-to-image with a FLUX schnell model. This is not recommended. The results are likely "
"to be poor. Consider using a FLUX dev model instead."
)
# Noise the orig_latents by the appropriate amount for the first timestep.
t_0 = timesteps[0]
x = t_0 * noise + (1.0 - t_0) * init_latents
else:
# init_latents are not provided, so we are not doing image-to-image (i.e. we are starting from pure noise).
if self.denoising_start > 1e-5:
raise ValueError("denoising_start should be 0 when initial latents are not provided.")
x = noise
# If len(timesteps) == 1, then short-circuit. We are just noising the input latents, but not taking any
# denoising steps.
if len(timesteps) <= 1:
return x
inpaint_mask = self._prep_inpaint_mask(context, x)
b, _c, h, w = x.shape
img_ids = generate_img_ids(h=h, w=w, batch_size=b, device=x.device, dtype=x.dtype)
bs, t5_seq_len, _ = t5_embeddings.shape
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
# Pack all latent tensors.
init_latents = pack(init_latents) if init_latents is not None else None
inpaint_mask = pack(inpaint_mask) if inpaint_mask is not None else None
noise = pack(noise)
x = pack(x)
# Now that we have 'packed' the latent tensors, verify that we calculated the image_seq_len correctly.
assert image_seq_len == x.shape[1]
# Prepare inpaint extension.
inpaint_extension: InpaintExtension | None = None
if inpaint_mask is not None:
assert init_latents is not None
inpaint_extension = InpaintExtension(
init_latents=init_latents,
inpaint_mask=inpaint_mask,
noise=noise,
)
with (
transformer_info.model_on_device() as (cached_weights, transformer),
# Apply the LoRA after transformer has been moved to its target device for faster patching.
PeftPatcher.apply_peft_patches(
model=transformer,
patches=self._lora_iterator(context),
prefix="",
cached_weights=cached_weights,
),
):
assert isinstance(transformer, Flux)
x = denoise(
model=transformer,
img=x,
img_ids=img_ids,
txt=t5_embeddings,
txt_ids=txt_ids,
vec=clip_embeddings,
timesteps=timesteps,
step_callback=self._build_step_callback(context),
guidance=self.guidance,
inpaint_extension=inpaint_extension,
)
x = unpack(x.float(), self.height, self.width)
return x
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None:
"""Prepare the inpaint mask.
- Loads the mask
- Resizes if necessary
- Casts to same device/dtype as latents
- Expands mask to the same shape as latents so that they line up after 'packing'
Args:
context (InvocationContext): The invocation context, for loading the inpaint mask.
latents (torch.Tensor): A latent image tensor. In 'unpacked' format. Used to determine the target shape,
device, and dtype for the inpaint mask.
Returns:
torch.Tensor | None: Inpaint mask.
"""
if self.denoise_mask is None:
return None
mask = context.tensors.load(self.denoise_mask.mask_name)
_, _, latent_height, latent_width = latents.shape
mask = tv_resize(
img=mask,
size=[latent_height, latent_width],
interpolation=tv_transforms.InterpolationMode.BILINEAR,
antialias=False,
)
mask = mask.to(device=latents.device, dtype=latents.dtype)
# Expand the inpaint mask to the same shape as `latents` so that when we 'pack' `mask` it lines up with
# `latents`.
return mask.expand_as(latents)
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.transformer.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
def _build_step_callback(self, context: InvocationContext) -> Callable[[], None]:
def step_callback() -> None:
if context.util.is_canceled():
raise CanceledException
# TODO: Make this look like the image before re-enabling
# latent_image = unpack(img.float(), self.height, self.width)
# latent_image = latent_image.squeeze() # Remove unnecessary dimensions
# flattened_tensor = latent_image.reshape(-1) # Flatten to shape [48*128*128]
# # Create a new tensor of the required shape [255, 255, 3]
# latent_image = flattened_tensor[: 255 * 255 * 3].reshape(255, 255, 3) # Reshape to RGB format
# # Convert to a NumPy array and then to a PIL Image
# image = Image.fromarray(latent_image.cpu().numpy().astype(np.uint8))
# (width, height) = image.size
# width *= 8
# height *= 8
# dataURL = image_to_dataURL(image, image_format="JPEG")
# # TODO: move this whole function to invocation context to properly reference these variables
# context._services.events.emit_invocation_denoise_progress(
# context._data.queue_item,
# context._data.invocation,
# state,
# ProgressImage(dataURL=dataURL, width=width, height=height),
# )
return step_callback

View File

@@ -1,53 +0,0 @@
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import LoRAField, ModelIdentifierField, TransformerField
from invokeai.app.services.shared.invocation_context import InvocationContext
@invocation_output("flux_lora_loader_output")
class FluxLoRALoaderOutput(BaseInvocationOutput):
"""FLUX LoRA Loader Output"""
transformer: TransformerField = OutputField(
default=None, description=FieldDescriptions.transformer, title="FLUX Transformer"
)
@invocation(
"flux_lora_loader",
title="FLUX LoRA",
tags=["lora", "model", "flux"],
category="model",
version="1.0.0",
)
class FluxLoRALoaderInvocation(BaseInvocation):
"""Apply a LoRA model to a FLUX transformer."""
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
transformer: TransformerField = InputField(
description=FieldDescriptions.transformer,
input=Input.Connection,
title="FLUX Transformer",
)
def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
lora_key = self.lora.key
if not context.models.exists(lora_key):
raise ValueError(f"Unknown lora: {lora_key}!")
if any(lora.lora.key == lora_key for lora in self.transformer.loras):
raise Exception(f'LoRA "{lora_key}" already applied to transformer.')
transformer = self.transformer.model_copy(deep=True)
transformer.loras.append(
LoRAField(
lora=self.lora,
weight=self.weight,
)
)
return FluxLoRALoaderOutput(transformer=transformer)

View File

@@ -0,0 +1,169 @@
import torch
from einops import rearrange
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
FluxConditioningField,
Input,
InputField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import TransformerField, VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.session_processor.session_processor_common import CanceledException
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, prepare_latent_img_patches, unpack
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
from invokeai.backend.util.devices import TorchDevice
@invocation(
"flux_text_to_image",
title="FLUX Text to Image",
tags=["image", "flux"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Text-to-image generation using a FLUX model."""
transformer: TransformerField = InputField(
description=FieldDescriptions.flux_model,
input=Input.Connection,
title="Transformer",
)
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
positive_text_conditioning: FluxConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
num_steps: int = InputField(
default=4, description="Number of diffusion steps. Recommend values are schnell: 4, dev: 50."
)
guidance: float = InputField(
default=4.0,
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.",
)
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = self._run_diffusion(context)
image = self._run_vae_decoding(context, latents)
image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto)
def _run_diffusion(
self,
context: InvocationContext,
):
inference_dtype = torch.bfloat16
# Load the conditioning data.
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
assert len(cond_data.conditionings) == 1
flux_conditioning = cond_data.conditionings[0]
assert isinstance(flux_conditioning, FLUXConditioningInfo)
flux_conditioning = flux_conditioning.to(dtype=inference_dtype)
t5_embeddings = flux_conditioning.t5_embeds
clip_embeddings = flux_conditioning.clip_embeds
transformer_info = context.models.load(self.transformer.transformer)
# Prepare input noise.
x = get_noise(
num_samples=1,
height=self.height,
width=self.width,
device=TorchDevice.choose_torch_device(),
dtype=inference_dtype,
seed=self.seed,
)
x, img_ids = prepare_latent_img_patches(x)
is_schnell = "schnell" in transformer_info.config.config_path
timesteps = get_schedule(
num_steps=self.num_steps,
image_seq_len=x.shape[1],
shift=not is_schnell,
)
bs, t5_seq_len, _ = t5_embeddings.shape
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
with transformer_info as transformer:
assert isinstance(transformer, Flux)
def step_callback() -> None:
if context.util.is_canceled():
raise CanceledException
# TODO: Make this look like the image before re-enabling
# latent_image = unpack(img.float(), self.height, self.width)
# latent_image = latent_image.squeeze() # Remove unnecessary dimensions
# flattened_tensor = latent_image.reshape(-1) # Flatten to shape [48*128*128]
# # Create a new tensor of the required shape [255, 255, 3]
# latent_image = flattened_tensor[: 255 * 255 * 3].reshape(255, 255, 3) # Reshape to RGB format
# # Convert to a NumPy array and then to a PIL Image
# image = Image.fromarray(latent_image.cpu().numpy().astype(np.uint8))
# (width, height) = image.size
# width *= 8
# height *= 8
# dataURL = image_to_dataURL(image, image_format="JPEG")
# # TODO: move this whole function to invocation context to properly reference these variables
# context._services.events.emit_invocation_denoise_progress(
# context._data.queue_item,
# context._data.invocation,
# state,
# ProgressImage(dataURL=dataURL, width=width, height=height),
# )
x = denoise(
model=transformer,
img=x,
img_ids=img_ids,
txt=t5_embeddings,
txt_ids=txt_ids,
vec=clip_embeddings,
timesteps=timesteps,
step_callback=step_callback,
guidance=self.guidance,
)
x = unpack(x.float(), self.height, self.width)
return x
def _run_vae_decoding(
self,
context: InvocationContext,
latents: torch.Tensor,
) -> Image.Image:
vae_info = context.models.load(self.vae.vae)
with vae_info as vae:
assert isinstance(vae, AutoEncoder)
latents = latents.to(dtype=TorchDevice.choose_torch_dtype())
img = vae.decode(latents)
img = img.clamp(-1, 1)
img = rearrange(img[0], "c h w -> h w c")
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
return img_pil

View File

@@ -1,60 +0,0 @@
import torch
from einops import rearrange
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
LatentsField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.util.devices import TorchDevice
@invocation(
"flux_vae_decode",
title="FLUX Latents to Image",
tags=["latents", "image", "vae", "l2i", "flux"],
category="latents",
version="1.0.0",
)
class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates an image from latents."""
latents: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
def _vae_decode(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image:
with vae_info as vae:
assert isinstance(vae, AutoEncoder)
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype())
img = vae.decode(latents)
img = img.clamp(-1, 1)
img = rearrange(img[0], "c h w -> h w c") # noqa: F821
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
return img_pil
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae)
image = self._vae_decode(vae_info=vae_info, latents=latents)
TorchDevice.empty_cache()
image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto)

View File

@@ -1,67 +0,0 @@
import einops
import torch
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
Input,
InputField,
)
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.model_manager import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.util.devices import TorchDevice
@invocation(
"flux_vae_encode",
title="FLUX Image to Latents",
tags=["latents", "image", "vae", "i2l", "flux"],
category="latents",
version="1.0.0",
)
class FluxVaeEncodeInvocation(BaseInvocation):
"""Encodes an image into latents."""
image: ImageField = InputField(
description="The image to encode.",
)
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
@staticmethod
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
# TODO(ryand): Expose seed parameter at the invocation level.
# TODO(ryand): Write a util function for generating random tensors that is consistent across devices / dtypes.
# There's a starting point in get_noise(...), but it needs to be extracted and generalized. This function
# should be used for VAE encode sampling.
generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0)
with vae_info as vae:
assert isinstance(vae, AutoEncoder)
image_tensor = image_tensor.to(
device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()
)
latents = vae.encode(image_tensor, sample=True, generator=generator)
return latents
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.images.get_pil(self.image.image_name)
vae_info = context.models.load(self.vae.vae)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
latents = latents.to("cpu")
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)

View File

@@ -6,13 +6,19 @@ import cv2
import numpy
from PIL import Image, ImageChops, ImageFilter, ImageOps
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.constants import IMAGE_MODES
from invokeai.app.invocations.fields import (
ColorField,
FieldDescriptions,
ImageField,
InputField,
OutputField,
WithBoard,
WithMetadata,
)
@@ -1007,3 +1013,62 @@ class MaskFromIDInvocation(BaseInvocation, WithMetadata, WithBoard):
image_dto = context.images.save(image=mask, image_category=ImageCategory.MASK)
return ImageOutput.build(image_dto)
@invocation_output("canvas_v2_mask_and_crop_output")
class CanvasV2MaskAndCropOutput(ImageOutput):
offset_x: int = OutputField(description="The x offset of the image, after cropping")
offset_y: int = OutputField(description="The y offset of the image, after cropping")
@invocation(
"canvas_v2_mask_and_crop",
title="Canvas V2 Mask and Crop",
tags=["image", "mask", "id"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Handles Canvas V2 image output masking and cropping"""
source_image: ImageField | None = InputField(
default=None,
description="The source image onto which the masked generated image is pasted. If omitted, the masked generated image is returned with transparency.",
)
generated_image: ImageField = InputField(description="The image to apply the mask to")
mask: ImageField = InputField(description="The mask to apply")
mask_blur: int = InputField(default=0, ge=0, description="The amount to blur the mask by")
def _prepare_mask(self, mask: Image.Image) -> Image.Image:
mask_array = numpy.array(mask)
kernel = numpy.ones((self.mask_blur, self.mask_blur), numpy.uint8)
dilated_mask_array = cv2.erode(mask_array, kernel, iterations=3)
dilated_mask = Image.fromarray(dilated_mask_array)
if self.mask_blur > 0:
mask = dilated_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
return ImageOps.invert(mask.convert("L"))
def invoke(self, context: InvocationContext) -> CanvasV2MaskAndCropOutput:
mask = self._prepare_mask(context.images.get_pil(self.mask.image_name))
if self.source_image:
generated_image = context.images.get_pil(self.generated_image.image_name)
source_image = context.images.get_pil(self.source_image.image_name)
source_image.paste(generated_image, (0, 0), mask)
image_dto = context.images.save(image=source_image)
else:
generated_image = context.images.get_pil(self.generated_image.image_name)
generated_image.putalpha(mask)
image_dto = context.images.save(image=generated_image)
# bbox = image.getbbox()
# image = image.crop(bbox)
return CanvasV2MaskAndCropOutput(
image=ImageField(image_name=image_dto.image_name),
offset_x=0,
offset_y=0,
width=image_dto.width,
height=image_dto.height,
)

View File

@@ -126,7 +126,7 @@ class ImageMaskToTensorInvocation(BaseInvocation, WithMetadata):
title="Tensor Mask to Image",
tags=["mask"],
category="mask",
version="1.1.0",
version="1.0.0",
)
class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Convert a mask tensor to an image."""
@@ -135,11 +135,6 @@ class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
mask = context.tensors.load(self.mask.tensor_name)
# Squeeze the channel dimension if it exists.
if mask.dim() == 3:
mask = mask.squeeze(0)
# Ensure that the mask is binary.
if mask.dtype != torch.bool:
mask = mask > 0.5

View File

@@ -69,7 +69,6 @@ class CLIPField(BaseModel):
class TransformerField(BaseModel):
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
class T5EncoderField(BaseModel):
@@ -203,7 +202,7 @@ class FluxModelLoaderInvocation(BaseInvocation):
assert isinstance(transformer_config, CheckpointConfigBase)
return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[]),
transformer=TransformerField(transformer=transformer),
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
vae=VAEField(vae=vae),

View File

@@ -22,8 +22,8 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.model import UNetField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.peft.lora import LoRAModelRaw
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, PipelineIntermediateState
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
MultiDiffusionPipeline,

View File

@@ -88,6 +88,8 @@ class QueueItemEventBase(QueueEventBase):
item_id: int = Field(description="The ID of the queue item")
batch_id: str = Field(description="The ID of the queue batch")
origin: str | None = Field(default=None, description="The origin of the queue item")
destination: str | None = Field(default=None, description="The destination of the queue item")
class InvocationEventBase(QueueItemEventBase):
@@ -95,8 +97,6 @@ class InvocationEventBase(QueueItemEventBase):
session_id: str = Field(description="The ID of the session (aka graph execution state)")
queue_id: str = Field(description="The ID of the queue")
item_id: int = Field(description="The ID of the queue item")
batch_id: str = Field(description="The ID of the queue batch")
session_id: str = Field(description="The ID of the session (aka graph execution state)")
invocation: AnyInvocation = Field(description="The ID of the invocation")
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
@@ -114,6 +114,8 @@ class InvocationStartedEvent(InvocationEventBase):
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
origin=queue_item.origin,
destination=queue_item.destination,
session_id=queue_item.session_id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
@@ -147,6 +149,8 @@ class InvocationDenoiseProgressEvent(InvocationEventBase):
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
origin=queue_item.origin,
destination=queue_item.destination,
session_id=queue_item.session_id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
@@ -184,6 +188,8 @@ class InvocationCompleteEvent(InvocationEventBase):
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
origin=queue_item.origin,
destination=queue_item.destination,
session_id=queue_item.session_id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
@@ -216,6 +222,8 @@ class InvocationErrorEvent(InvocationEventBase):
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
origin=queue_item.origin,
destination=queue_item.destination,
session_id=queue_item.session_id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
@@ -253,6 +261,8 @@ class QueueItemStatusChangedEvent(QueueItemEventBase):
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
origin=queue_item.origin,
destination=queue_item.destination,
session_id=queue_item.session_id,
status=queue_item.status,
error_type=queue_item.error_type,
@@ -279,12 +289,14 @@ class BatchEnqueuedEvent(QueueEventBase):
description="The number of invocations initially requested to be enqueued (may be less than enqueued if queue was full)"
)
priority: int = Field(description="The priority of the batch")
origin: str | None = Field(default=None, description="The origin of the batch")
@classmethod
def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent":
return cls(
queue_id=enqueue_result.queue_id,
batch_id=enqueue_result.batch.batch_id,
origin=enqueue_result.batch.origin,
enqueued=enqueue_result.enqueued,
requested=enqueue_result.requested,
priority=enqueue_result.priority,

View File

@@ -103,7 +103,7 @@ class HFModelSource(StringLikeSource):
if self.variant:
base += f":{self.variant or ''}"
if self.subfolder:
base += f"::{self.subfolder.as_posix()}"
base += f":{self.subfolder}"
return base

View File

@@ -6,6 +6,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
Batch,
BatchStatus,
CancelByBatchIDsResult,
CancelByOriginResult,
CancelByQueueIDResult,
ClearResult,
EnqueueBatchResult,
@@ -95,6 +96,11 @@ class SessionQueueBase(ABC):
"""Cancels all queue items with matching batch IDs"""
pass
@abstractmethod
def cancel_by_origin(self, queue_id: str, origin: str) -> CancelByOriginResult:
"""Cancels all queue items with the given batch origin"""
pass
@abstractmethod
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
"""Cancels all queue items with matching queue ID"""

View File

@@ -77,6 +77,14 @@ BatchDataCollection: TypeAlias = list[list[BatchDatum]]
class Batch(BaseModel):
batch_id: str = Field(default_factory=uuid_string, description="The ID of the batch")
origin: str | None = Field(
default=None,
description="The origin of this queue item. This data is used by the frontend to determine how to handle results.",
)
destination: str | None = Field(
default=None,
description="The origin of this queue item. This data is used by the frontend to determine how to handle results",
)
data: Optional[BatchDataCollection] = Field(default=None, description="The batch data collection.")
graph: Graph = Field(description="The graph to initialize the session with")
workflow: Optional[WorkflowWithoutID] = Field(
@@ -195,6 +203,14 @@ class SessionQueueItemWithoutGraph(BaseModel):
status: QUEUE_ITEM_STATUS = Field(default="pending", description="The status of this queue item")
priority: int = Field(default=0, description="The priority of this queue item")
batch_id: str = Field(description="The ID of the batch associated with this queue item")
origin: str | None = Field(
default=None,
description="The origin of this queue item. This data is used by the frontend to determine how to handle results.",
)
destination: str | None = Field(
default=None,
description="The origin of this queue item. This data is used by the frontend to determine how to handle results",
)
session_id: str = Field(
description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed."
)
@@ -294,6 +310,8 @@ class SessionQueueStatus(BaseModel):
class BatchStatus(BaseModel):
queue_id: str = Field(..., description="The ID of the queue")
batch_id: str = Field(..., description="The ID of the batch")
origin: str | None = Field(..., description="The origin of the batch")
destination: str | None = Field(..., description="The destination of the batch")
pending: int = Field(..., description="Number of queue items with status 'pending'")
in_progress: int = Field(..., description="Number of queue items with status 'in_progress'")
completed: int = Field(..., description="Number of queue items with status 'complete'")
@@ -328,6 +346,12 @@ class CancelByBatchIDsResult(BaseModel):
canceled: int = Field(..., description="Number of queue items canceled")
class CancelByOriginResult(BaseModel):
"""Result of canceling by list of batch ids"""
canceled: int = Field(..., description="Number of queue items canceled")
class CancelByQueueIDResult(CancelByBatchIDsResult):
"""Result of canceling by queue id"""
@@ -433,6 +457,8 @@ class SessionQueueValueToInsert(NamedTuple):
field_values: Optional[str] # field_values json
priority: int # priority
workflow: Optional[str] # workflow json
origin: str | None
destination: str | None
ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert]
@@ -453,6 +479,8 @@ def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new
json.dumps(field_values, default=to_jsonable_python) if field_values else None, # field_values (json)
priority, # priority
json.dumps(workflow, default=to_jsonable_python) if workflow else None, # workflow (json)
batch.origin, # origin
batch.destination, # destination
)
)
return values_to_insert

View File

@@ -10,6 +10,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
Batch,
BatchStatus,
CancelByBatchIDsResult,
CancelByOriginResult,
CancelByQueueIDResult,
ClearResult,
EnqueueBatchResult,
@@ -127,8 +128,8 @@ class SqliteSessionQueue(SessionQueueBase):
self.__cursor.executemany(
"""--sql
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow)
VALUES (?, ?, ?, ?, ?, ?, ?)
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
values_to_insert,
)
@@ -417,11 +418,7 @@ class SqliteSessionQueue(SessionQueueBase):
)
self.__conn.commit()
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_id)
self.__invoker.services.events.emit_queue_item_status_changed(
current_queue_item, batch_status, queue_status
)
self._set_queue_item_status(current_queue_item.item_id, "canceled")
except Exception:
self.__conn.rollback()
raise
@@ -429,6 +426,46 @@ class SqliteSessionQueue(SessionQueueBase):
self.__lock.release()
return CancelByBatchIDsResult(canceled=count)
def cancel_by_origin(self, queue_id: str, origin: str) -> CancelByOriginResult:
try:
current_queue_item = self.get_current(queue_id)
self.__lock.acquire()
where = """--sql
WHERE
queue_id == ?
AND origin == ?
AND status != 'canceled'
AND status != 'completed'
AND status != 'failed'
"""
params = (queue_id, origin)
self.__cursor.execute(
f"""--sql
SELECT COUNT(*)
FROM session_queue
{where};
""",
params,
)
count = self.__cursor.fetchone()[0]
self.__cursor.execute(
f"""--sql
UPDATE session_queue
SET status = 'canceled'
{where};
""",
params,
)
self.__conn.commit()
if current_queue_item is not None and current_queue_item.origin == origin:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
return CancelByOriginResult(canceled=count)
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
try:
current_queue_item = self.get_current(queue_id)
@@ -541,7 +578,9 @@ class SqliteSessionQueue(SessionQueueBase):
started_at,
session_id,
batch_id,
queue_id
queue_id,
origin,
destination
FROM session_queue
WHERE queue_id = ?
"""
@@ -621,7 +660,7 @@ class SqliteSessionQueue(SessionQueueBase):
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT status, count(*)
SELECT status, count(*), origin, destination
FROM session_queue
WHERE
queue_id = ?
@@ -633,6 +672,8 @@ class SqliteSessionQueue(SessionQueueBase):
result = cast(list[sqlite3.Row], self.__cursor.fetchall())
total = sum(row[1] for row in result)
counts: dict[str, int] = {row[0]: row[1] for row in result}
origin = result[0]["origin"] if result else None
destination = result[0]["destination"] if result else None
except Exception:
self.__conn.rollback()
raise
@@ -641,6 +682,8 @@ class SqliteSessionQueue(SessionQueueBase):
return BatchStatus(
batch_id=batch_id,
origin=origin,
destination=destination,
queue_id=queue_id,
pending=counts.get("pending", 0),
in_progress=counts.get("in_progress", 0),

View File

@@ -17,6 +17,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_11 import
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_12 import build_migration_12
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_13 import build_migration_13
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_14 import build_migration_14
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_15 import build_migration_15
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@@ -51,6 +52,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_12(app_config=config))
migrator.register_migration(build_migration_13())
migrator.register_migration(build_migration_14())
migrator.register_migration(build_migration_15())
migrator.run_migrations()
return db

View File

@@ -0,0 +1,34 @@
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration15Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._add_origin_col(cursor)
def _add_origin_col(self, cursor: sqlite3.Cursor) -> None:
"""
- Adds `origin` column to the session queue table.
- Adds `destination` column to the session queue table.
"""
cursor.execute("ALTER TABLE session_queue ADD COLUMN origin TEXT;")
cursor.execute("ALTER TABLE session_queue ADD COLUMN destination TEXT;")
def build_migration_15() -> Migration:
"""
Build the migration from database version 14 to 15.
This migration does the following:
- Adds `origin` column to the session queue table.
- Adds `destination` column to the session queue table.
"""
migration_15 = Migration(
from_version=14,
to_version=15,
callback=Migration15Callback(),
)
return migration_15

View File

@@ -1,407 +0,0 @@
{
"name": "FLUX Image to Image",
"author": "InvokeAI",
"description": "A simple image-to-image workflow using a FLUX dev model. ",
"version": "1.0.4",
"contact": "",
"tags": "image2image, flux, image-to-image",
"notes": "Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend using FLUX dev models for image-to-image workflows. The image-to-image performance with FLUX schnell models is poor.",
"exposedFields": [
{
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"fieldName": "model"
},
{
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"fieldName": "t5_encoder_model"
},
{
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"fieldName": "clip_embed_model"
},
{
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"fieldName": "vae_model"
},
{
"nodeId": "ace0258f-67d7-4eee-a218-6fff27065214",
"fieldName": "denoising_start"
},
{
"nodeId": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"fieldName": "prompt"
},
{
"nodeId": "ace0258f-67d7-4eee-a218-6fff27065214",
"fieldName": "num_steps"
}
],
"meta": {
"version": "3.0.0",
"category": "default"
},
"nodes": [
{
"id": "2981a67c-480f-4237-9384-26b68dbf912b",
"type": "invocation",
"data": {
"id": "2981a67c-480f-4237-9384-26b68dbf912b",
"type": "flux_vae_encode",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": true,
"inputs": {
"image": {
"name": "image",
"label": "",
"value": {
"image_name": "8a5c62aa-9335-45d2-9c71-89af9fc1f8d4.png"
}
},
"vae": {
"name": "vae",
"label": ""
}
}
},
"position": {
"x": 732.7680166609682,
"y": -24.37398171806909
}
},
{
"id": "ace0258f-67d7-4eee-a218-6fff27065214",
"type": "invocation",
"data": {
"id": "ace0258f-67d7-4eee-a218-6fff27065214",
"type": "flux_denoise",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": true,
"inputs": {
"board": {
"name": "board",
"label": ""
},
"metadata": {
"name": "metadata",
"label": ""
},
"latents": {
"name": "latents",
"label": ""
},
"denoise_mask": {
"name": "denoise_mask",
"label": ""
},
"denoising_start": {
"name": "denoising_start",
"label": "",
"value": 0.04
},
"denoising_end": {
"name": "denoising_end",
"label": "",
"value": 1
},
"transformer": {
"name": "transformer",
"label": ""
},
"positive_text_conditioning": {
"name": "positive_text_conditioning",
"label": ""
},
"width": {
"name": "width",
"label": "",
"value": 1024
},
"height": {
"name": "height",
"label": "",
"value": 1024
},
"num_steps": {
"name": "num_steps",
"label": "Steps (Recommend 30 for Dev, 4 for Schnell)",
"value": 30
},
"guidance": {
"name": "guidance",
"label": "",
"value": 4
},
"seed": {
"name": "seed",
"label": "",
"value": 0
}
}
},
"position": {
"x": 1182.8836633018684,
"y": -251.38882958913183
}
},
{
"id": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
"type": "invocation",
"data": {
"id": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
"type": "flux_vae_decode",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": false,
"useCache": true,
"inputs": {
"board": {
"name": "board",
"label": ""
},
"metadata": {
"name": "metadata",
"label": ""
},
"latents": {
"name": "latents",
"label": ""
},
"vae": {
"name": "vae",
"label": ""
}
}
},
"position": {
"x": 1575.5797431839133,
"y": -209.00150975507415
}
},
{
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"type": "invocation",
"data": {
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"type": "flux_model_loader",
"version": "1.0.4",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": false,
"inputs": {
"model": {
"name": "model",
"label": "Model (dev variant recommended for Image-to-Image)"
},
"t5_encoder_model": {
"name": "t5_encoder_model",
"label": ""
},
"clip_embed_model": {
"name": "clip_embed_model",
"label": "",
"value": {
"key": "fa23a584-b623-415d-832a-21b5098ff1a1",
"hash": "blake3:17c19f0ef941c3b7609a9c94a659ca5364de0be364a91d4179f0e39ba17c3b70",
"name": "clip-vit-large-patch14",
"base": "any",
"type": "clip_embed"
}
},
"vae_model": {
"name": "vae_model",
"label": "",
"value": {
"key": "74fc82ba-c0a8-479d-a890-2126f82da758",
"hash": "blake3:ce21cb76364aa6e2421311cf4a4b5eb052a76c4f1cd207b50703d8978198a068",
"name": "FLUX.1-schnell_ae",
"base": "flux",
"type": "vae"
}
}
}
},
"position": {
"x": 328.1809894659957,
"y": -90.2241133566946
}
},
{
"id": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"type": "invocation",
"data": {
"id": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"type": "flux_text_encoder",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": true,
"inputs": {
"clip": {
"name": "clip",
"label": ""
},
"t5_encoder": {
"name": "t5_encoder",
"label": ""
},
"t5_max_seq_len": {
"name": "t5_max_seq_len",
"label": "T5 Max Seq Len",
"value": 256
},
"prompt": {
"name": "prompt",
"label": "",
"value": "a cat wearing a birthday hat"
}
}
},
"position": {
"x": 745.8823365057267,
"y": -299.60249175851914
}
},
{
"id": "4754c534-a5f3-4ad0-9382-7887985e668c",
"type": "invocation",
"data": {
"id": "4754c534-a5f3-4ad0-9382-7887985e668c",
"type": "rand_int",
"version": "1.0.1",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": false,
"inputs": {
"low": {
"name": "low",
"label": "",
"value": 0
},
"high": {
"name": "high",
"label": "",
"value": 2147483647
}
}
},
"position": {
"x": 725.834098928012,
"y": 496.2710031089931
}
}
],
"edges": [
{
"id": "reactflow__edge-2981a67c-480f-4237-9384-26b68dbf912bheight-ace0258f-67d7-4eee-a218-6fff27065214height",
"type": "default",
"source": "2981a67c-480f-4237-9384-26b68dbf912b",
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
"sourceHandle": "height",
"targetHandle": "height"
},
{
"id": "reactflow__edge-2981a67c-480f-4237-9384-26b68dbf912bwidth-ace0258f-67d7-4eee-a218-6fff27065214width",
"type": "default",
"source": "2981a67c-480f-4237-9384-26b68dbf912b",
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
"sourceHandle": "width",
"targetHandle": "width"
},
{
"id": "reactflow__edge-2981a67c-480f-4237-9384-26b68dbf912blatents-ace0258f-67d7-4eee-a218-6fff27065214latents",
"type": "default",
"source": "2981a67c-480f-4237-9384-26b68dbf912b",
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
"sourceHandle": "latents",
"targetHandle": "latents"
},
{
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90vae-2981a67c-480f-4237-9384-26b68dbf912bvae",
"type": "default",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "2981a67c-480f-4237-9384-26b68dbf912b",
"sourceHandle": "vae",
"targetHandle": "vae"
},
{
"id": "reactflow__edge-ace0258f-67d7-4eee-a218-6fff27065214latents-7e5172eb-48c1-44db-a770-8fd83e1435d1latents",
"type": "default",
"source": "ace0258f-67d7-4eee-a218-6fff27065214",
"target": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
"sourceHandle": "latents",
"targetHandle": "latents"
},
{
"id": "reactflow__edge-4754c534-a5f3-4ad0-9382-7887985e668cvalue-ace0258f-67d7-4eee-a218-6fff27065214seed",
"type": "default",
"source": "4754c534-a5f3-4ad0-9382-7887985e668c",
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
"sourceHandle": "value",
"targetHandle": "seed"
},
{
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90transformer-ace0258f-67d7-4eee-a218-6fff27065214transformer",
"type": "default",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
"sourceHandle": "transformer",
"targetHandle": "transformer"
},
{
"id": "reactflow__edge-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cconditioning-ace0258f-67d7-4eee-a218-6fff27065214positive_text_conditioning",
"type": "default",
"source": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
"sourceHandle": "conditioning",
"targetHandle": "positive_text_conditioning"
},
{
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90vae-7e5172eb-48c1-44db-a770-8fd83e1435d1vae",
"type": "default",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
"sourceHandle": "vae",
"targetHandle": "vae"
},
{
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90max_seq_len-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_max_seq_len",
"type": "default",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"sourceHandle": "max_seq_len",
"targetHandle": "t5_max_seq_len"
},
{
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90t5_encoder-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_encoder",
"type": "default",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"sourceHandle": "t5_encoder",
"targetHandle": "t5_encoder"
},
{
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90clip-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cclip",
"type": "default",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"sourceHandle": "clip",
"targetHandle": "clip"
}
]
}

View File

@@ -1,7 +1,7 @@
{
"name": "FLUX Text to Image",
"author": "InvokeAI",
"description": "A simple text-to-image workflow using FLUX dev or schnell models.",
"description": "A simple text-to-image workflow using FLUX dev or schnell models. Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend 4 steps for FLUX schnell models and 30 steps for FLUX dev models.",
"version": "1.0.4",
"contact": "",
"tags": "text2image, flux",
@@ -11,25 +11,17 @@
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"fieldName": "model"
},
{
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"fieldName": "t5_encoder_model"
},
{
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"fieldName": "clip_embed_model"
},
{
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"fieldName": "vae_model"
},
{
"nodeId": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"fieldName": "prompt"
},
{
"nodeId": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
"nodeId": "159bdf1b-79e7-4174-b86e-d40e646964c8",
"fieldName": "num_steps"
},
{
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"fieldName": "t5_encoder_model"
}
],
"meta": {
@@ -37,121 +29,6 @@
"category": "default"
},
"nodes": [
{
"id": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
"type": "invocation",
"data": {
"id": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
"type": "flux_denoise",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": true,
"inputs": {
"board": {
"name": "board",
"label": ""
},
"metadata": {
"name": "metadata",
"label": ""
},
"latents": {
"name": "latents",
"label": ""
},
"denoise_mask": {
"name": "denoise_mask",
"label": ""
},
"denoising_start": {
"name": "denoising_start",
"label": "",
"value": 0
},
"denoising_end": {
"name": "denoising_end",
"label": "",
"value": 1
},
"transformer": {
"name": "transformer",
"label": ""
},
"positive_text_conditioning": {
"name": "positive_text_conditioning",
"label": ""
},
"width": {
"name": "width",
"label": "",
"value": 1024
},
"height": {
"name": "height",
"label": "",
"value": 1024
},
"num_steps": {
"name": "num_steps",
"label": "Steps (Recommend 30 for Dev, 4 for Schnell)",
"value": 30
},
"guidance": {
"name": "guidance",
"label": "",
"value": 4
},
"seed": {
"name": "seed",
"label": "",
"value": 0
}
}
},
"position": {
"x": 1186.1868226120378,
"y": -214.9459927686657
}
},
{
"id": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
"type": "invocation",
"data": {
"id": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
"type": "flux_vae_decode",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": false,
"useCache": true,
"inputs": {
"board": {
"name": "board",
"label": ""
},
"metadata": {
"name": "metadata",
"label": ""
},
"latents": {
"name": "latents",
"label": ""
},
"vae": {
"name": "vae",
"label": ""
}
}
},
"position": {
"x": 1575.5797431839133,
"y": -209.00150975507415
}
},
{
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"type": "invocation",
@@ -222,8 +99,8 @@
}
},
"position": {
"x": 778.4899149328337,
"y": -100.36469216659502
"x": 824.1970602278849,
"y": 146.98251001061735
}
},
{
@@ -252,52 +129,77 @@
}
},
"position": {
"x": 800.9667463219505,
"y": 285.8297267547506
"x": 822.9899179655476,
"y": 360.9657214885052
}
},
{
"id": "159bdf1b-79e7-4174-b86e-d40e646964c8",
"type": "invocation",
"data": {
"id": "159bdf1b-79e7-4174-b86e-d40e646964c8",
"type": "flux_text_to_image",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": false,
"useCache": true,
"inputs": {
"board": {
"name": "board",
"label": ""
},
"metadata": {
"name": "metadata",
"label": ""
},
"transformer": {
"name": "transformer",
"label": ""
},
"vae": {
"name": "vae",
"label": ""
},
"positive_text_conditioning": {
"name": "positive_text_conditioning",
"label": ""
},
"width": {
"name": "width",
"label": "",
"value": 1024
},
"height": {
"name": "height",
"label": "",
"value": 1024
},
"num_steps": {
"name": "num_steps",
"label": "Steps (Recommend 30 for Dev, 4 for Schnell)",
"value": 30
},
"guidance": {
"name": "guidance",
"label": "",
"value": 4
},
"seed": {
"name": "seed",
"label": "",
"value": 0
}
}
},
"position": {
"x": 1216.3900791301849,
"y": 5.500841807102248
}
}
],
"edges": [
{
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90transformer-4fe24f07-f906-4f55-ab2c-9beee56ef5bdtransformer",
"type": "default",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
"sourceHandle": "transformer",
"targetHandle": "transformer"
},
{
"id": "reactflow__edge-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cconditioning-4fe24f07-f906-4f55-ab2c-9beee56ef5bdpositive_text_conditioning",
"type": "default",
"source": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"target": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
"sourceHandle": "conditioning",
"targetHandle": "positive_text_conditioning"
},
{
"id": "reactflow__edge-4754c534-a5f3-4ad0-9382-7887985e668cvalue-4fe24f07-f906-4f55-ab2c-9beee56ef5bdseed",
"type": "default",
"source": "4754c534-a5f3-4ad0-9382-7887985e668c",
"target": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
"sourceHandle": "value",
"targetHandle": "seed"
},
{
"id": "reactflow__edge-4fe24f07-f906-4f55-ab2c-9beee56ef5bdlatents-7e5172eb-48c1-44db-a770-8fd83e1435d1latents",
"type": "default",
"source": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
"target": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
"sourceHandle": "latents",
"targetHandle": "latents"
},
{
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90vae-7e5172eb-48c1-44db-a770-8fd83e1435d1vae",
"type": "default",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
"sourceHandle": "vae",
"targetHandle": "vae"
},
{
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90max_seq_len-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_max_seq_len",
"type": "default",
@@ -306,6 +208,14 @@
"sourceHandle": "max_seq_len",
"targetHandle": "t5_max_seq_len"
},
{
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90vae-159bdf1b-79e7-4174-b86e-d40e646964c8vae",
"type": "default",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
"sourceHandle": "vae",
"targetHandle": "vae"
},
{
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90t5_encoder-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_encoder",
"type": "default",
@@ -321,6 +231,30 @@
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"sourceHandle": "clip",
"targetHandle": "clip"
},
{
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90transformer-159bdf1b-79e7-4174-b86e-d40e646964c8transformer",
"type": "default",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
"sourceHandle": "transformer",
"targetHandle": "transformer"
},
{
"id": "reactflow__edge-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cconditioning-159bdf1b-79e7-4174-b86e-d40e646964c8positive_text_conditioning",
"type": "default",
"source": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
"sourceHandle": "conditioning",
"targetHandle": "positive_text_conditioning"
},
{
"id": "reactflow__edge-4754c534-a5f3-4ad0-9382-7887985e668cvalue-159bdf1b-79e7-4174-b86e-d40e646964c8seed",
"type": "default",
"source": "4754c534-a5f3-4ad0-9382-7887985e668c",
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
"sourceHandle": "value",
"targetHandle": "seed"
}
]
}

View File

@@ -1,45 +0,0 @@
from typing import Callable
import torch
from tqdm import tqdm
from invokeai.backend.flux.inpaint_extension import InpaintExtension
from invokeai.backend.flux.model import Flux
def denoise(
model: Flux,
# model input
img: torch.Tensor,
img_ids: torch.Tensor,
txt: torch.Tensor,
txt_ids: torch.Tensor,
vec: torch.Tensor,
# sampling parameters
timesteps: list[float],
step_callback: Callable[[], None],
guidance: float,
inpaint_extension: InpaintExtension | None,
):
# guidance_vec is ignored for schnell.
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
)
img = img + (t_prev - t_curr) * pred
if inpaint_extension is not None:
img = inpaint_extension.merge_intermediate_latents_with_init_latents(img, t_prev)
step_callback()
return img

View File

@@ -1,35 +0,0 @@
import torch
class InpaintExtension:
"""A class for managing inpainting with FLUX."""
def __init__(self, init_latents: torch.Tensor, inpaint_mask: torch.Tensor, noise: torch.Tensor):
"""Initialize InpaintExtension.
Args:
init_latents (torch.Tensor): The initial latents (i.e. un-noised at timestep 0). In 'packed' format.
inpaint_mask (torch.Tensor): A mask specifying which elements to inpaint. Range [0, 1]. Values of 1 will be
re-generated. Values of 0 will remain unchanged. Values between 0 and 1 can be used to blend the
inpainted region with the background. In 'packed' format.
noise (torch.Tensor): The noise tensor used to noise the init_latents. In 'packed' format.
"""
assert init_latents.shape == inpaint_mask.shape == noise.shape
self._init_latents = init_latents
self._inpaint_mask = inpaint_mask
self._noise = noise
def merge_intermediate_latents_with_init_latents(
self, intermediate_latents: torch.Tensor, timestep: float
) -> torch.Tensor:
"""Merge the intermediate latents with the initial latents for the current timestep using the inpaint mask. I.e.
update the intermediate latents to keep the regions that are not being inpainted on the correct noise
trajectory.
This function should be called after each denoising step.
"""
# Noise the init latents for the current timestep.
noised_init_latents = self._noise * timestep + (1.0 - timestep) * self._init_latents
# Merge the intermediate latents with the noised_init_latents using the inpaint_mask.
return intermediate_latents * self._inpaint_mask + noised_init_latents * (1.0 - self._inpaint_mask)

View File

@@ -258,17 +258,16 @@ class Decoder(nn.Module):
class DiagonalGaussian(nn.Module):
def __init__(self, chunk_dim: int = 1):
def __init__(self, sample: bool = True, chunk_dim: int = 1):
super().__init__()
self.sample = sample
self.chunk_dim = chunk_dim
def forward(self, z: Tensor, sample: bool = True, generator: torch.Generator | None = None) -> Tensor:
def forward(self, z: Tensor) -> Tensor:
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
if sample:
if self.sample:
std = torch.exp(0.5 * logvar)
# Unfortunately, torch.randn_like(...) does not accept a generator argument at the time of writing, so we
# have to use torch.randn(...) instead.
return mean + std * torch.randn(size=mean.size(), generator=generator, dtype=mean.dtype, device=mean.device)
return mean + std * torch.randn_like(mean)
else:
return mean
@@ -298,21 +297,8 @@ class AutoEncoder(nn.Module):
self.scale_factor = params.scale_factor
self.shift_factor = params.shift_factor
def encode(self, x: Tensor, sample: bool = True, generator: torch.Generator | None = None) -> Tensor:
"""Run VAE encoding on input tensor x.
Args:
x (Tensor): Input image tensor. Shape: (batch_size, in_channels, height, width).
sample (bool, optional): If True, sample from the encoded distribution, else, return the distribution mean.
Defaults to True.
generator (torch.Generator | None, optional): Optional random number generator for reproducibility.
Defaults to None.
Returns:
Tensor: Encoded latent tensor. Shape: (batch_size, z_channels, latent_height, latent_width).
"""
z = self.reg(self.encoder(x), sample=sample, generator=generator)
def encode(self, x: Tensor) -> Tensor:
z = self.reg(self.encoder(x))
z = self.scale_factor * (z - self.shift_factor)
return z

View File

@@ -0,0 +1,167 @@
# Initially pulled from https://github.com/black-forest-labs/flux
import math
from typing import Callable
import torch
from einops import rearrange, repeat
from torch import Tensor
from tqdm import tqdm
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.modules.conditioner import HFEncoder
def get_noise(
num_samples: int,
height: int,
width: int,
device: torch.device,
dtype: torch.dtype,
seed: int,
):
# We always generate noise on the same device and dtype then cast to ensure consistency across devices/dtypes.
rand_device = "cpu"
rand_dtype = torch.float16
return torch.randn(
num_samples,
16,
# allow for packing
2 * math.ceil(height / 16),
2 * math.ceil(width / 16),
device=rand_device,
dtype=rand_dtype,
generator=torch.Generator(device=rand_device).manual_seed(seed),
).to(device=device, dtype=dtype)
def prepare(t5: HFEncoder, clip: HFEncoder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
bs, c, h, w = img.shape
if bs == 1 and not isinstance(prompt, str):
bs = len(prompt)
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img.shape[0] == 1 and bs > 1:
img = repeat(img, "1 ... -> bs ...", bs=bs)
img_ids = torch.zeros(h // 2, w // 2, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
if isinstance(prompt, str):
prompt = [prompt]
txt = t5(prompt)
if txt.shape[0] == 1 and bs > 1:
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
txt_ids = torch.zeros(bs, txt.shape[1], 3)
vec = clip(prompt)
if vec.shape[0] == 1 and bs > 1:
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
return {
"img": img,
"img_ids": img_ids.to(img.device),
"txt": txt.to(img.device),
"txt_ids": txt_ids.to(img.device),
"vec": vec.to(img.device),
}
def time_shift(mu: float, sigma: float, t: Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b
def get_schedule(
num_steps: int,
image_seq_len: int,
base_shift: float = 0.5,
max_shift: float = 1.15,
shift: bool = True,
) -> list[float]:
# extra step for zero
timesteps = torch.linspace(1, 0, num_steps + 1)
# shifting the schedule to favor high timesteps for higher signal images
if shift:
# eastimate mu based on linear estimation between two points
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
timesteps = time_shift(mu, 1.0, timesteps)
return timesteps.tolist()
def denoise(
model: Flux,
# model input
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
vec: Tensor,
# sampling parameters
timesteps: list[float],
step_callback: Callable[[], None],
guidance: float = 4.0,
):
# guidance_vec is ignored for schnell.
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
)
img = img + (t_prev - t_curr) * pred
step_callback()
return img
def unpack(x: Tensor, height: int, width: int) -> Tensor:
return rearrange(
x,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=math.ceil(height / 16),
w=math.ceil(width / 16),
ph=2,
pw=2,
)
def prepare_latent_img_patches(latent_img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Convert an input image in latent space to patches for diffusion.
This implementation was extracted from:
https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/sampling.py#L32
Returns:
tuple[Tensor, Tensor]: (img, img_ids), as defined in the original flux repo.
"""
bs, c, h, w = latent_img.shape
# Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches.
img = rearrange(latent_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img.shape[0] == 1 and bs > 1:
img = repeat(img, "1 ... -> bs ...", bs=bs)
# Generate patch position ids.
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device, dtype=img.dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device, dtype=img.dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device, dtype=img.dtype)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
return img, img_ids

View File

@@ -1,135 +0,0 @@
# Initially pulled from https://github.com/black-forest-labs/flux
import math
from typing import Callable
import torch
from einops import rearrange, repeat
def get_noise(
num_samples: int,
height: int,
width: int,
device: torch.device,
dtype: torch.dtype,
seed: int,
):
# We always generate noise on the same device and dtype then cast to ensure consistency across devices/dtypes.
rand_device = "cpu"
rand_dtype = torch.float16
return torch.randn(
num_samples,
16,
# allow for packing
2 * math.ceil(height / 16),
2 * math.ceil(width / 16),
device=rand_device,
dtype=rand_dtype,
generator=torch.Generator(device=rand_device).manual_seed(seed),
).to(device=device, dtype=dtype)
def time_shift(mu: float, sigma: float, t: torch.Tensor) -> torch.Tensor:
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b
def get_schedule(
num_steps: int,
image_seq_len: int,
base_shift: float = 0.5,
max_shift: float = 1.15,
shift: bool = True,
) -> list[float]:
# extra step for zero
timesteps = torch.linspace(1, 0, num_steps + 1)
# shifting the schedule to favor high timesteps for higher signal images
if shift:
# estimate mu based on linear estimation between two points
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
timesteps = time_shift(mu, 1.0, timesteps)
return timesteps.tolist()
def _find_last_index_ge_val(timesteps: list[float], val: float, eps: float = 1e-6) -> int:
"""Find the last index in timesteps that is >= val.
We use epsilon-close equality to avoid potential floating point errors.
"""
idx = len(list(filter(lambda t: t >= (val - eps), timesteps))) - 1
assert idx >= 0
return idx
def clip_timestep_schedule(timesteps: list[float], denoising_start: float, denoising_end: float) -> list[float]:
"""Clip the timestep schedule to the denoising range.
Args:
timesteps (list[float]): The original timestep schedule: [1.0, ..., 0.0].
denoising_start (float): A value in [0, 1] specifying the start of the denoising process. E.g. a value of 0.2
would mean that the denoising process start at the last timestep in the schedule >= 0.8.
denoising_end (float): A value in [0, 1] specifying the end of the denoising process. E.g. a value of 0.8 would
mean that the denoising process end at the last timestep in the schedule >= 0.2.
Returns:
list[float]: The clipped timestep schedule.
"""
assert 0.0 <= denoising_start <= 1.0
assert 0.0 <= denoising_end <= 1.0
assert denoising_start <= denoising_end
t_start_val = 1.0 - denoising_start
t_end_val = 1.0 - denoising_end
t_start_idx = _find_last_index_ge_val(timesteps, t_start_val)
t_end_idx = _find_last_index_ge_val(timesteps, t_end_val)
clipped_timesteps = timesteps[t_start_idx : t_end_idx + 1]
return clipped_timesteps
def unpack(x: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""Unpack flat array of patch embeddings to latent image."""
return rearrange(
x,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=math.ceil(height / 16),
w=math.ceil(width / 16),
ph=2,
pw=2,
)
def pack(x: torch.Tensor) -> torch.Tensor:
"""Pack latent image to flattented array of patch embeddings."""
# Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches.
return rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
def generate_img_ids(h: int, w: int, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
"""Generate tensor of image position ids.
Args:
h (int): Height of image in latent space.
w (int): Width of image in latent space.
batch_size (int): Batch size.
device (torch.device): Device.
dtype (torch.dtype): dtype.
Returns:
torch.Tensor: Image position ids.
"""
img_ids = torch.zeros(h // 2, w // 2, 3, device=device, dtype=dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=device, dtype=dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=device, dtype=dtype)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
return img_ids

672
invokeai/backend/lora.py Normal file
View File

@@ -0,0 +1,672 @@
# Copyright (c) 2024 The InvokeAI Development team
"""LoRA model support."""
import bisect
from pathlib import Path
from typing import Dict, List, Optional, Set, Tuple, Union
import torch
from safetensors.torch import load_file
from typing_extensions import Self
import invokeai.backend.util.logging as logger
from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.raw_model import RawModel
class LoRALayerBase:
# rank: Optional[int]
# alpha: Optional[float]
# bias: Optional[torch.Tensor]
# layer_key: str
# @property
# def scale(self):
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
if "alpha" in values:
self.alpha = values["alpha"].item()
else:
self.alpha = None
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
self.bias: Optional[torch.Tensor] = torch.sparse_coo_tensor(
values["bias_indices"],
values["bias_values"],
tuple(values["bias_size"]),
)
else:
self.bias = None
self.rank = None # set in layer implementation
self.layer_key = layer_key
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
return self.bias
def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
params = {"weight": self.get_weight(orig_module.weight)}
bias = self.get_bias(orig_module.bias)
if bias is not None:
params["bias"] = bias
return params
def calc_size(self) -> int:
model_size = 0
for val in [self.bias]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
if self.bias is not None:
self.bias = self.bias.to(device=device, dtype=dtype)
def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]):
"""Log a warning if values contains unhandled keys."""
# {"alpha", "bias_indices", "bias_values", "bias_size"} are hard-coded, because they are handled by
# `LoRALayerBase`. Sub-classes should provide the known_keys that they handled.
all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"}
unknown_keys = set(values.keys()) - all_known_keys
if unknown_keys:
logger.warning(
f"Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Keys: {unknown_keys}"
)
# TODO: find and debug lora/locon with bias
class LoRALayer(LoRALayerBase):
# up: torch.Tensor
# mid: Optional[torch.Tensor]
# down: torch.Tensor
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
self.up = values["lora_up.weight"]
self.down = values["lora_down.weight"]
self.mid = values.get("lora_mid.weight", None)
self.rank = self.down.shape[0]
self.check_keys(
values,
{
"lora_up.weight",
"lora_down.weight",
"lora_mid.weight",
},
)
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
if self.mid is not None:
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
else:
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.up, self.mid, self.down]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
super().to(device=device, dtype=dtype)
self.up = self.up.to(device=device, dtype=dtype)
self.down = self.down.to(device=device, dtype=dtype)
if self.mid is not None:
self.mid = self.mid.to(device=device, dtype=dtype)
class LoHALayer(LoRALayerBase):
# w1_a: torch.Tensor
# w1_b: torch.Tensor
# w2_a: torch.Tensor
# w2_b: torch.Tensor
# t1: Optional[torch.Tensor] = None
# t2: Optional[torch.Tensor] = None
def __init__(self, layer_key: str, values: Dict[str, torch.Tensor]):
super().__init__(layer_key, values)
self.w1_a = values["hada_w1_a"]
self.w1_b = values["hada_w1_b"]
self.w2_a = values["hada_w2_a"]
self.w2_b = values["hada_w2_b"]
self.t1 = values.get("hada_t1", None)
self.t2 = values.get("hada_t2", None)
self.rank = self.w1_b.shape[0]
self.check_keys(
values,
{
"hada_w1_a",
"hada_w1_b",
"hada_w2_a",
"hada_w2_b",
"hada_t1",
"hada_t2",
},
)
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
if self.t1 is None:
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
else:
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
weight = rebuild1 * rebuild2
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
super().to(device=device, dtype=dtype)
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
if self.t1 is not None:
self.t1 = self.t1.to(device=device, dtype=dtype)
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)
class LoKRLayer(LoRALayerBase):
# w1: Optional[torch.Tensor] = None
# w1_a: Optional[torch.Tensor] = None
# w1_b: Optional[torch.Tensor] = None
# w2: Optional[torch.Tensor] = None
# w2_a: Optional[torch.Tensor] = None
# w2_b: Optional[torch.Tensor] = None
# t2: Optional[torch.Tensor] = None
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
self.w1 = values.get("lokr_w1", None)
if self.w1 is None:
self.w1_a = values["lokr_w1_a"]
self.w1_b = values["lokr_w1_b"]
else:
self.w1_b = None
self.w1_a = None
self.w2 = values.get("lokr_w2", None)
if self.w2 is None:
self.w2_a = values["lokr_w2_a"]
self.w2_b = values["lokr_w2_b"]
else:
self.w2_a = None
self.w2_b = None
self.t2 = values.get("lokr_t2", None)
if self.w1_b is not None:
self.rank = self.w1_b.shape[0]
elif self.w2_b is not None:
self.rank = self.w2_b.shape[0]
else:
self.rank = None # unscaled
self.check_keys(
values,
{
"lokr_w1",
"lokr_w1_a",
"lokr_w1_b",
"lokr_w2",
"lokr_w2_a",
"lokr_w2_b",
"lokr_t2",
},
)
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
w1: Optional[torch.Tensor] = self.w1
if w1 is None:
assert self.w1_a is not None
assert self.w1_b is not None
w1 = self.w1_a @ self.w1_b
w2 = self.w2
if w2 is None:
if self.t2 is None:
assert self.w2_a is not None
assert self.w2_b is not None
w2 = self.w2_a @ self.w2_b
else:
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
w2 = w2.contiguous()
assert w1 is not None
assert w2 is not None
weight = torch.kron(w1, w2)
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
super().to(device=device, dtype=dtype)
if self.w1 is not None:
self.w1 = self.w1.to(device=device, dtype=dtype)
else:
assert self.w1_a is not None
assert self.w1_b is not None
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
if self.w2 is not None:
self.w2 = self.w2.to(device=device, dtype=dtype)
else:
assert self.w2_a is not None
assert self.w2_b is not None
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)
class FullLayer(LoRALayerBase):
# bias handled in LoRALayerBase(calc_size, to)
# weight: torch.Tensor
# bias: Optional[torch.Tensor]
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
self.weight = values["diff"]
self.bias = values.get("diff_b", None)
self.rank = None # unscaled
self.check_keys(values, {"diff", "diff_b"})
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
return self.weight
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
class IA3Layer(LoRALayerBase):
# weight: torch.Tensor
# on_input: torch.Tensor
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
self.weight = values["weight"]
self.on_input = values["on_input"]
self.rank = None # unscaled
self.check_keys(values, {"weight", "on_input"})
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
weight = self.weight
if not self.on_input:
weight = weight.reshape(-1, 1)
assert orig_weight is not None
return orig_weight * weight
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size()
model_size += self.on_input.nelement() * self.on_input.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
self.on_input = self.on_input.to(device=device, dtype=dtype)
class NormLayer(LoRALayerBase):
# bias handled in LoRALayerBase(calc_size, to)
# weight: torch.Tensor
# bias: Optional[torch.Tensor]
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
self.weight = values["w_norm"]
self.bias = values.get("b_norm", None)
self.rank = None # unscaled
self.check_keys(values, {"w_norm", "b_norm"})
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
return self.weight
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer]
class LoRAModelRaw(RawModel): # (torch.nn.Module):
_name: str
layers: Dict[str, AnyLoRALayer]
def __init__(
self,
name: str,
layers: Dict[str, AnyLoRALayer],
):
self._name = name
self.layers = layers
@property
def name(self) -> str:
return self._name
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
# TODO: try revert if exception?
for _key, layer in self.layers.items():
layer.to(device=device, dtype=dtype)
def calc_size(self) -> int:
model_size = 0
for _, layer in self.layers.items():
model_size += layer.calc_size()
return model_size
@classmethod
def _convert_sdxl_keys_to_diffusers_format(cls, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
diffusers format, then this function will have no effect.
This function is adapted from:
https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409
Args:
state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict.
Raises:
ValueError: If state_dict contains an unrecognized key, or not all keys could be converted.
Returns:
Dict[str, Tensor]: The diffusers-format state_dict.
"""
converted_count = 0 # The number of Stability AI keys converted to diffusers format.
not_converted_count = 0 # The number of keys that were not converted.
# Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes.
# For example, we want to efficiently find `input_blocks_4_1` in the list when searching for
# `input_blocks_4_1_proj_in`.
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
stability_unet_keys.sort()
new_state_dict = {}
for full_key, value in state_dict.items():
if full_key.startswith("lora_unet_"):
search_key = full_key.replace("lora_unet_", "")
# Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix.
position = bisect.bisect_right(stability_unet_keys, search_key)
map_key = stability_unet_keys[position - 1]
# Now, check if the map_key *actually* matches the search_key.
if search_key.startswith(map_key):
new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key])
new_state_dict[new_key] = value
converted_count += 1
else:
new_state_dict[full_key] = value
not_converted_count += 1
elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
# The CLIP text encoders have the same keys in both Stability AI and diffusers formats.
new_state_dict[full_key] = value
continue
else:
raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.")
if converted_count > 0 and not_converted_count > 0:
raise ValueError(
f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count},"
f" not_converted={not_converted_count}"
)
return new_state_dict
@classmethod
def from_checkpoint(
cls,
file_path: Union[str, Path],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
base_model: Optional[BaseModelType] = None,
) -> Self:
device = device or torch.device("cpu")
dtype = dtype or torch.float32
if isinstance(file_path, str):
file_path = Path(file_path)
model = cls(
name=file_path.stem,
layers={},
)
if file_path.suffix == ".safetensors":
sd = load_file(file_path.absolute().as_posix(), device="cpu")
else:
sd = torch.load(file_path, map_location="cpu")
state_dict = cls._group_state(sd)
if base_model == BaseModelType.StableDiffusionXL:
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
for layer_key, values in state_dict.items():
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
# lora and locon
if "lora_up.weight" in values:
layer: AnyLoRALayer = LoRALayer(layer_key, values)
# loha
elif "hada_w1_a" in values:
layer = LoHALayer(layer_key, values)
# lokr
elif "lokr_w1" in values or "lokr_w1_a" in values:
layer = LoKRLayer(layer_key, values)
# diff
elif "diff" in values:
layer = FullLayer(layer_key, values)
# ia3
elif "on_input" in values:
layer = IA3Layer(layer_key, values)
# norms
elif "w_norm" in values:
layer = NormLayer(layer_key, values)
else:
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
raise Exception("Unknown lora format!")
# lower memory consumption by removing already parsed layer values
state_dict[layer_key].clear()
layer.to(device=device, dtype=dtype)
model.layers[layer_key] = layer
return model
@staticmethod
def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
state_dict_groupped: Dict[str, Dict[str, torch.Tensor]] = {}
for key, value in state_dict.items():
stem, leaf = key.split(".", 1)
if stem not in state_dict_groupped:
state_dict_groupped[stem] = {}
state_dict_groupped[stem][leaf] = value
return state_dict_groupped
# code from
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
def make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]:
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
unet_conversion_map_layer = []
for i in range(3): # num_blocks is 3 in sdxl
# loop over downblocks/upblocks
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3:
# no attention layers in down_blocks.3
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(3):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
# if i > 0: commentout for sdxl
# no attention layers in up_blocks.0
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0."
sd_mid_atn_prefix = "middle_block.1."
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2*j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
unet_conversion_map_resnet = [
# (stable-diffusion, HF Diffusers)
("in_layers.0.", "norm1."),
("in_layers.2.", "conv1."),
("out_layers.0.", "norm2."),
("out_layers.3.", "conv2."),
("emb_layers.1.", "time_emb_proj."),
("skip_connection.", "conv_shortcut."),
]
unet_conversion_map = []
for sd, hf in unet_conversion_map_layer:
if "resnets" in hf:
for sd_res, hf_res in unet_conversion_map_resnet:
unet_conversion_map.append((sd + sd_res, hf + hf_res))
else:
unet_conversion_map.append((sd, hf))
for j in range(2):
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
sd_time_embed_prefix = f"time_embed.{j*2}."
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
for j in range(2):
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
sd_label_embed_prefix = f"label_emb.0.{j*2}."
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
unet_conversion_map.append(("out.0.", "conv_norm_out."))
unet_conversion_map.append(("out.2.", "conv_out."))
return unet_conversion_map
SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = {
sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()
}

View File

@@ -66,9 +66,8 @@ class ModelLoader(ModelLoaderBase):
return (model_base / config.path).resolve()
def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> ModelLockerBase:
stats_name = ":".join([config.base, config.type, config.name, (submodel_type or "")])
try:
return self._ram_cache.get(config.key, submodel_type, stats_name=stats_name)
return self._ram_cache.get(config.key, submodel_type)
except IndexError:
pass
@@ -85,7 +84,7 @@ class ModelLoader(ModelLoaderBase):
return self._ram_cache.get(
key=config.key,
submodel_type=submodel_type,
stats_name=stats_name,
stats_name=":".join([config.base, config.type, config.name, (submodel_type or "")]),
)
def get_size_fs(

View File

@@ -128,24 +128,7 @@ class ModelCacheBase(ABC, Generic[T]):
@property
@abstractmethod
def max_cache_size(self) -> float:
"""Return the maximum size the RAM cache can grow to."""
pass
@max_cache_size.setter
@abstractmethod
def max_cache_size(self, value: float) -> None:
"""Set the cap on vram cache size."""
@property
@abstractmethod
def max_vram_cache_size(self) -> float:
"""Return the maximum size the VRAM cache can grow to."""
pass
@max_vram_cache_size.setter
@abstractmethod
def max_vram_cache_size(self, value: float) -> float:
"""Set the maximum size the VRAM cache can grow to."""
"""Return true if the cache is configured to lazily offload models in VRAM."""
pass
@abstractmethod

View File

@@ -70,7 +70,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
max_vram_cache_size: float,
execution_device: torch.device = torch.device("cuda"),
storage_device: torch.device = torch.device("cpu"),
precision: torch.dtype = torch.float16,
lazy_offloading: bool = True,
log_memory_usage: bool = False,
logger: Optional[Logger] = None,
@@ -82,13 +81,11 @@ class ModelCache(ModelCacheBase[AnyModel]):
:param max_vram_cache_size: Maximum size of the execution_device cache in GBs.
:param execution_device: Torch device to load active model into [torch.device('cuda')]
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
:param precision: Precision for loaded models [torch.float16]
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded.
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
behaviour.
:param logger: InvokeAILogger to use (otherwise creates one)
"""
# allow lazy offloading only when vram cache enabled
self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
@@ -133,16 +130,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
"""Set the cap on cache size."""
self._max_cache_size = value
@property
def max_vram_cache_size(self) -> float:
"""Return the cap on vram cache size."""
return self._max_vram_cache_size
@max_vram_cache_size.setter
def max_vram_cache_size(self, value: float) -> None:
"""Set the cap on vram cache size."""
self._max_vram_cache_size = value
@property
def stats(self) -> Optional[CacheStats]:
"""Return collected CacheStats object."""

View File

@@ -5,10 +5,8 @@ from logging import Logger
from pathlib import Path
from typing import Optional
import torch
from safetensors.torch import load_file
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
@@ -20,11 +18,6 @@ from invokeai.backend.model_manager import (
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.peft.conversions.flux_kohya_lora_conversion_utils import (
lora_model_from_flux_kohya_state_dict,
)
from invokeai.backend.peft.conversions.sd_lora_conversion_utils import lora_model_from_sd_state_dict
from invokeai.backend.peft.conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.Diffusers)
@@ -52,28 +45,14 @@ class LoRALoader(ModelLoader):
raise ValueError("There are no submodels in a LoRA model.")
model_path = Path(config.path)
assert self._model_base is not None
# Load the state dict from the model file.
if model_path.suffix == ".safetensors":
state_dict = load_file(model_path.absolute().as_posix(), device="cpu")
else:
state_dict = torch.load(model_path, map_location="cpu")
# Apply state_dict key conversions, if necessary.
if self._model_base == BaseModelType.StableDiffusionXL:
state_dict = convert_sdxl_keys_to_diffusers_format(state_dict)
model = lora_model_from_sd_state_dict(state_dict=state_dict)
elif self._model_base == BaseModelType.Flux:
model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict)
elif self._model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
# Currently, we don't apply any conversions for SD1 and SD2 LoRA models.
model = lora_model_from_sd_state_dict(state_dict=state_dict)
else:
raise ValueError(f"Unsupported LoRA base model: {self._model_base}")
model.to(dtype=self._torch_dtype)
model = LoRAModelRaw.from_checkpoint(
file_path=model_path,
dtype=self._torch_dtype,
base_model=self._model_base,
)
return model
# override
def _get_model_path(self, config: AnyModelConfig) -> Path:
# cheating a little - we remember this variable for using in the subsequent call to _load_model()
self._model_base = config.base

View File

@@ -15,9 +15,9 @@ from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import D
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager.config import AnyModel
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from invokeai.backend.peft.lora import LoRAModelRaw
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
from invokeai.backend.textual_inversion import TextualInversionModelRaw

View File

@@ -26,7 +26,6 @@ from invokeai.backend.model_manager.config import (
SchedulerPredictionType,
)
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, read_checkpoint_meta
from invokeai.backend.peft.conversions.flux_kohya_lora_conversion_utils import is_state_dict_likely_in_flux_kohya_format
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
from invokeai.backend.util.silence_warnings import SilenceWarnings
@@ -529,11 +528,9 @@ class LoRACheckpointProbe(CheckpointProbeBase):
return ModelFormat("lycoris")
def get_base_type(self) -> BaseModelType:
if is_state_dict_likely_in_flux_kohya_format(self.checkpoint):
return BaseModelType.Flux
checkpoint = self.checkpoint
token_vector_length = lora_token_vector_length(checkpoint)
# If we've gotten here, we assume that the model is a Stable Diffusion model.
token_vector_length = lora_token_vector_length(self.checkpoint)
if token_vector_length == 768:
return BaseModelType.StableDiffusion1
elif token_vector_length == 1024:

View File

@@ -13,10 +13,10 @@ from diffusers import OnnxRuntimeModel, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager import AnyModel
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from invokeai.backend.peft.lora import LoRAModelRaw
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage

View File

@@ -1,84 +0,0 @@
import re
from typing import Any, Dict, TypeVar
import torch
from invokeai.backend.peft.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.peft.layers.utils import peft_layer_from_state_dict
from invokeai.backend.peft.lora import LoRAModelRaw
# A regex pattern that matches all of the keys in the Kohya FLUX LoRA format.
# Example keys:
# lora_unet_double_blocks_0_img_attn_proj.alpha
# lora_unet_double_blocks_0_img_attn_proj.lora_down.weight
# lora_unet_double_blocks_0_img_attn_proj.lora_up.weight
FLUX_KOHYA_KEY_REGEX = (
r"lora_unet_(\w+_blocks)_(\d+)_(img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear1|linear2|modulation)_?(.*)"
)
def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> bool:
"""Checks if the provided state dict is likely in the Kohya FLUX LoRA format.
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
"""
for k in state_dict.keys():
if not re.match(FLUX_KOHYA_KEY_REGEX, k):
return False
return True
def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
# Group keys by layer.
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}
for key, value in state_dict.items():
layer_name, param_name = key.split(".", 1)
if layer_name not in grouped_state_dict:
grouped_state_dict[layer_name] = {}
grouped_state_dict[layer_name][param_name] = value
# Convert the state dict to the InvokeAI format.
grouped_state_dict = convert_flux_kohya_state_dict_to_invoke_format(grouped_state_dict)
# Create LoRA layers.
layers: dict[str, AnyLoRALayer] = {}
for layer_key, layer_state_dict in grouped_state_dict.items():
layer = peft_layer_from_state_dict(layer_key, layer_state_dict)
layers[layer_key] = layer
# Create and return the LoRAModelRaw.
return LoRAModelRaw(layers=layers)
T = TypeVar("T")
def convert_flux_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) -> Dict[str, T]:
"""Converts a state dict from the Kohya FLUX LoRA format to LoRA weight format used internally by InvokeAI.
Example key conversions:
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img_attn.qkv"
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img.attn.qkv"
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img.attn.qkv"
"""
def replace_func(match: re.Match[str]) -> str:
s = f"{match.group(1)}.{match.group(2)}.{match.group(3)}"
if match.group(4):
s += f".{match.group(4)}"
return s
converted_dict: dict[str, T] = {}
for k, v in state_dict.items():
match = re.match(FLUX_KOHYA_KEY_REGEX, k)
if match:
new_key = re.sub(FLUX_KOHYA_KEY_REGEX, replace_func, k)
converted_dict[new_key] = v
else:
raise ValueError(f"Key '{k}' does not match the expected pattern for FLUX LoRA weights.")
return converted_dict

View File

@@ -1,30 +0,0 @@
from typing import Dict
import torch
from invokeai.backend.peft.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.peft.layers.utils import peft_layer_from_state_dict
from invokeai.backend.peft.lora import LoRAModelRaw
def lora_model_from_sd_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = _group_state(state_dict)
layers: dict[str, AnyLoRALayer] = {}
for layer_key, values in grouped_state_dict.items():
layer = peft_layer_from_state_dict(layer_key, values)
layers[layer_key] = layer
return LoRAModelRaw(layers=layers)
def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
state_dict_groupped: Dict[str, Dict[str, torch.Tensor]] = {}
for key, value in state_dict.items():
stem, leaf = key.split(".", 1)
if stem not in state_dict_groupped:
state_dict_groupped[stem] = {}
state_dict_groupped[stem][leaf] = value
return state_dict_groupped

View File

@@ -1,154 +0,0 @@
import bisect
from typing import Dict, List, Tuple, TypeVar
T = TypeVar("T")
def convert_sdxl_keys_to_diffusers_format(state_dict: Dict[str, T]) -> dict[str, T]:
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
diffusers format, then this function will have no effect.
This function is adapted from:
https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409
Args:
state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict.
Raises:
ValueError: If state_dict contains an unrecognized key, or not all keys could be converted.
Returns:
Dict[str, Tensor]: The diffusers-format state_dict.
"""
converted_count = 0 # The number of Stability AI keys converted to diffusers format.
not_converted_count = 0 # The number of keys that were not converted.
# Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes.
# For example, we want to efficiently find `input_blocks_4_1` in the list when searching for
# `input_blocks_4_1_proj_in`.
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
stability_unet_keys.sort()
new_state_dict: dict[str, T] = {}
for full_key, value in state_dict.items():
if full_key.startswith("lora_unet_"):
search_key = full_key.replace("lora_unet_", "")
# Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix.
position = bisect.bisect_right(stability_unet_keys, search_key)
map_key = stability_unet_keys[position - 1]
# Now, check if the map_key *actually* matches the search_key.
if search_key.startswith(map_key):
new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key])
new_state_dict[new_key] = value
converted_count += 1
else:
new_state_dict[full_key] = value
not_converted_count += 1
elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
# The CLIP text encoders have the same keys in both Stability AI and diffusers formats.
new_state_dict[full_key] = value
continue
else:
raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.")
if converted_count > 0 and not_converted_count > 0:
raise ValueError(
f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count},"
f" not_converted={not_converted_count}"
)
return new_state_dict
# code from
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
def _make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]:
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
unet_conversion_map_layer: list[tuple[str, str]] = []
for i in range(3): # num_blocks is 3 in sdxl
# loop over downblocks/upblocks
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3:
# no attention layers in down_blocks.3
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(3):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
# if i > 0: commentout for sdxl
# no attention layers in up_blocks.0
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0."
sd_mid_atn_prefix = "middle_block.1."
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2*j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
unet_conversion_map_resnet = [
# (stable-diffusion, HF Diffusers)
("in_layers.0.", "norm1."),
("in_layers.2.", "conv1."),
("out_layers.0.", "norm2."),
("out_layers.3.", "conv2."),
("emb_layers.1.", "time_emb_proj."),
("skip_connection.", "conv_shortcut."),
]
unet_conversion_map: list[tuple[str, str]] = []
for sd, hf in unet_conversion_map_layer:
if "resnets" in hf:
for sd_res, hf_res in unet_conversion_map_resnet:
unet_conversion_map.append((sd + sd_res, hf + hf_res))
else:
unet_conversion_map.append((sd, hf))
for j in range(2):
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
sd_time_embed_prefix = f"time_embed.{j*2}."
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
for j in range(2):
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
sd_label_embed_prefix = f"label_emb.0.{j*2}."
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
unet_conversion_map.append(("out.0.", "conv_norm_out."))
unet_conversion_map.append(("out.2.", "conv_out."))
return unet_conversion_map
SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = {
sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in _make_sdxl_unet_conversion_map()
}

View File

@@ -1,10 +0,0 @@
from typing import Union
from invokeai.backend.peft.layers.full_layer import FullLayer
from invokeai.backend.peft.layers.ia3_layer import IA3Layer
from invokeai.backend.peft.layers.loha_layer import LoHALayer
from invokeai.backend.peft.layers.lokr_layer import LoKRLayer
from invokeai.backend.peft.layers.lora_layer import LoRALayer
from invokeai.backend.peft.layers.norm_layer import NormLayer
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer]

View File

@@ -1,37 +0,0 @@
from typing import Dict, Optional
import torch
from invokeai.backend.peft.layers.lora_layer_base import LoRALayerBase
class FullLayer(LoRALayerBase):
# bias handled in LoRALayerBase(calc_size, to)
# weight: torch.Tensor
# bias: Optional[torch.Tensor]
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
self.weight = values["diff"]
self.bias = values.get("diff_b", None)
self.rank = None # unscaled
self.check_keys(values, {"diff", "diff_b"})
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
return self.weight
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)

View File

@@ -1,42 +0,0 @@
from typing import Dict, Optional
import torch
from invokeai.backend.peft.layers.lora_layer_base import LoRALayerBase
class IA3Layer(LoRALayerBase):
# weight: torch.Tensor
# on_input: torch.Tensor
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
self.weight = values["weight"]
self.on_input = values["on_input"]
self.rank = None # unscaled
self.check_keys(values, {"weight", "on_input"})
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
weight = self.weight
if not self.on_input:
weight = weight.reshape(-1, 1)
assert orig_weight is not None
return orig_weight * weight
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size()
model_size += self.on_input.nelement() * self.on_input.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
self.on_input = self.on_input.to(device=device, dtype=dtype)

View File

@@ -1,68 +0,0 @@
from typing import Dict, Optional
import torch
from invokeai.backend.peft.layers.lora_layer_base import LoRALayerBase
class LoHALayer(LoRALayerBase):
# w1_a: torch.Tensor
# w1_b: torch.Tensor
# w2_a: torch.Tensor
# w2_b: torch.Tensor
# t1: Optional[torch.Tensor] = None
# t2: Optional[torch.Tensor] = None
def __init__(self, layer_key: str, values: Dict[str, torch.Tensor]):
super().__init__(layer_key, values)
self.w1_a = values["hada_w1_a"]
self.w1_b = values["hada_w1_b"]
self.w2_a = values["hada_w2_a"]
self.w2_b = values["hada_w2_b"]
self.t1 = values.get("hada_t1", None)
self.t2 = values.get("hada_t2", None)
self.rank = self.w1_b.shape[0]
self.check_keys(
values,
{
"hada_w1_a",
"hada_w1_b",
"hada_w2_a",
"hada_w2_b",
"hada_t1",
"hada_t2",
},
)
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
if self.t1 is None:
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
else:
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
weight = rebuild1 * rebuild2
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
super().to(device=device, dtype=dtype)
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
if self.t1 is not None:
self.t1 = self.t1.to(device=device, dtype=dtype)
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)

View File

@@ -1,114 +0,0 @@
from typing import Dict, Optional
import torch
from invokeai.backend.peft.layers.lora_layer_base import LoRALayerBase
class LoKRLayer(LoRALayerBase):
# w1: Optional[torch.Tensor] = None
# w1_a: Optional[torch.Tensor] = None
# w1_b: Optional[torch.Tensor] = None
# w2: Optional[torch.Tensor] = None
# w2_a: Optional[torch.Tensor] = None
# w2_b: Optional[torch.Tensor] = None
# t2: Optional[torch.Tensor] = None
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
self.w1 = values.get("lokr_w1", None)
if self.w1 is None:
self.w1_a = values["lokr_w1_a"]
self.w1_b = values["lokr_w1_b"]
else:
self.w1_b = None
self.w1_a = None
self.w2 = values.get("lokr_w2", None)
if self.w2 is None:
self.w2_a = values["lokr_w2_a"]
self.w2_b = values["lokr_w2_b"]
else:
self.w2_a = None
self.w2_b = None
self.t2 = values.get("lokr_t2", None)
if self.w1_b is not None:
self.rank = self.w1_b.shape[0]
elif self.w2_b is not None:
self.rank = self.w2_b.shape[0]
else:
self.rank = None # unscaled
self.check_keys(
values,
{
"lokr_w1",
"lokr_w1_a",
"lokr_w1_b",
"lokr_w2",
"lokr_w2_a",
"lokr_w2_b",
"lokr_t2",
},
)
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
w1: Optional[torch.Tensor] = self.w1
if w1 is None:
assert self.w1_a is not None
assert self.w1_b is not None
w1 = self.w1_a @ self.w1_b
w2 = self.w2
if w2 is None:
if self.t2 is None:
assert self.w2_a is not None
assert self.w2_b is not None
w2 = self.w2_a @ self.w2_b
else:
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
w2 = w2.contiguous()
assert w1 is not None
assert w2 is not None
weight = torch.kron(w1, w2)
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
super().to(device=device, dtype=dtype)
if self.w1 is not None:
self.w1 = self.w1.to(device=device, dtype=dtype)
else:
assert self.w1_a is not None
assert self.w1_b is not None
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
if self.w2 is not None:
self.w2 = self.w2.to(device=device, dtype=dtype)
else:
assert self.w2_a is not None
assert self.w2_b is not None
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)

View File

@@ -1,59 +0,0 @@
from typing import Dict, Optional
import torch
from invokeai.backend.peft.layers.lora_layer_base import LoRALayerBase
# TODO: find and debug lora/locon with bias
class LoRALayer(LoRALayerBase):
# up: torch.Tensor
# mid: Optional[torch.Tensor]
# down: torch.Tensor
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
self.up = values["lora_up.weight"]
self.down = values["lora_down.weight"]
self.mid = values.get("lora_mid.weight", None)
self.rank = self.down.shape[0]
self.check_keys(
values,
{
"lora_up.weight",
"lora_down.weight",
"lora_mid.weight",
},
)
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
if self.mid is not None:
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
else:
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.up, self.mid, self.down]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
super().to(device=device, dtype=dtype)
self.up = self.up.to(device=device, dtype=dtype)
self.down = self.down.to(device=device, dtype=dtype)
if self.mid is not None:
self.mid = self.mid.to(device=device, dtype=dtype)

View File

@@ -1,74 +0,0 @@
from typing import Dict, Optional, Set
import torch
import invokeai.backend.util.logging as logger
class LoRALayerBase:
# rank: Optional[int]
# alpha: Optional[float]
# bias: Optional[torch.Tensor]
# layer_key: str
# @property
# def scale(self):
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
if "alpha" in values:
self.alpha = values["alpha"].item()
else:
self.alpha = None
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
self.bias: Optional[torch.Tensor] = torch.sparse_coo_tensor(
values["bias_indices"],
values["bias_values"],
tuple(values["bias_size"]),
)
else:
self.bias = None
self.rank = None # set in layer implementation
self.layer_key = layer_key
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
return self.bias
def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
params = {"weight": self.get_weight(orig_module.weight)}
bias = self.get_bias(orig_module.bias)
if bias is not None:
params["bias"] = bias
return params
def calc_size(self) -> int:
model_size = 0
for val in [self.bias]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
if self.bias is not None:
self.bias = self.bias.to(device=device, dtype=dtype)
def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]):
"""Log a warning if values contains unhandled keys."""
# {"alpha", "bias_indices", "bias_values", "bias_size"} are hard-coded, because they are handled by
# `LoRALayerBase`. Sub-classes should provide the known_keys that they handled.
all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"}
unknown_keys = set(values.keys()) - all_known_keys
if unknown_keys:
logger.warning(
f"Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Keys: {unknown_keys}"
)

View File

@@ -1,37 +0,0 @@
from typing import Dict, Optional
import torch
from invokeai.backend.peft.layers.lora_layer_base import LoRALayerBase
class NormLayer(LoRALayerBase):
# bias handled in LoRALayerBase(calc_size, to)
# weight: torch.Tensor
# bias: Optional[torch.Tensor]
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
self.weight = values["w_norm"]
self.bias = values.get("b_norm", None)
self.rank = None # unscaled
self.check_keys(values, {"w_norm", "b_norm"})
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
return self.weight
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)

View File

@@ -1,33 +0,0 @@
from typing import Dict
import torch
from invokeai.backend.peft.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.peft.layers.full_layer import FullLayer
from invokeai.backend.peft.layers.ia3_layer import IA3Layer
from invokeai.backend.peft.layers.loha_layer import LoHALayer
from invokeai.backend.peft.layers.lokr_layer import LoKRLayer
from invokeai.backend.peft.layers.lora_layer import LoRALayer
from invokeai.backend.peft.layers.norm_layer import NormLayer
def peft_layer_from_state_dict(layer_key: str, state_dict: Dict[str, torch.Tensor]) -> AnyLoRALayer:
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
if "lora_up.weight" in state_dict:
# LoRA a.k.a LoCon
return LoRALayer(layer_key, state_dict)
elif "hada_w1_a" in state_dict:
return LoHALayer(layer_key, state_dict)
elif "lokr_w1" in state_dict or "lokr_w1_a" in state_dict:
return LoKRLayer(layer_key, state_dict)
elif "diff" in state_dict:
# Full a.k.a Diff
return FullLayer(layer_key, state_dict)
elif "on_input" in state_dict:
return IA3Layer(layer_key, state_dict)
elif "w_norm" in state_dict:
return NormLayer(layer_key, state_dict)
else:
raise ValueError(f"Unsupported lora format: {state_dict.keys()}")

View File

@@ -1,22 +0,0 @@
# Copyright (c) 2024 The InvokeAI Development team
from typing import Dict, Optional
import torch
from invokeai.backend.peft.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.raw_model import RawModel
class LoRAModelRaw(RawModel): # (torch.nn.Module):
def __init__(self, layers: Dict[str, AnyLoRALayer]):
self.layers = layers
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
for _key, layer in self.layers.items():
layer.to(device=device, dtype=dtype)
def calc_size(self) -> int:
model_size = 0
for _, layer in self.layers.items():
model_size += layer.calc_size()
return model_size

View File

@@ -1,102 +0,0 @@
from contextlib import contextmanager
from typing import Dict, Iterator, Optional, Tuple
import torch
from invokeai.backend.peft.lora import LoRAModelRaw
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
class PeftPatcher:
@classmethod
@torch.no_grad()
@contextmanager
def apply_peft_patches(
cls,
model: torch.nn.Module,
patches: Iterator[Tuple[LoRAModelRaw, float]],
prefix: str,
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
):
"""Apply one or more PEFT patches to a model.
:param model: The model to patch.
:param loras: An iterator that returns tuples of PEFT patches and associated weights. An iterator is used so
that the PEFT patches do not need to be loaded into memory all at once.
:param prefix: The keys in the patches will be filtered to only include weights with this prefix.
:cached_weights: Read-only copy of the model's state dict in CPU, for efficient unpatching purposes.
"""
original_weights = OriginalWeightsStorage(cached_weights)
try:
for patch, patch_weight in patches:
cls._apply_peft_patch(
model=model,
prefix=prefix,
patch=patch,
patch_weight=patch_weight,
original_weights=original_weights,
)
yield
finally:
for param_key, weight in original_weights.get_changed_weights():
model.get_parameter(param_key).copy_(weight)
@classmethod
@torch.no_grad()
def _apply_peft_patch(
cls,
model: torch.nn.Module,
prefix: str,
patch: LoRAModelRaw,
patch_weight: float,
original_weights: OriginalWeightsStorage,
):
"""
Apply one a LoRA to a model.
:param model: The model to patch.
:param patch: LoRA model to patch in.
:param patch_weight: LoRA patch weight.
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
:param original_weights: Storage with original weights, filled by weights which lora patches, used for unpatching.
"""
if patch_weight == 0:
return
for layer_key, layer in patch.layers.items():
if not layer_key.startswith(prefix):
continue
module = model.get_submodule(layer_key)
# All of the LoRA weight calculations will be done on the same device as the module weight.
# (Performance will be best if this is a CUDA device.)
device = module.weight.device
dtype = module.weight.dtype
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
# We intentionally move to the target device first, then cast. Experimentally, this was found to
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
# same thing in a single call to '.to(...)'.
layer.to(device=device)
layer.to(dtype=torch.float32)
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
for param_name, lora_param_weight in layer.get_parameters(module).items():
param_key = layer_key + "." + param_name
module_param = module.get_parameter(param_name)
# Save original weight
original_weights.save(param_key, module_param)
if module_param.shape != lora_param_weight.shape:
lora_param_weight = lora_param_weight.reshape(module_param.shape)
lora_param_weight *= patch_weight * layer_scale
module_param += lora_param_weight.to(dtype=dtype)
layer.to(device=TorchDevice.CPU_DEVICE)

View File

@@ -12,7 +12,7 @@ from invokeai.backend.util.devices import TorchDevice
if TYPE_CHECKING:
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.peft.lora import LoRAModelRaw
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage

View File

@@ -12,6 +12,10 @@ module.exports = {
'i18next/no-literal-string': 'error',
// https://eslint.org/docs/latest/rules/no-console
'no-console': 'error',
// https://eslint.org/docs/latest/rules/no-promise-executor-return
'no-promise-executor-return': 'error',
// https://eslint.org/docs/latest/rules/require-await
'require-await': 'error',
},
overrides: [
/**

View File

@@ -1,5 +1,5 @@
import { PropsWithChildren, memo, useEffect } from 'react';
import { modelChanged } from '../src/features/parameters/store/generationSlice';
import { modelChanged } from '../src/features/controlLayers/store/paramsSlice';
import { useAppDispatch } from '../src/app/store/storeHooks';
import { useGlobalModifiersInit } from '@invoke-ai/ui-library';
/**
@@ -10,7 +10,9 @@ export const ReduxInit = memo((props: PropsWithChildren) => {
const dispatch = useAppDispatch();
useGlobalModifiersInit();
useEffect(() => {
dispatch(modelChanged({ key: 'test_model', hash: 'some_hash', name: 'some name', base: 'sd-1', type: 'main' }));
dispatch(
modelChanged({ model: { key: 'test_model', hash: 'some_hash', name: 'some name', base: 'sd-1', type: 'main' } })
);
}, []);
return props.children;

View File

@@ -9,6 +9,8 @@ const config: KnipConfig = {
'src/services/api/schema.ts',
'src/features/nodes/types/v1/**',
'src/features/nodes/types/v2/**',
// TODO(psyche): maybe we can clean up these utils after canvas v2 release
'src/features/controlLayers/konva/util.ts',
],
ignoreBinaries: ['only-allow'],
paths: {

View File

@@ -24,7 +24,7 @@
"build": "pnpm run lint && vite build",
"typegen": "node scripts/typegen.js",
"preview": "vite preview",
"lint:knip": "knip",
"lint:knip": "knip --tags=-knipignore",
"lint:dpdm": "dpdm --no-warning --no-tree --transform --exit-code circular:1 src/main.tsx",
"lint:eslint": "eslint --max-warnings=0 .",
"lint:prettier": "prettier --check .",
@@ -52,18 +52,19 @@
}
},
"dependencies": {
"@chakra-ui/react-use-size": "^2.1.0",
"@dagrejs/dagre": "^1.1.3",
"@dagrejs/graphlib": "^2.2.3",
"@dnd-kit/core": "^6.1.0",
"@dnd-kit/sortable": "^8.0.0",
"@dnd-kit/utilities": "^3.2.2",
"@fontsource-variable/inter": "^5.0.20",
"@invoke-ai/ui-library": "^0.0.29",
"@invoke-ai/ui-library": "^0.0.32",
"@nanostores/react": "^0.7.3",
"@reduxjs/toolkit": "2.2.3",
"@roarr/browser-log-writer": "^1.3.0",
"async-mutex": "^0.5.0",
"chakra-react-select": "^4.9.1",
"cmdk": "^1.0.0",
"compare-versions": "^6.1.1",
"dateformat": "^5.0.3",
"fracturedjsonjs": "^4.0.2",
@@ -74,6 +75,8 @@
"jsondiffpatch": "^0.6.0",
"konva": "^9.3.14",
"lodash-es": "^4.17.21",
"lru-cache": "^11.0.0",
"nanoid": "^5.0.7",
"nanostores": "^0.11.2",
"new-github-issue-url": "^1.0.0",
"overlayscrollbars": "^2.10.0",
@@ -88,10 +91,8 @@
"react-hotkeys-hook": "4.5.0",
"react-i18next": "^14.1.3",
"react-icons": "^5.2.1",
"react-konva": "^18.2.10",
"react-redux": "9.1.2",
"react-resizable-panels": "^2.0.23",
"react-select": "5.8.0",
"react-use": "^17.5.1",
"react-virtuoso": "^4.9.0",
"reactflow": "^11.11.4",
@@ -102,9 +103,9 @@
"roarr": "^7.21.1",
"serialize-error": "^11.0.3",
"socket.io-client": "^4.7.5",
"stable-hash": "^0.0.4",
"use-debounce": "^10.0.2",
"use-device-pixel-ratio": "^1.1.2",
"use-image": "^1.1.1",
"uuid": "^10.0.0",
"zod": "^3.23.8",
"zod-validation-error": "^3.3.1"

File diff suppressed because it is too large Load Diff

View File

@@ -80,6 +80,7 @@
"aboutDesc": "Using Invoke for work? Check out:",
"aboutHeading": "Own Your Creative Power",
"accept": "Accept",
"apply": "Apply",
"add": "Add",
"advanced": "Advanced",
"ai": "ai",
@@ -115,6 +116,7 @@
"githubLabel": "Github",
"goTo": "Go to",
"hotkeysLabel": "Hotkeys",
"loadingImage": "Loading Image",
"imageFailedToLoad": "Unable to Load Image",
"img2img": "Image To Image",
"inpaint": "inpaint",
@@ -162,10 +164,10 @@
"alpha": "Alpha",
"selected": "Selected",
"tab": "Tab",
"viewing": "Viewing",
"viewingDesc": "Review images in a large gallery view",
"editing": "Editing",
"editingDesc": "Edit on the Control Layers canvas",
"view": "View",
"viewDesc": "Review images in a large gallery view",
"edit": "Edit",
"editDesc": "Edit on the Canvas",
"comparing": "Comparing",
"comparingDesc": "Comparing two images",
"enabled": "Enabled",
@@ -325,6 +327,14 @@
"canceled": "Canceled",
"completedIn": "Completed in",
"batch": "Batch",
"origin": "Origin",
"destination": "Destination",
"upscaling": "Upscaling",
"canvas": "Canvas",
"generation": "Generation",
"workflows": "Workflows",
"other": "Other",
"gallery": "Gallery",
"batchFieldValues": "Batch Field Values",
"item": "Item",
"session": "Session",
@@ -1100,7 +1110,6 @@
"confirmOnDelete": "Confirm On Delete",
"developer": "Developer",
"displayInProgress": "Display Progress Images",
"enableImageDebugging": "Enable Image Debugging",
"enableInformationalPopovers": "Enable Informational Popovers",
"informationalPopoversDisabled": "Informational Popovers Disabled",
"informationalPopoversDisabledDesc": "Informational popovers have been disabled. Enable them in Settings.",
@@ -1567,7 +1576,7 @@
"copyToClipboard": "Copy to Clipboard",
"cursorPosition": "Cursor Position",
"darkenOutsideSelection": "Darken Outside Selection",
"discardAll": "Discard All",
"discardAll": "Discard All & Cancel Pending Generations",
"discardCurrent": "Discard Current",
"downloadAsImage": "Download As Image",
"enableMask": "Enable Mask",
@@ -1645,39 +1654,143 @@
"storeNotInitialized": "Store is not initialized"
},
"controlLayers": {
"deleteAll": "Delete All",
"saveCanvasToGallery": "Save Canvas To Gallery",
"saveBboxToGallery": "Save Bbox To Gallery",
"savedToGalleryOk": "Saved to Gallery",
"savedToGalleryError": "Error saving to gallery",
"mergeVisible": "Merge Visible",
"mergeVisibleOk": "Merged visible layers",
"mergeVisibleError": "Error merging visible layers",
"clearHistory": "Clear History",
"generateMode": "Generate",
"generateModeDesc": "Create individual images. Generated images are added directly to the gallery.",
"composeMode": "Compose",
"composeModeDesc": "Compose your work iterative. Generated images are added back to the canvas.",
"autoSave": "Auto-save to Gallery",
"resetCanvas": "Reset Canvas",
"resetAll": "Reset All",
"clearCaches": "Clear Caches",
"recalculateRects": "Recalculate Rects",
"clipToBbox": "Clip Strokes to Bbox",
"addLayer": "Add Layer",
"duplicate": "Duplicate",
"moveToFront": "Move to Front",
"moveToBack": "Move to Back",
"moveForward": "Move Forward",
"moveBackward": "Move Backward",
"brushSize": "Brush Size",
"width": "Width",
"zoom": "Zoom",
"resetView": "Reset View",
"controlLayers": "Control Layers",
"globalMaskOpacity": "Global Mask Opacity",
"autoNegative": "Auto Negative",
"enableAutoNegative": "Enable Auto Negative",
"disableAutoNegative": "Disable Auto Negative",
"deletePrompt": "Delete Prompt",
"resetRegion": "Reset Region",
"debugLayers": "Debug Layers",
"showHUD": "Show HUD",
"rectangle": "Rectangle",
"maskPreviewColor": "Mask Preview Color",
"maskFill": "Mask Fill",
"addPositivePrompt": "Add $t(common.positivePrompt)",
"addNegativePrompt": "Add $t(common.negativePrompt)",
"addIPAdapter": "Add $t(common.ipAdapter)",
"regionalGuidance": "Regional Guidance",
"addRasterLayer": "Add $t(controlLayers.rasterLayer)",
"addControlLayer": "Add $t(controlLayers.controlLayer)",
"addInpaintMask": "Add $t(controlLayers.inpaintMask)",
"addRegionalGuidance": "Add $t(controlLayers.regionalGuidance)",
"regionalGuidanceLayer": "$t(controlLayers.regionalGuidance) $t(unifiedCanvas.layer)",
"raster": "Raster",
"rasterLayer": "Raster Layer",
"controlLayer": "Control Layer",
"inpaintMask": "Inpaint Mask",
"regionalGuidance": "Regional Guidance",
"ipAdapter": "IP Adapter",
"sendToGallery": "Send To Gallery",
"sendToGalleryDesc": "Generations will be sent to the gallery.",
"sendToCanvas": "Send To Canvas",
"sendToCanvasDesc": "Generations will be staged onto the canvas.",
"rasterLayer_withCount_one": "$t(controlLayers.rasterLayer)",
"controlLayer_withCount_one": "$t(controlLayers.controlLayer)",
"inpaintMask_withCount_one": "$t(controlLayers.inpaintMask)",
"regionalGuidance_withCount_one": "$t(controlLayers.regionalGuidance)",
"ipAdapter_withCount_one": "$t(controlLayers.ipAdapter)",
"rasterLayer_withCount_other": "Raster Layers",
"controlLayer_withCount_other": "Control Layers",
"inpaintMask_withCount_other": "Inpaint Masks",
"regionalGuidance_withCount_other": "Regional Guidance",
"ipAdapter_withCount_other": "IP Adapters",
"opacity": "Opacity",
"regionalGuidance_withCount_hidden": "Regional Guidance ({{count}} hidden)",
"controlLayers_withCount_hidden": "Control Layers ({{count}} hidden)",
"rasterLayers_withCount_hidden": "Raster Layers ({{count}} hidden)",
"ipAdapters_withCount_hidden": "IP Adapters ({{count}} hidden)",
"inpaintMasks_withCount_hidden": "Inpaint Masks ({{count}} hidden)",
"regionalGuidance_withCount_visible": "Regional Guidance ({{count}})",
"controlLayers_withCount_visible": "Control Layers ({{count}})",
"rasterLayers_withCount_visible": "Raster Layers ({{count}})",
"ipAdapters_withCount_visible": "IP Adapters ({{count}})",
"inpaintMasks_withCount_visible": "Inpaint Masks ({{count}})",
"globalControlAdapter": "Global $t(controlnet.controlAdapter_one)",
"globalControlAdapterLayer": "Global $t(controlnet.controlAdapter_one) $t(unifiedCanvas.layer)",
"globalIPAdapter": "Global $t(common.ipAdapter)",
"globalIPAdapterLayer": "Global $t(common.ipAdapter) $t(unifiedCanvas.layer)",
"globalInitialImage": "Global Initial Image",
"globalInitialImageLayer": "$t(controlLayers.globalInitialImage) $t(unifiedCanvas.layer)",
"layer": "Layer",
"opacityFilter": "Opacity Filter",
"clearProcessor": "Clear Processor",
"resetProcessor": "Reset Processor to Defaults",
"noLayersAdded": "No Layers Added",
"layers_one": "Layer",
"layers_other": "Layers"
"layers_other": "Layers",
"objects_zero": "empty",
"objects_one": "{{count}} object",
"objects_other": "{{count}} objects",
"convertToControlLayer": "Convert to Control Layer",
"convertToRasterLayer": "Convert to Raster Layer",
"transparency": "Transparency",
"enableTransparencyEffect": "Enable Transparency Effect",
"disableTransparencyEffect": "Disable Transparency Effect",
"hidingType": "Hiding {{type}}",
"showingType": "Showing {{type}}",
"dynamicGrid": "Dynamic Grid",
"logDebugInfo": "Log Debug Info",
"locked": "Locked",
"unlocked": "Unlocked",
"deleteSelected": "Delete Selected",
"deleteAll": "Delete All",
"flipHorizontal": "Flip Horizontal",
"flipVertical": "Flip Vertical",
"fill": {
"fillColor": "Fill Color",
"fillStyle": "Fill Style",
"solid": "Solid",
"grid": "Grid",
"crosshatch": "Crosshatch",
"vertical": "Vertical",
"horizontal": "Horizontal",
"diagonal": "Diagonal"
},
"tool": {
"brush": "Brush",
"eraser": "Eraser",
"rectangle": "Rectangle",
"bbox": "Bbox",
"move": "Move",
"view": "View",
"transform": "Transform",
"colorPicker": "Color Picker"
},
"filter": {
"filter": "Filter",
"filters": "Filters",
"filterType": "Filter Type",
"preview": "Preview",
"apply": "Apply",
"cancel": "Cancel"
}
},
"upscaling": {
"upscale": "Upscale",
@@ -1765,5 +1878,30 @@
"upscaling": "Upscaling",
"upscalingTab": "$t(ui.tabs.upscaling) $t(common.tab)"
}
},
"system": {
"enableLogging": "Enable Logging",
"logLevel": {
"logLevel": "Log Level",
"trace": "Trace",
"debug": "Debug",
"info": "Info",
"warn": "Warn",
"error": "Error",
"fatal": "Fatal"
},
"logNamespaces": {
"logNamespaces": "Log Namespaces",
"gallery": "Gallery",
"models": "Models",
"config": "Config",
"canvas": "Canvas",
"generation": "Generation",
"workflows": "Workflows",
"system": "System",
"events": "Events",
"queue": "Queue",
"metadata": "Metadata"
}
}
}

View File

@@ -38,7 +38,7 @@ async function generateTypes(schema) {
process.stdout.write(`\nOK!\r\n`);
}
async function main() {
function main() {
const encoding = 'utf-8';
if (process.stdin.isTTY) {

View File

@@ -6,6 +6,7 @@ import { appStarted } from 'app/store/middleware/listenerMiddleware/listeners/ap
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import type { PartialAppConfig } from 'app/types/invokeai';
import ImageUploadOverlay from 'common/components/ImageUploadOverlay';
import { useScopeFocusWatcher } from 'common/hooks/interactionScopes';
import { useClearStorage } from 'common/hooks/useClearStorage';
import { useFullscreenDropzone } from 'common/hooks/useFullscreenDropzone';
import { useGlobalHotkeys } from 'common/hooks/useGlobalHotkeys';
@@ -13,13 +14,16 @@ import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardMo
import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageModal';
import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicPromptsPreviewModal';
import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast';
import { ClearQueueConfirmationsAlertDialog } from 'features/queue/components/ClearQueueConfirmationAlertDialog';
import { StylePresetModal } from 'features/stylePresets/components/StylePresetForm/StylePresetModal';
import { activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice';
import RefreshAfterResetModal from 'features/system/components/SettingsModal/RefreshAfterResetModal';
import SettingsModal from 'features/system/components/SettingsModal/SettingsModal';
import { configChanged } from 'features/system/store/configSlice';
import { languageSelector } from 'features/system/store/systemSelectors';
import InvokeTabs from 'features/ui/components/InvokeTabs';
import type { InvokeTabName } from 'features/ui/store/tabMap';
import { selectLanguage } from 'features/system/store/systemSelectors';
import { AppContent } from 'features/ui/components/AppContent';
import { setActiveTab } from 'features/ui/store/uiSlice';
import type { TabName } from 'features/ui/store/uiTypes';
import { useGetAndLoadLibraryWorkflow } from 'features/workflowLibrary/hooks/useGetAndLoadLibraryWorkflow';
import { AnimatePresence } from 'framer-motion';
import i18n from 'i18n';
@@ -41,7 +45,7 @@ interface Props {
};
selectedWorkflowId?: string;
selectedStylePresetId?: string;
destination?: InvokeTabName | undefined;
destination?: TabName;
}
const App = ({
@@ -51,7 +55,7 @@ const App = ({
selectedStylePresetId,
destination,
}: Props) => {
const language = useAppSelector(languageSelector);
const language = useAppSelector(selectLanguage);
const logger = useLogger('system');
const dispatch = useAppDispatch();
const clearStorage = useClearStorage();
@@ -107,6 +111,7 @@ const App = ({
useStarterModelsToast();
useSyncQueueStatus();
useScopeFocusWatcher();
return (
<ErrorBoundary onReset={handleReset} FallbackComponent={AppErrorBoundaryFallback}>
@@ -119,7 +124,7 @@ const App = ({
{...dropzone.getRootProps()}
>
<input {...dropzone.getInputProps()} />
<InvokeTabs />
<AppContent />
<AnimatePresence>
{dropzone.isDragActive && isHandlingUpload && (
<ImageUploadOverlay dropzone={dropzone} setIsHandlingUpload={setIsHandlingUpload} />
@@ -130,7 +135,10 @@ const App = ({
<ChangeBoardModal />
<DynamicPromptsModal />
<StylePresetModal />
<ClearQueueConfirmationsAlertDialog />
<PreselectedImage selectedImage={selectedImage} />
<SettingsModal />
<RefreshAfterResetModal />
</ErrorBoundary>
);
};

View File

@@ -1,5 +1,7 @@
import { Button, Flex, Heading, Image, Link, Text } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { selectConfigSlice } from 'features/system/store/configSlice';
import { toast } from 'features/toast/toast';
import newGithubIssueUrl from 'new-github-issue-url';
import InvokeLogoYellow from 'public/assets/images/invoke-symbol-ylw-lrg.svg';
@@ -13,9 +15,11 @@ type Props = {
resetErrorBoundary: () => void;
};
const selectIsLocal = createSelector(selectConfigSlice, (config) => config.isLocal);
const AppErrorBoundaryFallback = ({ error, resetErrorBoundary }: Props) => {
const { t } = useTranslation();
const isLocal = useAppSelector((s) => s.config.isLocal);
const isLocal = useAppSelector(selectIsLocal);
const handleCopy = useCallback(() => {
const text = JSON.stringify(serializeError(error), null, 2);

View File

@@ -19,7 +19,7 @@ import type { PartialAppConfig } from 'app/types/invokeai';
import Loading from 'common/components/Loading/Loading';
import AppDndContext from 'features/dnd/components/AppDndContext';
import type { WorkflowCategory } from 'features/nodes/types/workflow';
import type { InvokeTabName } from 'features/ui/store/tabMap';
import type { TabName } from 'features/ui/store/uiTypes';
import type { PropsWithChildren, ReactNode } from 'react';
import React, { lazy, memo, useEffect, useMemo } from 'react';
import { Provider } from 'react-redux';
@@ -46,7 +46,7 @@ interface Props extends PropsWithChildren {
};
selectedWorkflowId?: string;
selectedStylePresetId?: string;
destination?: InvokeTabName;
destination?: TabName;
customStarUi?: CustomStarUi;
socketOptions?: Partial<ManagerOptions & SocketOptions>;
isDebugging?: boolean;

View File

@@ -2,7 +2,7 @@ import { useStore } from '@nanostores/react';
import { $authToken } from 'app/store/nanostores/authToken';
import { $baseUrl } from 'app/store/nanostores/baseUrl';
import { $isDebugging } from 'app/store/nanostores/isDebugging';
import { useAppDispatch } from 'app/store/storeHooks';
import { useAppStore } from 'app/store/nanostores/store';
import type { MapStore } from 'nanostores';
import { atom, map } from 'nanostores';
import { useEffect, useMemo } from 'react';
@@ -18,14 +18,19 @@ declare global {
}
}
export type AppSocket = Socket<ServerToClientEvents, ClientToServerEvents>;
export const $socket = atom<AppSocket | null>(null);
export const $socketOptions = map<Partial<ManagerOptions & SocketOptions>>({});
const $isSocketInitialized = atom<boolean>(false);
export const $isConnected = atom<boolean>(false);
/**
* Initializes the socket.io connection and sets up event listeners.
*/
export const useSocketIO = () => {
const dispatch = useAppDispatch();
const { dispatch, getState } = useAppStore();
const baseUrl = useStore($baseUrl);
const authToken = useStore($authToken);
const addlSocketOptions = useStore($socketOptions);
@@ -61,8 +66,9 @@ export const useSocketIO = () => {
return;
}
const socket: Socket<ServerToClientEvents, ClientToServerEvents> = io(socketUrl, socketOptions);
setEventListeners({ dispatch, socket });
const socket: AppSocket = io(socketUrl, socketOptions);
$socket.set(socket);
setEventListeners({ socket, dispatch, getState, setIsConnected: $isConnected.set });
socket.connect();
if ($isDebugging.get() || import.meta.env.MODE === 'development') {
@@ -84,5 +90,5 @@ export const useSocketIO = () => {
socket.disconnect();
$isSocketInitialized.set(false);
};
}, [dispatch, socketOptions, socketUrl]);
}, [dispatch, getState, socketOptions, socketUrl]);
};

View File

@@ -15,21 +15,21 @@ export const BASE_CONTEXT = {};
export const $logger = atom<Logger>(Roarr.child(BASE_CONTEXT));
export type LoggerNamespace =
| 'images'
| 'models'
| 'config'
| 'canvas'
| 'generation'
| 'nodes'
| 'system'
| 'socketio'
| 'session'
| 'queue'
| 'dnd'
| 'controlLayers';
export const zLogNamespace = z.enum([
'canvas',
'config',
'events',
'gallery',
'generation',
'metadata',
'models',
'system',
'queue',
'workflows',
]);
export type LogNamespace = z.infer<typeof zLogNamespace>;
export const logger = (namespace: LoggerNamespace) => $logger.get().child({ namespace });
export const logger = (namespace: LogNamespace) => $logger.get().child({ namespace });
export const zLogLevel = z.enum(['trace', 'debug', 'info', 'warn', 'error', 'fatal']);
export type LogLevel = z.infer<typeof zLogLevel>;

View File

@@ -1,29 +1,41 @@
import { createLogWriter } from '@roarr/browser-log-writer';
import { useAppSelector } from 'app/store/storeHooks';
import {
selectSystemLogIsEnabled,
selectSystemLogLevel,
selectSystemLogNamespaces,
} from 'features/system/store/systemSlice';
import { useEffect, useMemo } from 'react';
import { ROARR, Roarr } from 'roarr';
import type { LoggerNamespace } from './logger';
import type { LogNamespace } from './logger';
import { $logger, BASE_CONTEXT, LOG_LEVEL_MAP, logger } from './logger';
export const useLogger = (namespace: LoggerNamespace) => {
const consoleLogLevel = useAppSelector((s) => s.system.consoleLogLevel);
const shouldLogToConsole = useAppSelector((s) => s.system.shouldLogToConsole);
export const useLogger = (namespace: LogNamespace) => {
const logLevel = useAppSelector(selectSystemLogLevel);
const logNamespaces = useAppSelector(selectSystemLogNamespaces);
const logIsEnabled = useAppSelector(selectSystemLogIsEnabled);
// The provided Roarr browser log writer uses localStorage to config logging to console
useEffect(() => {
if (shouldLogToConsole) {
if (logIsEnabled) {
// Enable console log output
localStorage.setItem('ROARR_LOG', 'true');
// Use a filter to show only logs of the given level
localStorage.setItem('ROARR_FILTER', `context.logLevel:>=${LOG_LEVEL_MAP[consoleLogLevel]}`);
let filter = `context.logLevel:>=${LOG_LEVEL_MAP[logLevel]}`;
if (logNamespaces.length > 0) {
filter += ` AND (${logNamespaces.map((ns) => `context.namespace:${ns}`).join(' OR ')})`;
} else {
filter += ' AND context.namespace:undefined';
}
localStorage.setItem('ROARR_FILTER', filter);
} else {
// Disable console log output
localStorage.setItem('ROARR_LOG', 'false');
}
ROARR.write = createLogWriter();
}, [consoleLogLevel, shouldLogToConsole]);
}, [logLevel, logIsEnabled, logNamespaces]);
// Update the module-scoped logger context as needed
useEffect(() => {

View File

@@ -1,7 +1,7 @@
import { createAction } from '@reduxjs/toolkit';
import type { InvokeTabName } from 'features/ui/store/tabMap';
import type { TabName } from 'features/ui/store/uiTypes';
export const enqueueRequested = createAction<{
tabName: InvokeTabName;
tabName: TabName;
prepend: boolean;
}>('app/enqueueRequested');

View File

@@ -1,2 +1,3 @@
export const STORAGE_PREFIX = '@@invokeai-';
export const EMPTY_ARRAY = [];
export const EMPTY_OBJECT = {};

View File

@@ -1,5 +1,6 @@
import { createDraftSafeSelectorCreator, createSelectorCreator, lruMemoize } from '@reduxjs/toolkit';
import type { GetSelectorsOptions } from '@reduxjs/toolkit/dist/entities/state_selectors';
import type { RootState } from 'app/store/store';
import { isEqual } from 'lodash-es';
/**
@@ -19,3 +20,5 @@ export const getSelectorsOptions: GetSelectorsOptions = {
argsMemoize: lruMemoize,
}),
};
export const createMemoizedAppSelector = createMemoizedSelector.withTypes<RootState>();

View File

@@ -1,5 +1,4 @@
import { logger } from 'app/logging/logger';
import { parseify } from 'common/util/serialize';
import { PersistError, RehydrateError } from 'redux-remember';
import { serializeError } from 'serialize-error';
@@ -41,6 +40,6 @@ export const errorHandler = (err: PersistError | RehydrateError) => {
} else if (err instanceof RehydrateError) {
log.error({ error: serializeError(err) }, 'Problem rehydrating state');
} else {
log.error({ error: parseify(err) }, 'Problem in persistence layer');
log.error({ error: serializeError(err) }, 'Problem in persistence layer');
}
};

View File

@@ -1,9 +1,7 @@
import type { UnknownAction } from '@reduxjs/toolkit';
import { deepClone } from 'common/util/deepClone';
import { isAnyGraphBuilt } from 'features/nodes/store/actions';
import { appInfoApi } from 'services/api/endpoints/appInfo';
import type { Graph } from 'services/api/types';
import { socketGeneratorProgress } from 'services/events/actions';
export const actionSanitizer = <A extends UnknownAction>(action: A): A => {
if (isAnyGraphBuilt(action)) {
@@ -24,13 +22,5 @@ export const actionSanitizer = <A extends UnknownAction>(action: A): A => {
};
}
if (socketGeneratorProgress.match(action)) {
const sanitized = deepClone(action);
if (sanitized.payload.data.progress_image) {
sanitized.payload.data.progress_image.dataURL = '<Progress image omitted>';
}
return sanitized;
}
return action;
};

View File

@@ -1,7 +1,7 @@
import type { TypedStartListening } from '@reduxjs/toolkit';
import { createListenerMiddleware } from '@reduxjs/toolkit';
import { addAdHocPostProcessingRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/addAdHocPostProcessingRequestedListener';
import { addCommitStagingAreaImageListener } from 'app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener';
import { addStagingListeners } from 'app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener';
import { addAnyEnqueuedListener } from 'app/store/middleware/listenerMiddleware/listeners/anyEnqueued';
import { addAppConfigReceivedListener } from 'app/store/middleware/listenerMiddleware/listeners/appConfigReceived';
import { addAppStartedListener } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
@@ -9,17 +9,6 @@ import { addBatchEnqueuedListener } from 'app/store/middleware/listenerMiddlewar
import { addDeleteBoardAndImagesFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/boardAndImagesDeleted';
import { addBoardIdSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/boardIdSelected';
import { addBulkDownloadListeners } from 'app/store/middleware/listenerMiddleware/listeners/bulkDownload';
import { addCanvasCopiedToClipboardListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasCopiedToClipboard';
import { addCanvasDownloadedAsImageListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasDownloadedAsImage';
import { addCanvasImageToControlNetListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasImageToControlNet';
import { addCanvasMaskSavedToGalleryListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasMaskSavedToGallery';
import { addCanvasMaskToControlNetListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasMaskToControlNet';
import { addCanvasMergedListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasMerged';
import { addCanvasSavedToGalleryListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery';
import { addControlAdapterPreprocessor } from 'app/store/middleware/listenerMiddleware/listeners/controlAdapterPreprocessor';
import { addControlNetAutoProcessListener } from 'app/store/middleware/listenerMiddleware/listeners/controlNetAutoProcess';
import { addControlNetImageProcessedListener } from 'app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed';
import { addEnqueueRequestedCanvasListener } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedCanvas';
import { addEnqueueRequestedLinear } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear';
import { addEnqueueRequestedNodes } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes';
import { addGalleryImageClickedListener } from 'app/store/middleware/listenerMiddleware/listeners/galleryImageClicked';
@@ -37,16 +26,7 @@ import { addModelSelectedListener } from 'app/store/middleware/listenerMiddlewar
import { addModelsLoadedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelsLoaded';
import { addDynamicPromptsListener } from 'app/store/middleware/listenerMiddleware/listeners/promptChanged';
import { addSetDefaultSettingsListener } from 'app/store/middleware/listenerMiddleware/listeners/setDefaultSettings';
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected';
import { addSocketDisconnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected';
import { addGeneratorProgressEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress';
import { addInvocationCompleteEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete';
import { addInvocationErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError';
import { addInvocationStartedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted';
import { addModelInstallEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall';
import { addModelLoadEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoad';
import { addSocketQueueItemStatusChangedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged';
import { addStagingAreaImageSavedListener } from 'app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved';
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketConnected';
import { addUpdateAllNodesRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested';
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
import type { AppDispatch, RootState } from 'app/store/store';
@@ -83,7 +63,6 @@ addGalleryImageClickedListener(startAppListening);
addGalleryOffsetChangedListener(startAppListening);
// User Invoked
addEnqueueRequestedCanvasListener(startAppListening);
addEnqueueRequestedNodes(startAppListening);
addEnqueueRequestedLinear(startAppListening);
addEnqueueRequestedUpscale(startAppListening);
@@ -91,31 +70,22 @@ addAnyEnqueuedListener(startAppListening);
addBatchEnqueuedListener(startAppListening);
// Canvas actions
addCanvasSavedToGalleryListener(startAppListening);
addCanvasMaskSavedToGalleryListener(startAppListening);
addCanvasImageToControlNetListener(startAppListening);
addCanvasMaskToControlNetListener(startAppListening);
addCanvasDownloadedAsImageListener(startAppListening);
addCanvasCopiedToClipboardListener(startAppListening);
addCanvasMergedListener(startAppListening);
addStagingAreaImageSavedListener(startAppListening);
addCommitStagingAreaImageListener(startAppListening);
// addCanvasSavedToGalleryListener(startAppListening);
// addCanvasMaskSavedToGalleryListener(startAppListening);
// addCanvasImageToControlNetListener(startAppListening);
// addCanvasMaskToControlNetListener(startAppListening);
// addCanvasDownloadedAsImageListener(startAppListening);
// addCanvasCopiedToClipboardListener(startAppListening);
// addCanvasMergedListener(startAppListening);
// addStagingAreaImageSavedListener(startAppListening);
// addCommitStagingAreaImageListener(startAppListening);
addStagingListeners(startAppListening);
// Socket.IO
addGeneratorProgressEventListener(startAppListening);
addInvocationCompleteEventListener(startAppListening);
addInvocationErrorEventListener(startAppListening);
addInvocationStartedEventListener(startAppListening);
addSocketConnectedEventListener(startAppListening);
addSocketDisconnectedEventListener(startAppListening);
addModelLoadEventListener(startAppListening);
addModelInstallEventListener(startAppListening);
addSocketQueueItemStatusChangedEventListener(startAppListening);
addBulkDownloadListeners(startAppListening);
// ControlNet
addControlNetImageProcessedListener(startAppListening);
addControlNetAutoProcessListener(startAppListening);
// Gallery bulk download
addBulkDownloadListeners(startAppListening);
// Boards
addImageAddedToBoardFulfilledListener(startAppListening);
@@ -148,4 +118,4 @@ addAdHocPostProcessingRequestedListener(startAppListening);
addDynamicPromptsListener(startAppListening);
addSetDefaultSettingsListener(startAppListening);
addControlAdapterPreprocessor(startAppListening);
// addControlAdapterPreprocessor(startAppListening);

View File

@@ -1,21 +1,21 @@
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { parseify } from 'common/util/serialize';
import type { SerializableObject } from 'common/types';
import { buildAdHocPostProcessingGraph } from 'features/nodes/util/graph/buildAdHocPostProcessingGraph';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { queueApi } from 'services/api/endpoints/queue';
import type { BatchConfig, ImageDTO } from 'services/api/types';
const log = logger('queue');
export const adHocPostProcessingRequested = createAction<{ imageDTO: ImageDTO }>(`upscaling/postProcessingRequested`);
export const addAdHocPostProcessingRequestedListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: adHocPostProcessingRequested,
effect: async (action, { dispatch, getState }) => {
const log = logger('session');
const { imageDTO } = action.payload;
const state = getState();
@@ -39,9 +39,9 @@ export const addAdHocPostProcessingRequestedListener = (startAppListening: AppSt
const enqueueResult = await req.unwrap();
req.reset();
log.debug({ enqueueResult: parseify(enqueueResult) }, t('queue.graphQueued'));
log.debug({ enqueueResult } as SerializableObject, t('queue.graphQueued'));
} catch (error) {
log.error({ enqueueBatchArg: parseify(enqueueBatchArg) }, t('queue.graphFailedToQueue'));
log.error({ enqueueBatchArg } as SerializableObject, t('queue.graphFailedToQueue'));
if (error instanceof Object && 'status' in error && error.status === 403) {
return;

View File

@@ -23,7 +23,7 @@ export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartLis
*/
startAppListening({
matcher: matchAnyBoardDeleted,
effect: async (action, { dispatch, getState }) => {
effect: (action, { dispatch, getState }) => {
const state = getState();
const deletedBoardId = action.meta.arg.originalArgs;
const { autoAddBoardId, selectedBoardId } = state.gallery;
@@ -44,7 +44,7 @@ export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartLis
// If we archived a board, it may end up hidden. If it's selected or the auto-add board, we should reset those.
startAppListening({
matcher: boardsApi.endpoints.updateBoard.matchFulfilled,
effect: async (action, { dispatch, getState }) => {
effect: (action, { dispatch, getState }) => {
const state = getState();
const { shouldShowArchivedBoards } = state.gallery;
@@ -61,7 +61,7 @@ export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartLis
// When we hide archived boards, if the selected or the auto-add board is archived, we should reset those.
startAppListening({
actionCreator: shouldShowArchivedBoardsChanged,
effect: async (action, { dispatch, getState }) => {
effect: (action, { dispatch, getState }) => {
const shouldShowArchivedBoards = action.payload;
// We only need to take action if we have just hidden archived boards.
@@ -100,7 +100,7 @@ export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartLis
*/
startAppListening({
matcher: boardsApi.endpoints.listAllBoards.matchFulfilled,
effect: async (action, { dispatch, getState }) => {
effect: (action, { dispatch, getState }) => {
const boards = action.payload;
const state = getState();
const { selectedBoardId, autoAddBoardId } = state.gallery;

View File

@@ -1,33 +1,37 @@
import { isAnyOf } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import {
canvasBatchIdsReset,
commitStagingAreaImage,
discardStagedImages,
resetCanvas,
setInitialCanvasImage,
} from 'features/canvas/store/canvasSlice';
sessionStagingAreaImageAccepted,
sessionStagingAreaReset,
} from 'features/controlLayers/store/canvasSessionSlice';
import { rasterLayerAdded } from 'features/controlLayers/store/canvasSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
import { imageDTOToImageObject } from 'features/controlLayers/store/types';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { queueApi } from 'services/api/endpoints/queue';
import { $lastCanvasProgressEvent } from 'services/events/setEventListeners';
import { assert } from 'tsafe';
const matcher = isAnyOf(commitStagingAreaImage, discardStagedImages, resetCanvas, setInitialCanvasImage);
const log = logger('canvas');
export const addCommitStagingAreaImageListener = (startAppListening: AppStartListening) => {
export const addStagingListeners = (startAppListening: AppStartListening) => {
startAppListening({
matcher,
effect: async (_, { dispatch, getState }) => {
const log = logger('canvas');
const state = getState();
const { batchIds } = state.canvas;
actionCreator: sessionStagingAreaReset,
effect: async (_, { dispatch }) => {
try {
const req = dispatch(
queueApi.endpoints.cancelByBatchIds.initiate({ batch_ids: batchIds }, { fixedCacheKey: 'cancelByBatchIds' })
queueApi.endpoints.cancelByBatchOrigin.initiate(
{ origin: 'canvas' },
{ fixedCacheKey: 'cancelByBatchOrigin' }
)
);
const { canceled } = await req.unwrap();
req.reset();
$lastCanvasProgressEvent.set(null);
if (canceled > 0) {
log.debug(`Canceled ${canceled} canvas batches`);
toast({
@@ -36,7 +40,6 @@ export const addCommitStagingAreaImageListener = (startAppListening: AppStartLis
status: 'success',
});
}
dispatch(canvasBatchIdsReset());
} catch {
log.error('Failed to cancel canvas batches');
toast({
@@ -47,4 +50,26 @@ export const addCommitStagingAreaImageListener = (startAppListening: AppStartLis
}
},
});
startAppListening({
actionCreator: sessionStagingAreaImageAccepted,
effect: (action, api) => {
const { index } = action.payload;
const state = api.getState();
const stagingAreaImage = state.canvasSession.stagedImages[index];
assert(stagingAreaImage, 'No staged image found to accept');
const { x, y } = selectCanvasSlice(state).bbox.rect;
const { imageDTO, offsetX, offsetY } = stagingAreaImage;
const imageObject = imageDTOToImageObject(imageDTO);
const overrides: Partial<CanvasRasterLayerState> = {
position: { x: x + offsetX, y: y + offsetY },
objects: [imageObject],
};
api.dispatch(rasterLayerAdded({ overrides, isSelected: false }));
api.dispatch(sessionStagingAreaReset());
},
});
};

View File

@@ -4,7 +4,7 @@ import { queueApi, selectQueueStatus } from 'services/api/endpoints/queue';
export const addAnyEnqueuedListener = (startAppListening: AppStartListening) => {
startAppListening({
matcher: queueApi.endpoints.enqueueBatch.matchFulfilled,
effect: async (_, { dispatch, getState }) => {
effect: (_, { dispatch, getState }) => {
const { data } = selectQueueStatus(getState());
if (!data || data.processor.is_started) {

View File

@@ -1,14 +1,14 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { setInfillMethod } from 'features/parameters/store/generationSlice';
import { setInfillMethod } from 'features/controlLayers/store/paramsSlice';
import { shouldUseNSFWCheckerChanged, shouldUseWatermarkerChanged } from 'features/system/store/systemSlice';
import { appInfoApi } from 'services/api/endpoints/appInfo';
export const addAppConfigReceivedListener = (startAppListening: AppStartListening) => {
startAppListening({
matcher: appInfoApi.endpoints.getAppConfig.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
effect: (action, { getState, dispatch }) => {
const { infill_methods = [], nsfw_methods = [], watermarking_methods = [] } = action.payload;
const infillMethod = getState().generation.infillMethod;
const infillMethod = getState().params.infillMethod;
if (!infill_methods.includes(infillMethod)) {
// if there is no infill method, set it to the first one

View File

@@ -6,7 +6,7 @@ export const appStarted = createAction('app/appStarted');
export const addAppStartedListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: appStarted,
effect: async (action, { unsubscribe, cancelActiveListeners }) => {
effect: (action, { unsubscribe, cancelActiveListeners }) => {
// this should only run once
cancelActiveListeners();
unsubscribe();

View File

@@ -1,27 +1,30 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { parseify } from 'common/util/serialize';
import type { SerializableObject } from 'common/types';
import { zPydanticValidationError } from 'features/system/store/zodSchemas';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { truncate, upperFirst } from 'lodash-es';
import { serializeError } from 'serialize-error';
import { queueApi } from 'services/api/endpoints/queue';
const log = logger('queue');
export const addBatchEnqueuedListener = (startAppListening: AppStartListening) => {
// success
startAppListening({
matcher: queueApi.endpoints.enqueueBatch.matchFulfilled,
effect: async (action) => {
const response = action.payload;
effect: (action) => {
const enqueueResult = action.payload;
const arg = action.meta.arg.originalArgs;
logger('queue').debug({ enqueueResult: parseify(response) }, 'Batch enqueued');
log.debug({ enqueueResult } as SerializableObject, 'Batch enqueued');
toast({
id: 'QUEUE_BATCH_SUCCEEDED',
title: t('queue.batchQueued'),
status: 'success',
description: t('queue.batchQueuedDesc', {
count: response.enqueued,
count: enqueueResult.enqueued,
direction: arg.prepend ? t('queue.front') : t('queue.back'),
}),
});
@@ -31,9 +34,9 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) =
// error
startAppListening({
matcher: queueApi.endpoints.enqueueBatch.matchRejected,
effect: async (action) => {
effect: (action) => {
const response = action.payload;
const arg = action.meta.arg.originalArgs;
const batchConfig = action.meta.arg.originalArgs;
if (!response) {
toast({
@@ -42,7 +45,7 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) =
status: 'error',
description: t('common.unknownError'),
});
logger('queue').error({ batchConfig: parseify(arg), error: parseify(response) }, t('queue.batchFailedToQueue'));
log.error({ batchConfig } as SerializableObject, t('queue.batchFailedToQueue'));
return;
}
@@ -68,7 +71,7 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) =
description: t('common.unknownError'),
});
}
logger('queue').error({ batchConfig: parseify(arg), error: parseify(response) }, t('queue.batchFailedToQueue'));
log.error({ batchConfig, error: serializeError(response) } as SerializableObject, t('queue.batchFailedToQueue'));
},
});
};

View File

@@ -1,47 +1,31 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { resetCanvas } from 'features/canvas/store/canvasSlice';
import { controlAdaptersReset } from 'features/controlAdapters/store/controlAdaptersSlice';
import { allLayersDeleted } from 'features/controlLayers/store/controlLayersSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { getImageUsage } from 'features/deleteImageModal/store/selectors';
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
import { selectNodesSlice } from 'features/nodes/store/selectors';
import { imagesApi } from 'services/api/endpoints/images';
export const addDeleteBoardAndImagesFulfilledListener = (startAppListening: AppStartListening) => {
startAppListening({
matcher: imagesApi.endpoints.deleteBoardAndImages.matchFulfilled,
effect: async (action, { dispatch, getState }) => {
effect: (action, { dispatch, getState }) => {
const { deleted_images } = action.payload;
// Remove all deleted images from the UI
let wasCanvasReset = false;
let wasNodeEditorReset = false;
let wereControlAdaptersReset = false;
let wereControlLayersReset = false;
const { canvas, nodes, controlAdapters, controlLayers } = getState();
const state = getState();
const nodes = selectNodesSlice(state);
const canvas = selectCanvasSlice(state);
deleted_images.forEach((image_name) => {
const imageUsage = getImageUsage(canvas, nodes.present, controlAdapters, controlLayers.present, image_name);
if (imageUsage.isCanvasImage && !wasCanvasReset) {
dispatch(resetCanvas());
wasCanvasReset = true;
}
const imageUsage = getImageUsage(nodes, canvas, image_name);
if (imageUsage.isNodesImage && !wasNodeEditorReset) {
dispatch(nodeEditorReset());
wasNodeEditorReset = true;
}
if (imageUsage.isControlImage && !wereControlAdaptersReset) {
dispatch(controlAdaptersReset());
wereControlAdaptersReset = true;
}
if (imageUsage.isControlLayerImage && !wereControlLayersReset) {
dispatch(allLayersDeleted());
wereControlLayersReset = true;
}
});
},
});

View File

@@ -1,21 +1,15 @@
import { ExternalLink } from '@invoke-ai/ui-library';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
import {
socketBulkDownloadComplete,
socketBulkDownloadError,
socketBulkDownloadStarted,
} from 'services/events/actions';
const log = logger('images');
const log = logger('gallery');
export const addBulkDownloadListeners = (startAppListening: AppStartListening) => {
startAppListening({
matcher: imagesApi.endpoints.bulkDownloadImages.matchFulfilled,
effect: async (action) => {
effect: (action) => {
log.debug(action.payload, 'Bulk download requested');
// If we have an item name, we are processing the bulk download locally and should use it as the toast id to
@@ -33,7 +27,7 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
startAppListening({
matcher: imagesApi.endpoints.bulkDownloadImages.matchRejected,
effect: async () => {
effect: () => {
log.debug('Bulk download request failed');
// There isn't any toast to update if we get this event.
@@ -44,55 +38,4 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
});
},
});
startAppListening({
actionCreator: socketBulkDownloadStarted,
effect: async (action) => {
// This should always happen immediately after the bulk download request, so we don't need to show a toast here.
log.debug(action.payload.data, 'Bulk download preparation started');
},
});
startAppListening({
actionCreator: socketBulkDownloadComplete,
effect: async (action) => {
log.debug(action.payload.data, 'Bulk download preparation completed');
const { bulk_download_item_name } = action.payload.data;
// TODO(psyche): This URL may break in in some environments (e.g. Nvidia workbench) but we need to test it first
const url = `/api/v1/images/download/${bulk_download_item_name}`;
toast({
id: bulk_download_item_name,
title: t('gallery.bulkDownloadReady', 'Download ready'),
status: 'success',
description: (
<ExternalLink
label={t('gallery.clickToDownload', 'Click here to download')}
href={url}
download={bulk_download_item_name}
/>
),
duration: null,
});
},
});
startAppListening({
actionCreator: socketBulkDownloadError,
effect: async (action) => {
log.debug(action.payload.data, 'Bulk download preparation failed');
const { bulk_download_item_name } = action.payload.data;
toast({
id: bulk_download_item_name,
title: t('gallery.bulkDownloadFailed'),
status: 'error',
description: action.payload.data.error,
duration: null,
});
},
});
};

View File

@@ -1,38 +0,0 @@
import { $logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { canvasCopiedToClipboard } from 'features/canvas/store/actions';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { copyBlobToClipboard } from 'features/system/util/copyBlobToClipboard';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
export const addCanvasCopiedToClipboardListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: canvasCopiedToClipboard,
effect: async (action, { getState }) => {
const moduleLog = $logger.get().child({ namespace: 'canvasCopiedToClipboardListener' });
const state = getState();
try {
const blob = getBaseLayerBlob(state);
copyBlobToClipboard(blob);
} catch (err) {
moduleLog.error(String(err));
toast({
id: 'CANVAS_COPY_FAILED',
title: t('toast.problemCopyingCanvas'),
description: t('toast.problemCopyingCanvasDesc'),
status: 'error',
});
return;
}
toast({
id: 'CANVAS_COPY_SUCCEEDED',
title: t('toast.canvasCopiedClipboard'),
status: 'success',
});
},
});
};

View File

@@ -1,34 +0,0 @@
import { $logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { canvasDownloadedAsImage } from 'features/canvas/store/actions';
import { downloadBlob } from 'features/canvas/util/downloadBlob';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
export const addCanvasDownloadedAsImageListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: canvasDownloadedAsImage,
effect: async (action, { getState }) => {
const moduleLog = $logger.get().child({ namespace: 'canvasSavedToGalleryListener' });
const state = getState();
let blob;
try {
blob = await getBaseLayerBlob(state);
} catch (err) {
moduleLog.error(String(err));
toast({
id: 'CANVAS_DOWNLOAD_FAILED',
title: t('toast.problemDownloadingCanvas'),
description: t('toast.problemDownloadingCanvasDesc'),
status: 'error',
});
return;
}
downloadBlob(blob, 'canvas.png');
toast({ id: 'CANVAS_DOWNLOAD_SUCCEEDED', title: t('toast.canvasDownloaded'), status: 'success' });
},
});
};

View File

@@ -1,60 +0,0 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { canvasImageToControlAdapter } from 'features/canvas/store/actions';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { controlAdapterImageChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
export const addCanvasImageToControlNetListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: canvasImageToControlAdapter,
effect: async (action, { dispatch, getState }) => {
const log = logger('canvas');
const state = getState();
const { id } = action.payload;
let blob: Blob;
try {
blob = await getBaseLayerBlob(state, true);
} catch (err) {
log.error(String(err));
toast({
id: 'PROBLEM_SAVING_CANVAS',
title: t('toast.problemSavingCanvas'),
description: t('toast.problemSavingCanvasDesc'),
status: 'error',
});
return;
}
const { autoAddBoardId } = state.gallery;
const imageDTO = await dispatch(
imagesApi.endpoints.uploadImage.initiate({
file: new File([blob], 'savedCanvas.png', {
type: 'image/png',
}),
image_category: 'control',
is_intermediate: true,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
crop_visible: false,
postUploadAction: {
type: 'TOAST',
title: t('toast.canvasSentControlnetAssets'),
},
})
).unwrap();
const { image_name } = imageDTO;
dispatch(
controlAdapterImageChanged({
id,
controlImage: image_name,
})
);
},
});
};

View File

@@ -1,60 +0,0 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { canvasMaskSavedToGallery } from 'features/canvas/store/actions';
import { getCanvasData } from 'features/canvas/util/getCanvasData';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
export const addCanvasMaskSavedToGalleryListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: canvasMaskSavedToGallery,
effect: async (action, { dispatch, getState }) => {
const log = logger('canvas');
const state = getState();
const canvasBlobsAndImageData = await getCanvasData(
state.canvas.layerState,
state.canvas.boundingBoxCoordinates,
state.canvas.boundingBoxDimensions,
state.canvas.isMaskEnabled,
state.canvas.shouldPreserveMaskedArea
);
if (!canvasBlobsAndImageData) {
return;
}
const { maskBlob } = canvasBlobsAndImageData;
if (!maskBlob) {
log.error('Problem getting mask layer blob');
toast({
id: 'PROBLEM_SAVING_MASK',
title: t('toast.problemSavingMask'),
description: t('toast.problemSavingMaskDesc'),
status: 'error',
});
return;
}
const { autoAddBoardId } = state.gallery;
dispatch(
imagesApi.endpoints.uploadImage.initiate({
file: new File([maskBlob], 'canvasMaskImage.png', {
type: 'image/png',
}),
image_category: 'mask',
is_intermediate: false,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
crop_visible: true,
postUploadAction: {
type: 'TOAST',
title: t('toast.maskSavedAssets'),
},
})
);
},
});
};

View File

@@ -1,70 +0,0 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { canvasMaskToControlAdapter } from 'features/canvas/store/actions';
import { getCanvasData } from 'features/canvas/util/getCanvasData';
import { controlAdapterImageChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
export const addCanvasMaskToControlNetListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: canvasMaskToControlAdapter,
effect: async (action, { dispatch, getState }) => {
const log = logger('canvas');
const state = getState();
const { id } = action.payload;
const canvasBlobsAndImageData = await getCanvasData(
state.canvas.layerState,
state.canvas.boundingBoxCoordinates,
state.canvas.boundingBoxDimensions,
state.canvas.isMaskEnabled,
state.canvas.shouldPreserveMaskedArea
);
if (!canvasBlobsAndImageData) {
return;
}
const { maskBlob } = canvasBlobsAndImageData;
if (!maskBlob) {
log.error('Problem getting mask layer blob');
toast({
id: 'PROBLEM_IMPORTING_MASK',
title: t('toast.problemImportingMask'),
description: t('toast.problemImportingMaskDesc'),
status: 'error',
});
return;
}
const { autoAddBoardId } = state.gallery;
const imageDTO = await dispatch(
imagesApi.endpoints.uploadImage.initiate({
file: new File([maskBlob], 'canvasMaskImage.png', {
type: 'image/png',
}),
image_category: 'mask',
is_intermediate: true,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
crop_visible: false,
postUploadAction: {
type: 'TOAST',
title: t('toast.maskSentControlnetAssets'),
},
})
).unwrap();
const { image_name } = imageDTO;
dispatch(
controlAdapterImageChanged({
id,
controlImage: image_name,
})
);
},
});
};

View File

@@ -1,73 +0,0 @@
import { $logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { canvasMerged } from 'features/canvas/store/actions';
import { $canvasBaseLayer } from 'features/canvas/store/canvasNanostore';
import { setMergedCanvas } from 'features/canvas/store/canvasSlice';
import { getFullBaseLayerBlob } from 'features/canvas/util/getFullBaseLayerBlob';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
export const addCanvasMergedListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: canvasMerged,
effect: async (action, { dispatch }) => {
const moduleLog = $logger.get().child({ namespace: 'canvasCopiedToClipboardListener' });
const blob = await getFullBaseLayerBlob();
if (!blob) {
moduleLog.error('Problem getting base layer blob');
toast({
id: 'PROBLEM_MERGING_CANVAS',
title: t('toast.problemMergingCanvas'),
description: t('toast.problemMergingCanvasDesc'),
status: 'error',
});
return;
}
const canvasBaseLayer = $canvasBaseLayer.get();
if (!canvasBaseLayer) {
moduleLog.error('Problem getting canvas base layer');
toast({
id: 'PROBLEM_MERGING_CANVAS',
title: t('toast.problemMergingCanvas'),
description: t('toast.problemMergingCanvasDesc'),
status: 'error',
});
return;
}
const baseLayerRect = canvasBaseLayer.getClientRect({
relativeTo: canvasBaseLayer.getParent() ?? undefined,
});
const imageDTO = await dispatch(
imagesApi.endpoints.uploadImage.initiate({
file: new File([blob], 'mergedCanvas.png', {
type: 'image/png',
}),
image_category: 'general',
is_intermediate: true,
postUploadAction: {
type: 'TOAST',
title: t('toast.canvasMerged'),
},
})
).unwrap();
// TODO: I can't figure out how to do the type narrowing in the `take()` so just brute forcing it here
const { image_name } = imageDTO;
dispatch(
setMergedCanvas({
kind: 'image',
layer: 'base',
imageName: image_name,
...baseLayerRect,
})
);
},
});
};

View File

@@ -1,53 +0,0 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { parseify } from 'common/util/serialize';
import { canvasSavedToGallery } from 'features/canvas/store/actions';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
export const addCanvasSavedToGalleryListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: canvasSavedToGallery,
effect: async (action, { dispatch, getState }) => {
const log = logger('canvas');
const state = getState();
let blob;
try {
blob = await getBaseLayerBlob(state);
} catch (err) {
log.error(String(err));
toast({
id: 'CANVAS_SAVE_FAILED',
title: t('toast.problemSavingCanvas'),
description: t('toast.problemSavingCanvasDesc'),
status: 'error',
});
return;
}
const { autoAddBoardId } = state.gallery;
dispatch(
imagesApi.endpoints.uploadImage.initiate({
file: new File([blob], 'savedCanvas.png', {
type: 'image/png',
}),
image_category: 'general',
is_intermediate: false,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
crop_visible: true,
postUploadAction: {
type: 'TOAST',
title: t('toast.canvasSavedGallery'),
},
metadata: {
_canvas_objects: parseify(state.canvas.layerState.objects),
},
})
);
},
});
};

View File

@@ -1,194 +0,0 @@
import { isAnyOf } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppDispatch } from 'app/store/store';
import { parseify } from 'common/util/serialize';
import {
caLayerImageChanged,
caLayerModelChanged,
caLayerProcessedImageChanged,
caLayerProcessorConfigChanged,
caLayerProcessorPendingBatchIdChanged,
caLayerRecalled,
isControlAdapterLayer,
} from 'features/controlLayers/store/controlLayersSlice';
import { CA_PROCESSOR_DATA } from 'features/controlLayers/util/controlAdapters';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { isEqual } from 'lodash-es';
import { getImageDTO } from 'services/api/endpoints/images';
import { queueApi } from 'services/api/endpoints/queue';
import type { BatchConfig } from 'services/api/types';
import { socketInvocationComplete } from 'services/events/actions';
import { assert } from 'tsafe';
const matcher = isAnyOf(
caLayerImageChanged,
caLayerProcessedImageChanged,
caLayerProcessorConfigChanged,
caLayerModelChanged,
caLayerRecalled
);
const DEBOUNCE_MS = 300;
const log = logger('session');
/**
* Simple helper to cancel a batch and reset the pending batch ID
*/
const cancelProcessorBatch = async (dispatch: AppDispatch, layerId: string, batchId: string) => {
const req = dispatch(queueApi.endpoints.cancelByBatchIds.initiate({ batch_ids: [batchId] }));
log.trace({ batchId }, 'Cancelling existing preprocessor batch');
try {
await req.unwrap();
} catch {
// no-op
} finally {
req.reset();
// Always reset the pending batch ID - the cancel req could fail if the batch doesn't exist
dispatch(caLayerProcessorPendingBatchIdChanged({ layerId, batchId: null }));
}
};
export const addControlAdapterPreprocessor = (startAppListening: AppStartListening) => {
startAppListening({
matcher,
effect: async (action, { dispatch, getState, getOriginalState, cancelActiveListeners, delay, take, signal }) => {
const layerId = caLayerRecalled.match(action) ? action.payload.id : action.payload.layerId;
const state = getState();
const originalState = getOriginalState();
// Cancel any in-progress instances of this listener
cancelActiveListeners();
log.trace('Control Layer CA auto-process triggered');
// Delay before starting actual work
await delay(DEBOUNCE_MS);
const layer = state.controlLayers.present.layers.filter(isControlAdapterLayer).find((l) => l.id === layerId);
if (!layer) {
return;
}
// We should only process if the processor settings or image have changed
const originalLayer = originalState.controlLayers.present.layers
.filter(isControlAdapterLayer)
.find((l) => l.id === layerId);
const originalImage = originalLayer?.controlAdapter.image;
const originalConfig = originalLayer?.controlAdapter.processorConfig;
const image = layer.controlAdapter.image;
const processedImage = layer.controlAdapter.processedImage;
const config = layer.controlAdapter.processorConfig;
if (isEqual(config, originalConfig) && isEqual(image, originalImage) && processedImage) {
// Neither config nor image have changed, we can bail
return;
}
if (!image || !config) {
// - If we have no image, we have nothing to process
// - If we have no processor config, we have nothing to process
// Clear the processed image and bail
dispatch(caLayerProcessedImageChanged({ layerId, imageDTO: null }));
return;
}
// At this point, the user has stopped fiddling with the processor settings and there is a processor selected.
// If there is a pending processor batch, cancel it.
if (layer.controlAdapter.processorPendingBatchId) {
cancelProcessorBatch(dispatch, layerId, layer.controlAdapter.processorPendingBatchId);
}
// TODO(psyche): I can't get TS to be happy, it thinkgs `config` is `never` but it should be inferred from the generic... I'll just cast it for now
const processorNode = CA_PROCESSOR_DATA[config.type].buildNode(image, config as never);
const enqueueBatchArg: BatchConfig = {
prepend: true,
batch: {
graph: {
nodes: {
[processorNode.id]: {
...processorNode,
// Control images are always intermediate - do not save to gallery
is_intermediate: true,
},
},
edges: [],
},
runs: 1,
},
};
// Kick off the processor batch
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(enqueueBatchArg, {
fixedCacheKey: 'enqueueBatch',
})
);
try {
const enqueueResult = await req.unwrap();
// TODO(psyche): Update the pydantic models, pretty sure we will _always_ have a batch_id here, but the model says it's optional
assert(enqueueResult.batch.batch_id, 'Batch ID not returned from queue');
dispatch(caLayerProcessorPendingBatchIdChanged({ layerId, batchId: enqueueResult.batch.batch_id }));
log.debug({ enqueueResult: parseify(enqueueResult) }, t('queue.graphQueued'));
// Wait for the processor node to complete
const [invocationCompleteAction] = await take(
(action): action is ReturnType<typeof socketInvocationComplete> =>
socketInvocationComplete.match(action) &&
action.payload.data.batch_id === enqueueResult.batch.batch_id &&
action.payload.data.invocation_source_id === processorNode.id
);
// We still have to check the output type
assert(
invocationCompleteAction.payload.data.result.type === 'image_output',
`Processor did not return an image output, got: ${invocationCompleteAction.payload.data.result}`
);
const { image_name } = invocationCompleteAction.payload.data.result.image;
const imageDTO = await getImageDTO(image_name);
assert(imageDTO, "Failed to fetch processor output's image DTO");
// Whew! We made it. Update the layer with the processed image
log.debug({ layerId, imageDTO }, 'ControlNet image processed');
dispatch(caLayerProcessedImageChanged({ layerId, imageDTO }));
dispatch(caLayerProcessorPendingBatchIdChanged({ layerId, batchId: null }));
} catch (error) {
if (signal.aborted) {
// The listener was canceled - we need to cancel the pending processor batch, if there is one (could have changed by now).
const pendingBatchId = getState()
.controlLayers.present.layers.filter(isControlAdapterLayer)
.find((l) => l.id === layerId)?.controlAdapter.processorPendingBatchId;
if (pendingBatchId) {
cancelProcessorBatch(dispatch, layerId, pendingBatchId);
}
log.trace('Control Adapter preprocessor cancelled');
} else {
// Some other error condition...
log.error({ enqueueBatchArg: parseify(enqueueBatchArg) }, t('queue.graphFailedToQueue'));
if (error instanceof Object) {
if ('data' in error && 'status' in error) {
if (error.status === 403) {
dispatch(caLayerImageChanged({ layerId, imageDTO: null }));
return;
}
}
}
toast({
id: 'GRAPH_QUEUE_FAILED',
title: t('queue.graphFailedToQueue'),
status: 'error',
});
}
} finally {
req.reset();
}
},
});
};

View File

@@ -1,85 +0,0 @@
import type { AnyListenerPredicate } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { RootState } from 'app/store/store';
import { controlAdapterImageProcessed } from 'features/controlAdapters/store/actions';
import {
controlAdapterAutoConfigToggled,
controlAdapterImageChanged,
controlAdapterModelChanged,
controlAdapterProcessorParamsChanged,
controlAdapterProcessortTypeChanged,
selectControlAdapterById,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
type AnyControlAdapterParamChangeAction =
| ReturnType<typeof controlAdapterProcessorParamsChanged>
| ReturnType<typeof controlAdapterModelChanged>
| ReturnType<typeof controlAdapterImageChanged>
| ReturnType<typeof controlAdapterProcessortTypeChanged>
| ReturnType<typeof controlAdapterAutoConfigToggled>;
const predicate: AnyListenerPredicate<RootState> = (action, state, prevState) => {
const isActionMatched =
controlAdapterProcessorParamsChanged.match(action) ||
controlAdapterModelChanged.match(action) ||
controlAdapterImageChanged.match(action) ||
controlAdapterProcessortTypeChanged.match(action) ||
controlAdapterAutoConfigToggled.match(action);
if (!isActionMatched) {
return false;
}
const { id } = action.payload;
const prevCA = selectControlAdapterById(prevState.controlAdapters, id);
const ca = selectControlAdapterById(state.controlAdapters, id);
if (!prevCA || !isControlNetOrT2IAdapter(prevCA) || !ca || !isControlNetOrT2IAdapter(ca)) {
return false;
}
if (controlAdapterAutoConfigToggled.match(action)) {
// do not process if the user just disabled auto-config
if (prevCA.shouldAutoConfig === true) {
return false;
}
}
const { controlImage, processorType, shouldAutoConfig } = ca;
if (controlAdapterModelChanged.match(action) && !shouldAutoConfig) {
// do not process if the action is a model change but the processor settings are dirty
return false;
}
const isProcessorSelected = processorType !== 'none';
const hasControlImage = Boolean(controlImage);
return isProcessorSelected && hasControlImage;
};
const DEBOUNCE_MS = 300;
/**
* Listener that automatically processes a ControlNet image when its processor parameters are changed.
*
* The network request is debounced.
*/
export const addControlNetAutoProcessListener = (startAppListening: AppStartListening) => {
startAppListening({
predicate,
effect: async (action, { dispatch, cancelActiveListeners, delay }) => {
const log = logger('session');
const { id } = (action as AnyControlAdapterParamChangeAction).payload;
// Cancel any in-progress instances of this listener
cancelActiveListeners();
log.trace('ControlNet auto-process triggered');
// Delay before starting actual work
await delay(DEBOUNCE_MS);
dispatch(controlAdapterImageProcessed({ id }));
},
});
};

View File

@@ -1,118 +0,0 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { parseify } from 'common/util/serialize';
import { controlAdapterImageProcessed } from 'features/controlAdapters/store/actions';
import {
controlAdapterImageChanged,
controlAdapterProcessedImageChanged,
pendingControlImagesCleared,
selectControlAdapterById,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
import { queueApi } from 'services/api/endpoints/queue';
import type { BatchConfig, ImageDTO } from 'services/api/types';
import { socketInvocationComplete } from 'services/events/actions';
export const addControlNetImageProcessedListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: controlAdapterImageProcessed,
effect: async (action, { dispatch, getState, take }) => {
const log = logger('session');
const { id } = action.payload;
const ca = selectControlAdapterById(getState().controlAdapters, id);
if (!ca?.controlImage || !isControlNetOrT2IAdapter(ca)) {
log.error('Unable to process ControlNet image');
return;
}
if (ca.processorType === 'none' || ca.processorNode.type === 'none') {
return;
}
// ControlNet one-off procressing graph is just the processor node, no edges.
// Also we need to grab the image.
const nodeId = ca.processorNode.id;
const enqueueBatchArg: BatchConfig = {
prepend: true,
batch: {
graph: {
nodes: {
[ca.processorNode.id]: {
...ca.processorNode,
is_intermediate: true,
use_cache: false,
image: { image_name: ca.controlImage },
},
},
edges: [],
},
runs: 1,
},
};
try {
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(enqueueBatchArg, {
fixedCacheKey: 'enqueueBatch',
})
);
const enqueueResult = await req.unwrap();
req.reset();
log.debug({ enqueueResult: parseify(enqueueResult) }, t('queue.graphQueued'));
const [invocationCompleteAction] = await take(
(action): action is ReturnType<typeof socketInvocationComplete> =>
socketInvocationComplete.match(action) &&
action.payload.data.batch_id === enqueueResult.batch.batch_id &&
action.payload.data.invocation_source_id === nodeId
);
// We still have to check the output type
if (invocationCompleteAction.payload.data.result.type === 'image_output') {
const { image_name } = invocationCompleteAction.payload.data.result.image;
// Wait for the ImageDTO to be received
const [{ payload }] = await take(
(action) =>
imagesApi.endpoints.getImageDTO.matchFulfilled(action) && action.payload.image_name === image_name
);
const processedControlImage = payload as ImageDTO;
log.debug({ controlNetId: action.payload, processedControlImage }, 'ControlNet image processed');
// Update the processed image in the store
dispatch(
controlAdapterProcessedImageChanged({
id,
processedControlImage: processedControlImage.image_name,
})
);
}
} catch (error) {
log.error({ enqueueBatchArg: parseify(enqueueBatchArg) }, t('queue.graphFailedToQueue'));
if (error instanceof Object) {
if ('data' in error && 'status' in error) {
if (error.status === 403) {
dispatch(pendingControlImagesCleared());
dispatch(controlAdapterImageChanged({ id, controlImage: null }));
return;
}
}
}
toast({
id: 'GRAPH_QUEUE_FAILED',
title: t('queue.graphFailedToQueue'),
status: 'error',
});
}
},
});
};

View File

@@ -1,144 +0,0 @@
import { logger } from 'app/logging/logger';
import { enqueueRequested } from 'app/store/actions';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
import { parseify } from 'common/util/serialize';
import { canvasBatchIdAdded, stagingAreaInitialized } from 'features/canvas/store/canvasSlice';
import { blobToDataURL } from 'features/canvas/util/blobToDataURL';
import { getCanvasData } from 'features/canvas/util/getCanvasData';
import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGenerationMode';
import { canvasGraphBuilt } from 'features/nodes/store/actions';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildCanvasGraph } from 'features/nodes/util/graph/canvas/buildCanvasGraph';
import { imagesApi } from 'services/api/endpoints/images';
import { queueApi } from 'services/api/endpoints/queue';
import type { ImageDTO } from 'services/api/types';
/**
* This listener is responsible invoking the canvas. This involves a number of steps:
*
* 1. Generate image blobs from the canvas layers
* 2. Determine the generation mode from the layers (txt2img, img2img, inpaint)
* 3. Build the canvas graph
* 4. Create the session with the graph
* 5. Upload the init image if necessary
* 6. Upload the mask image if necessary
* 7. Update the init and mask images with the session ID
* 8. Initialize the staging area if not yet initialized
* 9. Dispatch the sessionReadyToInvoke action to invoke the session
*/
export const addEnqueueRequestedCanvasListener = (startAppListening: AppStartListening) => {
startAppListening({
predicate: (action): action is ReturnType<typeof enqueueRequested> =>
enqueueRequested.match(action) && action.payload.tabName === 'canvas',
effect: async (action, { getState, dispatch }) => {
const log = logger('queue');
const { prepend } = action.payload;
const state = getState();
const { layerState, boundingBoxCoordinates, boundingBoxDimensions, isMaskEnabled, shouldPreserveMaskedArea } =
state.canvas;
// Build canvas blobs
const canvasBlobsAndImageData = await getCanvasData(
layerState,
boundingBoxCoordinates,
boundingBoxDimensions,
isMaskEnabled,
shouldPreserveMaskedArea
);
if (!canvasBlobsAndImageData) {
log.error('Unable to create canvas data');
return;
}
const { baseBlob, baseImageData, maskBlob, maskImageData } = canvasBlobsAndImageData;
// Determine the generation mode
const generationMode = getCanvasGenerationMode(baseImageData, maskImageData);
if (state.system.enableImageDebugging) {
const baseDataURL = await blobToDataURL(baseBlob);
const maskDataURL = await blobToDataURL(maskBlob);
openBase64ImageInTab([
{ base64: maskDataURL, caption: 'mask b64' },
{ base64: baseDataURL, caption: 'image b64' },
]);
}
log.debug(`Generation mode: ${generationMode}`);
// Temp placeholders for the init and mask images
let canvasInitImage: ImageDTO | undefined;
let canvasMaskImage: ImageDTO | undefined;
// For img2img and inpaint/outpaint, we need to upload the init images
if (['img2img', 'inpaint', 'outpaint'].includes(generationMode)) {
// upload the image, saving the request id
canvasInitImage = await dispatch(
imagesApi.endpoints.uploadImage.initiate({
file: new File([baseBlob], 'canvasInitImage.png', {
type: 'image/png',
}),
image_category: 'general',
is_intermediate: true,
})
).unwrap();
}
// For inpaint/outpaint, we also need to upload the mask layer
if (['inpaint', 'outpaint'].includes(generationMode)) {
// upload the image, saving the request id
canvasMaskImage = await dispatch(
imagesApi.endpoints.uploadImage.initiate({
file: new File([maskBlob], 'canvasMaskImage.png', {
type: 'image/png',
}),
image_category: 'mask',
is_intermediate: true,
})
).unwrap();
}
const graph = await buildCanvasGraph(state, generationMode, canvasInitImage, canvasMaskImage);
log.debug({ graph: parseify(graph) }, `Canvas graph built`);
// currently this action is just listened to for logging
dispatch(canvasGraphBuilt(graph));
const batchConfig = prepareLinearUIBatch(state, graph, prepend);
try {
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
fixedCacheKey: 'enqueueBatch',
})
);
const enqueueResult = await req.unwrap();
req.reset();
const batchId = enqueueResult.batch.batch_id as string; // we know the is a string, backend provides it
// Prep the canvas staging area if it is not yet initialized
if (!state.canvas.layerState.stagingArea.boundingBox) {
dispatch(
stagingAreaInitialized({
boundingBox: {
...state.canvas.boundingBoxCoordinates,
...state.canvas.boundingBoxDimensions,
},
})
);
}
// Associate the session with the canvas session ID
dispatch(canvasBatchIdAdded(batchId));
} catch {
// no-op
}
},
});
};

View File

@@ -1,10 +1,21 @@
import { logger } from 'app/logging/logger';
import { enqueueRequested } from 'app/store/actions';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
import type { SerializableObject } from 'common/types';
import type { Result } from 'common/util/result';
import { isErr, withResult, withResultAsync } from 'common/util/result';
import { $canvasManager } from 'features/controlLayers/konva/CanvasManager';
import { sessionStagingAreaReset, sessionStartedStaging } from 'features/controlLayers/store/canvasSessionSlice';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildGenerationTabGraph } from 'features/nodes/util/graph/generation/buildGenerationTabGraph';
import { buildGenerationTabSDXLGraph } from 'features/nodes/util/graph/generation/buildGenerationTabSDXLGraph';
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { serializeError } from 'serialize-error';
import { queueApi } from 'services/api/endpoints/queue';
import type { Invocation } from 'services/api/types';
import { assert } from 'tsafe';
const log = logger('generation');
export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) => {
startAppListening({
@@ -12,33 +23,81 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
enqueueRequested.match(action) && action.payload.tabName === 'generation',
effect: async (action, { getState, dispatch }) => {
const state = getState();
const { shouldShowProgressInViewer } = state.ui;
const model = state.generation.model;
const model = state.params.model;
const { prepend } = action.payload;
let graph;
const manager = $canvasManager.get();
assert(manager, 'No model found in state');
if (model?.base === 'sdxl') {
graph = await buildGenerationTabSDXLGraph(state);
} else {
graph = await buildGenerationTabGraph(state);
let didStartStaging = false;
if (!state.canvasSession.isStaging && state.canvasSession.sendToCanvas) {
dispatch(sessionStartedStaging());
didStartStaging = true;
}
const batchConfig = prepareLinearUIBatch(state, graph, prepend);
const abortStaging = () => {
if (didStartStaging && getState().canvasSession.isStaging) {
dispatch(sessionStagingAreaReset());
}
};
let buildGraphResult: Result<
{ g: Graph; noise: Invocation<'noise'>; posCond: Invocation<'compel' | 'sdxl_compel_prompt'> },
Error
>;
assert(model, 'No model found in state');
const base = model.base;
switch (base) {
case 'sdxl':
buildGraphResult = await withResultAsync(() => buildSDXLGraph(state, manager));
break;
case 'sd-1':
case `sd-2`:
buildGraphResult = await withResultAsync(() => buildSD1Graph(state, manager));
break;
default:
assert(false, `No graph builders for base ${base}`);
}
if (isErr(buildGraphResult)) {
log.error({ error: serializeError(buildGraphResult.error) }, 'Failed to build graph');
abortStaging();
return;
}
const { g, noise, posCond } = buildGraphResult.value;
const destination = state.canvasSession.sendToCanvas ? 'canvas' : 'gallery';
const prepareBatchResult = withResult(() =>
prepareLinearUIBatch(state, g, prepend, noise, posCond, 'generation', destination)
);
if (isErr(prepareBatchResult)) {
log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
abortStaging();
return;
}
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
queueApi.endpoints.enqueueBatch.initiate(prepareBatchResult.value, {
fixedCacheKey: 'enqueueBatch',
})
);
try {
await req.unwrap();
if (shouldShowProgressInViewer) {
dispatch(isImageViewerOpenChanged(true));
}
} finally {
req.reset();
req.reset();
const enqueueResult = await withResultAsync(() => req.unwrap());
if (isErr(enqueueResult)) {
log.error({ error: serializeError(enqueueResult.error) }, 'Failed to enqueue batch');
abortStaging();
return;
}
log.debug({ batchConfig: prepareBatchResult.value } as SerializableObject, 'Enqueued batch');
},
});
};

View File

@@ -1,5 +1,6 @@
import { enqueueRequested } from 'app/store/actions';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectNodesSlice } from 'features/nodes/store/selectors';
import { buildNodesGraph } from 'features/nodes/util/graph/buildNodesGraph';
import { buildWorkflowWithValidation } from 'features/nodes/util/workflow/buildWorkflow';
import { queueApi } from 'services/api/endpoints/queue';
@@ -11,12 +12,12 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
enqueueRequested.match(action) && action.payload.tabName === 'workflows',
effect: async (action, { getState, dispatch }) => {
const state = getState();
const { nodes, edges } = state.nodes.present;
const nodes = selectNodesSlice(state);
const workflow = state.workflow;
const graph = buildNodesGraph(state.nodes.present);
const graph = buildNodesGraph(nodes);
const builtWorkflow = buildWorkflowWithValidation({
nodes,
edges,
nodes: nodes.nodes,
edges: nodes.edges,
workflow,
});
@@ -29,7 +30,9 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
batch: {
graph,
workflow: builtWorkflow,
runs: state.generation.iterations,
runs: state.params.iterations,
origin: 'workflows',
destination: 'gallery',
},
prepend: action.payload.prepend,
};

View File

@@ -14,9 +14,9 @@ export const addEnqueueRequestedUpscale = (startAppListening: AppStartListening)
const { shouldShowProgressInViewer } = state.ui;
const { prepend } = action.payload;
const graph = await buildMultidiffusionUpscaleGraph(state);
const { g, noise, posCond } = await buildMultidiffusionUpscaleGraph(state);
const batchConfig = prepareLinearUIBatch(state, graph, prepend);
const batchConfig = prepareLinearUIBatch(state, g, prepend, noise, posCond, 'upscaling', 'gallery');
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {

View File

@@ -27,7 +27,7 @@ export const galleryImageClicked = createAction<{
export const addGalleryImageClickedListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: galleryImageClicked,
effect: async (action, { dispatch, getState }) => {
effect: (action, { dispatch, getState }) => {
const { imageDTO, shiftKey, ctrlKey, metaKey, altKey } = action.payload;
const state = getState();
const queryArgs = selectListImagesQueryArgs(state);

View File

@@ -1,24 +1,27 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { SerializableObject } from 'common/types';
import { parseify } from 'common/util/serialize';
import { $templates } from 'features/nodes/store/nodesSlice';
import { parseSchema } from 'features/nodes/util/schema/parseSchema';
import { size } from 'lodash-es';
import { serializeError } from 'serialize-error';
import { appInfoApi } from 'services/api/endpoints/appInfo';
const log = logger('system');
export const addGetOpenAPISchemaListener = (startAppListening: AppStartListening) => {
startAppListening({
matcher: appInfoApi.endpoints.getOpenAPISchema.matchFulfilled,
effect: (action, { getState }) => {
const log = logger('system');
const schemaJSON = action.payload;
log.debug({ schemaJSON: parseify(schemaJSON) }, 'Received OpenAPI schema');
log.debug({ schemaJSON: parseify(schemaJSON) } as SerializableObject, 'Received OpenAPI schema');
const { nodesAllowlist, nodesDenylist } = getState().config;
const nodeTemplates = parseSchema(schemaJSON, nodesAllowlist, nodesDenylist);
log.debug({ nodeTemplates: parseify(nodeTemplates) }, `Built ${size(nodeTemplates)} node templates`);
log.debug({ nodeTemplates } as SerializableObject, `Built ${size(nodeTemplates)} node templates`);
$templates.set(nodeTemplates);
},
@@ -30,8 +33,7 @@ export const addGetOpenAPISchemaListener = (startAppListening: AppStartListening
// If action.meta.condition === true, the request was canceled/skipped because another request was in flight or
// the value was already in the cache. We don't want to log these errors.
if (!action.meta.condition) {
const log = logger('system');
log.error({ error: parseify(action.error) }, 'Problem retrieving OpenAPI Schema');
log.error({ error: serializeError(action.error) }, 'Problem retrieving OpenAPI Schema');
}
},
});

Some files were not shown because too many files have changed in this diff Show More