Compare commits

...

1139 Commits

Author SHA1 Message Date
Kent Keirsey
3a14791da3 bria-ui-updates-wip 2025-07-25 12:58:27 -04:00
Ilan Tchenak
711a579945 fixed schema 2025-07-24 17:22:15 +00:00
Ilan Tchenak
1ac5a24a8a ruff fix 2025-07-24 19:10:29 +03:00
Ubuntu
282df322d5 fixed node issue 2025-07-24 10:56:37 -04:00
Ilan Tchenak
8523ea88f2 moved bria's nodes to invocations folder 2025-07-24 10:56:37 -04:00
Ubuntu
cad97d3da3 Small cosmetic fixes 2025-07-24 10:56:37 -04:00
Ubuntu
efc5a762fc removed unused file 2025-07-24 10:56:37 -04:00
Ubuntu
9131c45645 Added scikit-image required for Bria's OpenposeDetector model 2025-07-24 10:56:37 -04:00
Ilan Tchenak
75ca44d5f9 Add Bria text to image model and controlnet support 2025-07-24 10:56:37 -04:00
Ilan Tchenak
8b08af3949 Setup Probe and UI to accept bria controlnet models 2025-07-24 10:56:37 -04:00
Ubuntu
df9ea8dcc1 addded bria nodes for bria3.1 and bria3.2 2025-07-24 10:56:37 -04:00
Ubuntu
25a57326b3 front end support for bria 2025-07-24 10:56:37 -04:00
Ubuntu
7f3e8087ba added support for loading bria transformer 2025-07-24 10:56:37 -04:00
Brandon Rising
dfc7835359 Setup Probe and UI to accept bria main models 2025-07-24 10:56:37 -04:00
psychedelicious
169d58ea4c feat(ui): restore clear queue button
It is accessible in two places:
- The queue actions hamburger menu.
- On the queue tab.

If the clear queue app feature is disabled, it is not shown in either of
those places.
2025-07-23 23:38:53 +10:00
psychedelicious
b53d2250f7 feat(ui): reduce snap tolerance to make it easier to break the snap 2025-07-23 23:05:40 +10:00
psychedelicious
242eea8295 fix(ui): incorrect zoom direction w/ small scroll amounts 2025-07-23 23:05:40 +10:00
psychedelicious
4dabe09e0d tests(ui): remove test for no-longer-valid behaviour 2025-07-23 23:03:02 +10:00
psychedelicious
07fa0d3b77 fix(ui): do not attempt toggle when target panel isn't registered 2025-07-23 23:03:02 +10:00
psychedelicious
e97f82292f tests(ui): add tests for disposable handling 2025-07-23 23:03:02 +10:00
psychedelicious
005bab9035 fix(ui): tab disposables not being added correctly 2025-07-23 23:03:02 +10:00
psychedelicious
409173919c tests(ui): add tests for toggleViewer functionality 2025-07-23 23:03:02 +10:00
psychedelicious
7915180047 feat(ui): restore viewer toggle hotkey 2025-07-23 23:03:02 +10:00
Riccardo Giovanetti
4349b8387d translationBot(ui): update translation (Italian)
Currently translated at 97.9% (2000 of 2042 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2025-07-23 12:26:48 +10:00
Kent Keirsey
f95b686bdc reposition export button 2025-07-23 11:55:11 +10:00
Mary Hipp
72afb9c3fd fix iterations for all API models 2025-07-22 13:27:35 -04:00
Mary Hipp
f004fc31f1 update whats new 2025-07-22 12:24:10 -04:00
psychedelicious
2aa163b3a2 feat(ui): add default inpaint mask layer on canvas reset 2025-07-22 10:26:57 +10:00
psychedelicious
f40900c173 chore: bump version to v6.1.0 2025-07-22 08:24:31 +10:00
psychedelicious
2c1f2b2873 tidy(ui): move star hotkey into own hook & use reactive state for focus 2025-07-22 08:11:57 +10:00
Kent Keirsey
8418e34480 lint 2025-07-22 08:11:57 +10:00
Kent Keirsey
b548ac0ccf Add Star/Unstar Hotkey and fix hotkey translations 2025-07-22 08:11:57 +10:00
Linos
2af2b8b6c4 translationBot(ui): update translation (Vietnamese)
Currently translated at 100.0% (2003 of 2003 strings)

Co-authored-by: Linos <linos.coding@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/vi/
Translation: InvokeAI/Web UI
2025-07-22 07:58:19 +10:00
Hosted Weblate
058dc06748 translationBot(ui): update translation files
Updated by "Cleanup translation files" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI
2025-07-22 07:58:19 +10:00
Riccardo Giovanetti
8acb1c0088 translationBot(ui): update translation (Italian)
Currently translated at 98.7% (1978 of 2003 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.7% (1978 of 2003 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.6% (1968 of 1994 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2025-07-22 07:58:19 +10:00
Hosted Weblate
683732a37c translationBot(ui): update translation files
Updated by "Cleanup translation files" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI
2025-07-22 07:58:19 +10:00
Riku
b990eacca0 translationBot(ui): update translation (German)
Currently translated at 62.1% (1251 of 2012 strings)

Co-authored-by: Riku <riku.block@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/de/
Translation: InvokeAI/Web UI
2025-07-22 07:58:19 +10:00
RyoKoba
5f7e920deb translationBot(ui): update translation (Japanese)
Currently translated at 99.8% (2007 of 2011 strings)

translationBot(ui): update translation (Japanese)

Currently translated at 99.8% (2007 of 2011 strings)

translationBot(ui): update translation (Japanese)

Currently translated at 99.8% (2007 of 2011 strings)

translationBot(ui): update translation (Japanese)

Currently translated at 99.8% (2007 of 2011 strings)

translationBot(ui): update translation (Japanese)

Currently translated at 99.8% (2007 of 2011 strings)

translationBot(ui): update translation (Japanese)

Currently translated at 92.0% (1851 of 2011 strings)

translationBot(ui): update translation (Japanese)

Currently translated at 92.0% (1851 of 2011 strings)

translationBot(ui): update translation (Japanese)

Currently translated at 92.0% (1851 of 2011 strings)

translationBot(ui): update translation (Japanese)

Currently translated at 87.4% (1744 of 1995 strings)

translationBot(ui): update translation (Japanese)

Currently translated at 87.4% (1744 of 1995 strings)

translationBot(ui): update translation (Japanese)

Currently translated at 81.0% (1616 of 1995 strings)

translationBot(ui): update translation (Japanese)

Currently translated at 81.0% (1616 of 1995 strings)

translationBot(ui): update translation (Japanese)

Currently translated at 81.0% (1616 of 1995 strings)

translationBot(ui): update translation (Japanese)

Currently translated at 81.0% (1616 of 1995 strings)

translationBot(ui): update translation (Japanese)

Currently translated at 81.0% (1616 of 1995 strings)

translationBot(ui): update translation (Japanese)

Currently translated at 81.0% (1616 of 1995 strings)

translationBot(ui): update translation (Japanese)

Currently translated at 81.0% (1616 of 1995 strings)

translationBot(ui): update translation (Japanese)

Currently translated at 81.0% (1616 of 1995 strings)

translationBot(ui): update translation (Japanese)

Currently translated at 81.0% (1616 of 1995 strings)

translationBot(ui): update translation (Japanese)

Currently translated at 75.6% (1510 of 1995 strings)

translationBot(ui): update translation (Japanese)

Currently translated at 75.6% (1510 of 1995 strings)

translationBot(ui): update translation (Japanese)

Currently translated at 75.6% (1510 of 1995 strings)

translationBot(ui): update translation (Japanese)

Currently translated at 75.6% (1510 of 1995 strings)

translationBot(ui): update translation (Japanese)

Currently translated at 75.6% (1510 of 1995 strings)

translationBot(ui): update translation (Japanese)

Currently translated at 75.6% (1510 of 1995 strings)

translationBot(ui): update translation (Japanese)

Currently translated at 75.6% (1510 of 1995 strings)

translationBot(ui): update translation (Japanese)

Currently translated at 75.6% (1510 of 1995 strings)

Co-authored-by: RyoKoba <kobayashi_ryo@cyberagent.co.jp>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/ja/
Translation: InvokeAI/Web UI
2025-07-22 07:58:19 +10:00
Riccardo Giovanetti
55dfdc0a9c translationBot(ui): update translation (Italian)
Currently translated at 97.9% (1953 of 1994 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.7% (1986 of 2011 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.7% (1970 of 1995 strings)

translationBot(ui): update translation (Italian)

Currently translated at 97.8% (1910 of 1952 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2025-07-22 07:58:19 +10:00
Linos
10d6d19e17 translationBot(ui): update translation (Vietnamese)
Currently translated at 100.0% (2012 of 2012 strings)

translationBot(ui): update translation (Vietnamese)

Currently translated at 100.0% (2012 of 2012 strings)

translationBot(ui): update translation (Vietnamese)

Currently translated at 99.7% (2006 of 2012 strings)

translationBot(ui): update translation (Vietnamese)

Currently translated at 99.7% (2006 of 2012 strings)

translationBot(ui): update translation (Vietnamese)

Currently translated at 99.5% (2002 of 2012 strings)

translationBot(ui): update translation (Vietnamese)

Currently translated at 99.5% (2002 of 2012 strings)

translationBot(ui): update translation (Vietnamese)

Currently translated at 97.8% (1968 of 2012 strings)

translationBot(ui): update translation (Vietnamese)

Currently translated at 97.8% (1968 of 2012 strings)

translationBot(ui): update translation (Vietnamese)

Currently translated at 97.8% (1968 of 2012 strings)

translationBot(ui): update translation (Vietnamese)

Currently translated at 97.8% (1968 of 2012 strings)

translationBot(ui): update translation (Vietnamese)

Currently translated at 96.4% (1940 of 2012 strings)

translationBot(ui): update translation (Vietnamese)

Currently translated at 96.4% (1940 of 2012 strings)

translationBot(ui): update translation (Vietnamese)

Currently translated at 100.0% (1921 of 1921 strings)

translationBot(ui): update translation (Vietnamese)

Currently translated at 100.0% (1917 of 1917 strings)

Co-authored-by: Linos <linos.coding@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/vi/
Translation: InvokeAI/Web UI
2025-07-22 07:58:19 +10:00
skunkworxdark
15542b954d Fix nodes ui: Make nodes dot background to be the same as the snap to grid size and position
Fix nodes ui:  Make nodes dot background to be the same as the snap to grid size and position
Update to Flow.tsx

Changes the size and offset of the dots background to be the same size as the snap to grid, and also fix the background dot pattern alignment.

Currently, the snapGrid is 25x25, and the default background dot gap is 20x20, these do not align.  This is fixed by making the gap property of the background the same as the snapGrid.

Additionally, there is a bug in the rectFlow background code that incorrectly sets the offset to be the centre of the dot pattern with the default offset of 0.  To work around this issue, setting the background offset property to the snapGrid size will realign the dot pattern correctly. 

I have logged a bug for the rectFlow background issue in its repo. 
https://github.com/xyflow/xyflow/issues/5405
2025-07-22 07:46:52 +10:00
skunkworxdark
6430d830c1 Update nodes auto layout spacing for snap to grid size
Update workflowSettingsSlice.ts

Change the default settings for auto layout nodeSpacing and layerSpacing  to 30 instead of 32.    This will make the x position of auto layed nodes land on the snap to grid positions. 

Because the node width (320) + 30 = 350 which is divisible by the snap to grid size of 25.
2025-07-22 07:40:58 +10:00
Kent Keirsey
c3f6389291 fix ruff and remove unused API route 2025-07-22 07:33:48 +10:00
Kent Keirsey
070eef3eff remove whitespace 2025-07-22 07:33:48 +10:00
Kent Keirsey
b14d841d57 Extract util and fix model image logic 2025-07-22 07:33:48 +10:00
Kent Keirsey
dd35ab026a update logic and remove bad test 2025-07-22 07:33:48 +10:00
Cursor Agent
7fc06db8ad Add LoRA model metadata extraction from JSON and PNG files
Co-authored-by: kent <kent@invoke.ai>
2025-07-22 07:33:48 +10:00
psychedelicious
9d1f09c0f3 fix(ui): return wrapped history in redux-remember unserialize
We intermittently get an error like this:
```
TypeError: Cannot read properties of undefined (reading 'length')
```

This error is caused by a `redux-undo`-enhanced slice being rehydrated
without the extra stuff it adds to the slice to make it undoable (e.g.
an array of `past` states, the `present` state, array of `future`
states, and some other metadata).

`redux-undo` may need to check the length of the past/future arrays as
part of its internal functionality. These keys don't exist so we get the
error. I'm not sure _why_ they don't exist - my understanding of
`redux-undo` is that it should be checking and wrapping the state w/ the
history stuff automatically. Seems to be related to `redux-remember` -
may be a race condition.

The solution is to ensure we wrap rehydrated state for undoable slices
as we rehydrate them. I discovered the solution while troubleshooting
#8314 when the changes therein somehow triggered the issue to start
occuring every time instead of rarely.
2025-07-22 07:00:57 +10:00
skunkworxdark
cacfb183a6 Add auto layout controls to node editor (#8239)
* Add auto layout controls using elkjs to node editor

Introduces auto layout functionality for the node editor using elkjs, including a new UI popover for layout options (placement strategy, layering, spacing, direction). Adds related state and actions to workflowSettingsSlice, updates translations, and ensures elkjs is included in optimized dependencies.

* feat(nodes): Improve workflow auto-layout controls and accuracy

- The auto-layout settings panel is updated to use `Select` dropdowns and `NumberInput`
- The layout algorithm now uses the actual rendered dimensions of nodes from the DOM, falling back to estimates only when necessary. This results in a much more accurate and predictable layout.
- The ELKjs library integration is refactored to fix some warnings

* Update useAutoLayout.ts

prettier

* feat(nodes): Improve workflow auto-layout controls and accuracy

- The auto-layout settings panel is updated to use `Select` dropdowns and `NumberInput`
- The layout algorithm now uses the actual rendered dimensions of nodes from the DOM, falling back to estimates only when necessary. This results in a much more accurate and predictable layout.
- The ELKjs library integration is refactored to fix some warnings

* Update useAutoLayout.ts

prettier

* build(ui): import elkjs directly

* updated to use  dagrejs for autolayout

updated to use dagrejs - it has less layout options but is already included

but this is still WIP as some nodes don't report the height correctly. I am still investigating this...

* Update useAutoLayout.ts

update to fix layout issues

* minor updates

- pretty useAutoLayout.ts
- add missing type import in ViewportControls.tsx
- update pnpm-lock.yaml with elkjs removed

* Update ViewportControls.tsx

pnpm fix

* Fix Frontend check + single node selection fix

Fix Frontend check -  remove unused export from workflowSettingsSlice.ts
Update so that if you have a single node selected, it will auto layout all nodes, as this is a common thing to have a single node selected and means that you don't have to unselect it.

* feat(ui): misc improvements for autolayout

- Split popover into own component
- Add util functions to get node w/h
- Use magic wand icon for button
- Fix sizing of input components
- Use CompositeNumberInput instead of base chakra number input
- Add zod schemas for string values and use them in the component to
ensure state integrity

* chore(ui): lint

---------

Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
2025-07-21 14:44:29 +10:00
psychedelicious
564f4f7a60 feat(ui): better icon for invert mask button 2025-07-21 13:47:02 +10:00
Kent Keirsey
113a118fcf fix potential for null data 2025-07-21 13:47:02 +10:00
Kent Keirsey
1f930cdaf2 fix 2025-07-21 13:47:02 +10:00
Kent Keirsey
c490e0ce08 feat(ui):invert mask 2025-07-21 13:47:02 +10:00
Kent Keirsey
7640ee307c feat(ui):Adjust-bbox-to-masks 2025-07-21 13:26:49 +10:00
psychedelicious
1f5f70f898 feat(ui): clean up picker compact view default state handling
- Name it `pickerCompactViewStates` bc its not exclusive to model
picker, it is used for all pickers
- Rename redux action to model an event
- Move selector to right file
- Use selector to derive state for individual picker
2025-07-21 13:18:09 +10:00
Mary Hipp
1430858112 cleanup 2025-07-21 13:18:09 +10:00
Mary Hipp
48c27ec117 persist model picker compact/expanded state 2025-07-21 13:18:09 +10:00
psychedelicious
af7737e804 fix(ui): context menu on staging area images
There was a subtle issue where the progress image wasn't ever cleared,
preventing the context menu from working on staging area preview images.

The staging area preview images were displaying the last progress image
_on top of_ the result image. Because the image elements were so small,
you wouldn't notice that you were looking at a low-res progress image.
Right clicking a progress image gets you no menu.

If you refresh the page or switch tabs, this would fix itself, because
those actions clear out the progress images. The result image would then
be the topmost element, and the context menu works.

Fixing this without introducing a flash of empty space as the progress
image was hidden required a bit of refactoring. We have to wait for the
result image element to load before clearing out the progress.

Result - progress images appear to "resolve" to result images in the
staging area without any blips or jank, and the context menu works after
that happens.
2025-07-21 13:15:34 +10:00
psychedelicious
3eca0d2ba0 fix(ui): staging area left/right hotkeys 2025-07-18 08:08:15 -04:00
psychedelicious
307259f096 fix(ui): ensure staging area always has the right state and session association 2025-07-18 08:08:15 -04:00
psychedelicious
bed01941a5 fix(ui): ensure we clean up when session id changes 2025-07-18 08:08:15 -04:00
psychedelicious
89fa43a3b6 docs(ui): update StagingAreaApi docstrings 2025-07-18 08:08:15 -04:00
psychedelicious
d8fcb08abf repo: update ignores 2025-07-18 08:08:15 -04:00
psychedelicious
c61bcd9f50 tests(ui): add test suite for StagingAreaApi 2025-07-18 08:08:15 -04:00
psychedelicious
3fb0fcbbfb tidy(ui): move staging area components to correct dir 2025-07-18 08:08:15 -04:00
psychedelicious
db9af5083f tidy(ui): move launchpad components to ui dir 2025-07-18 08:08:15 -04:00
psychedelicious
720f1bb65c chore(ui): rename context2.tsx -> context.tsx 2025-07-18 08:08:15 -04:00
psychedelicious
7dfb318ba2 chore(ui): lint 2025-07-18 08:08:15 -04:00
psychedelicious
9b024da2b4 refactor(ui): move staging area logic out side react
Was running into difficultlies reasoning about the logic and couldn't
write tests because it was all in react.

Moved logic outside react, updated context, make it testable.
2025-07-18 08:08:15 -04:00
psychedelicious
15ca3b727a wip 2025-07-18 08:08:15 -04:00
psychedelicious
74ca604ae0 fix(ui): unstyled error boundary 2025-07-18 08:08:15 -04:00
psychedelicious
6934b05c85 fix(ui): use invocation context provider in inspector panel 2025-07-18 08:08:15 -04:00
psychedelicious
1a47a5317c chore(ui): update dockview to latest
Remove extraneous fix now that the disableDnd issue is resolved upstream
2025-07-18 08:08:15 -04:00
psychedelicious
bc3ef21c64 chore(ui): bump version to v6.1.0rc2 2025-07-18 08:08:15 -04:00
psychedelicious
e329f5ad43 fix(ui): negative style prompt not recorded in metadata 2025-07-18 06:41:21 +10:00
psychedelicious
e6ad91bf89 chore(ui): update prettier config 2025-07-17 22:04:57 +10:00
psychedelicious
2f586416a5 chore(ui): remove unused pkgs 2025-07-17 22:04:57 +10:00
psychedelicious
33b56f421c chore(ui): lint 2025-07-17 22:04:57 +10:00
psychedelicious
e58ee4c492 chore(ui): upgrade zod 2025-07-17 22:04:57 +10:00
psychedelicious
49691aa07e chore(ui): upgrade rollup vis 2025-07-17 22:04:57 +10:00
psychedelicious
56570f235f chore(ui): actually upgrade storybook 2025-07-17 22:04:57 +10:00
psychedelicious
a2d95cf5b6 chore(ui): upgrade minor bump packages 2025-07-17 22:04:57 +10:00
psychedelicious
704dbfd04a chore(ui): upgrade storybook 2025-07-17 22:04:57 +10:00
psychedelicious
5d9e078043 chore(ui): finish eslint v9 migration 2025-07-17 22:04:57 +10:00
psychedelicious
875cde13ae chore(ui): migrate to eslint v9 (wip) 2025-07-17 22:04:57 +10:00
psychedelicious
77655aed86 chore(ui): update eslint config 2025-07-17 22:04:57 +10:00
psychedelicious
0628b92d63 chore: bump version to v6.1.0rc1 2025-07-17 19:30:38 +10:00
psychedelicious
9e526d00c2 chore(ui): lint 2025-07-17 15:36:24 +10:00
psychedelicious
1a24396be8 feat(ui): styling when nodes have error 2025-07-17 15:36:24 +10:00
psychedelicious
d97e73a565 chore(ui): lint 2025-07-17 15:36:24 +10:00
psychedelicious
55b14c8aaf perf(ui): optimize redux selectors for workflow editor
- Build selectors for each node in a react context so components can
re-use the same selectors
- Cache the selectors in the context
2025-07-17 15:36:24 +10:00
psychedelicious
79f65e57eb fix(ui): remove unnecessary coalescing operator 2025-07-17 14:21:02 +10:00
Kent Keirsey
b4c8950278 address comments 2025-07-17 14:21:02 +10:00
Kent Keirsey
400b2e9a55 unlint. 2025-07-17 14:21:02 +10:00
Kent Keirsey
3a687c583a lint 2025-07-17 14:21:02 +10:00
Kent Keirsey
833950078d commit tile size controls 2025-07-17 14:21:02 +10:00
Kent Keirsey
e698dcb148 unlint. 2025-07-17 14:21:02 +10:00
Kent Keirsey
218386e077 lint 2025-07-17 14:21:02 +10:00
Kent Keirsey
4426be9e64 commit tile size controls 2025-07-17 14:21:02 +10:00
psychedelicious
86f4cf7857 feat(ui): related embedding styling/tidy 2025-07-17 14:12:29 +10:00
Kent Keirsey
49ae66d94a Added related model support 2025-07-17 14:12:29 +10:00
Cursor Agent
c10865c7ef Reorder embedding options in PromptTriggerSelect component
Co-authored-by: kent <kent@invoke.ai>
2025-07-17 14:12:29 +10:00
psychedelicious
f3478a189a fix(ui): able to drag empty space in tab bar and detach panels 2025-07-17 13:58:32 +10:00
psychedelicious
43db29176a chore(ui): lint 2025-07-17 13:52:24 +10:00
psychedelicious
f38922929c docs(ui): comments in modelsLoaded 2025-07-17 13:52:24 +10:00
psychedelicious
7d02c58f86 fix(ui): move <ParamTileControlNetModel /> to <UpscaleTabAdvancedSettingsAccordion /> 2025-07-17 13:52:24 +10:00
Kent Keirsey
6edce8be87 Add scaling in 2025-07-17 13:52:24 +10:00
Kent Keirsey
31f63e38bd lint 2025-07-17 13:52:24 +10:00
Kent Keirsey
78a68ac3a7 Updated 2025-07-17 13:52:24 +10:00
Kent Keirsey
8cd3bcd1c0 Updates 2025-07-17 13:52:24 +10:00
Cursor Agent
264cc5ef46 Add tile ControlNet model selection to upscale settings
Co-authored-by: kent <kent@invoke.ai>
2025-07-17 13:52:24 +10:00
JPPhoto
8bfbea5ed3 Updated __init__.py 2025-07-17 06:33:56 +10:00
JPPhoto
f06a66da07 Updated schema.ts 2025-07-17 06:33:56 +10:00
Jonathan
337cae9b22 Update __init__.py
Added FluxConditioningField, FluxConditioningCollectionOutput, and FluxConditioningCollectionOutput,
2025-07-17 06:33:56 +10:00
Jonathan
bf926bb7d5 Update primitives.py
Added FluxConditioningCollectionOutput
2025-07-17 06:33:56 +10:00
psychedelicious
18ad9a6af3 feat(ui): canvas/viewer panel tabs show progress 2025-07-17 06:20:05 +10:00
psychedelicious
b6ed31c222 feat(ui): clicking invoke switches to viewer tab instead of canvas when save all images to gallery is enabled 2025-07-17 06:20:05 +10:00
psychedelicious
200beb5af5 feat(ui): make save all images to gallery option also bypass canvas 2025-07-17 06:20:05 +10:00
psychedelicious
f82a948bdd refactor(ui): canvas autoswitch logic
Simplify the canvas auto-switch logic to not rely on the preview images
loading. This fixes an issue where offscreen preview images didn't get
auto-switched to. Images are now loaded directly.
2025-07-17 06:20:05 +10:00
psychedelicious
dd03e3ddcd refactor(ui): simplify canvas session logic 2025-07-17 06:20:05 +10:00
psychedelicious
7561b73e8f fix(ui): uppercase file extensions blocked for image upload
Closes #8284
2025-07-17 00:48:36 +10:00
psychedelicious
caa97608c7 fix(ui): aspect ratios out of order 2025-07-16 23:27:37 +10:00
Mary Hipp
72a6d1edc1 simplify descriptoin styling 2025-07-16 09:19:33 -04:00
Mary Hipp
b8bf89c2f1 add fallback image and make sure description text is legible for model picker noncompact 2025-07-16 09:19:33 -04:00
psychedelicious
a1ade2b8c0 feat(ui): export apis & actions from package 2025-07-16 08:21:03 -04:00
Eugene Brodsky
4bdcae1f8f fix(docker): switch to pnpm10.x 2025-07-15 13:03:15 -04:00
Jonathan
4b22c84407 Update dev-environment.md
Document the latest changes required to build Invoke 6.0.
2025-07-15 15:21:01 +10:00
Eugene Brodsky
c9daf1db30 (fix) remove timeout from image prompt expansion (#8281) 2025-07-14 11:19:20 -04:00
psychedelicious
06d3cfbe97 gh: update bug report template
- Add require drop down for install method
- Make browser version optional
- Link to latest release
- Update verbiage for sys info section
2025-07-14 12:18:52 +10:00
psychedelicious
71e4901313 fix(ui): ignore disalbed ref images in readiness checks 2025-07-14 10:51:51 +10:00
psychedelicious
82fb897b62 chore(ui): lint 2025-07-12 14:56:57 +10:00
psychedelicious
192b00d969 chore: bump version to v6.0.2 2025-07-12 14:56:57 +10:00
psychedelicious
7bb25ef1b4 fix(ui): gallery dnd 2025-07-12 14:56:57 +10:00
psychedelicious
62f52c74a8 fix(ui): linked negative style prompt not passed in
Closes #8256
2025-07-12 10:22:17 +10:00
psychedelicious
97439c1daa fix(ui): native context menu shown on right click on short fat images
Closes #8254
2025-07-12 10:22:17 +10:00
psychedelicious
b23bff1b53 fix(ui): center staging area images 2025-07-12 10:22:17 +10:00
psychedelicious
d9a1efbabf fix(ui): staging area images may be slightly too large 2025-07-12 10:22:17 +10:00
psychedelicious
d4e903ee2d chore: bump version to v6.0.1 2025-07-12 10:22:17 +10:00
Kevin Turner
bb3e5d16d8 feat(Model Manager): refuse to download a file when there's insufficient space 2025-07-12 10:14:25 +10:00
psychedelicious
e62d3f01a8 feat(app): better error message for failed model probe
- Old: No valid config found
- New: Unable to determine model type
2025-07-11 23:35:43 +10:00
psychedelicious
757ecdbf82 build(ui): downgrade idb-keyval
We have increased error rates after updating this package. Let's try
downgrading to see if that fixes the issue.
2025-07-11 15:00:10 +10:00
psychedelicious
694c85b041 fix(ui): language file filenames
Need to replace the underscores w/ dashes - this was missed in #8246.
2025-07-11 14:21:41 +10:00
psychedelicious
988d7ba24c chore: bump version to v6.0.1rc1 2025-07-11 09:05:24 +10:00
psychedelicious
ac981879ef fix(ui): runtime errors related to calling reduce on array iterator
Fix an issue in certain browsers/builds causing a runtime error.

A zod enum has a .options property, which is an array of all the options
for the enum. This is handy for when you need to derive something from a
zod schema.

In this case, we represented the possible focus regions in the zod enum,
then derived a mapping of region names to set of target HTML elements.
Why isn't important, but suffice to say, we were using the .options
property for this.

But actually, we were using .options.values(), then calling .reduce() on
that. An array's .values() method returns an _array iterator_. Array
iterators do not have .reduce() methods!

Except, apparently in some environments they do - it depends on the JS
engine and whether or not polyfills for iterator helpers were included
in the build.

Turns out my dev environment - and most user browsers - do provide
.reduce(), so we didn't catch this error. It took a large deployment and
error monitoring to catch it.

I've refactored the code to totally avoid deriving data from zod in this
way.
2025-07-11 08:25:47 +10:00
psychedelicious
fc71849c24 feat(app): expose a cursor, not a connection in db util 2025-07-11 08:20:06 +10:00
psychedelicious
a19aa3b032 feat(app): db abstraction to prevent threading conflicts
- Add a context manager to the SqliteDatabase class which abstracts away
creating a transaction, committing it on success and rolling back on
error.
- Use it everywhere. The context manager should be exited before
returning results. No business logic changes should be present.
2025-07-11 08:20:06 +10:00
psychedelicious
ef4d5d7377 feat(ui): virtualized list for staging area
Make the staging area a virtualized list so it doesn't choke when there
are a large number (i.e. more than a few hundred) of queue items.
2025-07-11 07:50:57 +10:00
Mary Hipp Rogers
6b0dfd8427 dont reset canvas if studio is loaded with canvas destination (#8252)
Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
2025-07-10 09:36:41 -04:00
psychedelicious
471c010217 fix(ui): invalid language crashes app
- Apparently locales must use hyphens instead of underscores. This must
have been a fairly recent change that we didn't catch. It caused i18n to
throw for Brasilian Portuguese and both Simplified and Traditional
Mandarin. Change the locales to use the right strings.
- Move the theme + locale provider inside of the error boundary. This
allows errors with locals to be caught by the error boundary instead of
hard-crashing the app. The error screen is unstyled if this happens but
at least it has the reset button.
- Add a migration for the system slice to fix existing users' language
selections. For example, if the user had an incorrect language setting
of `zh_CN`, it will be changed to the correct `zh-CN`.
2025-07-10 14:27:36 +10:00
psychedelicious
b1193022f7 fix(ui): sometimes images added to gallery show as placeholder only
The range-based fetching logic had a subtle bug - it didn't keep track
of what the _current_ visible range is - only the ranges that the user
last scrolled to.

When an image was added to the gallery, the logic saw that the images
had changed, but thought it had already loaded everything it needed to,
so it didn't load the new image.

The updated logic tracks the current visible range separately from the
accumulated scroll ranges to address this issue.
2025-07-10 14:27:36 +10:00
psychedelicious
2152ca092c fix(ui): workaround for dockview bug that lets you drag tabs in certain ways 2025-07-10 14:27:36 +10:00
psychedelicious
ccc62ba56d perf(ui): revised range-based fetching strategy
When the user scrolls in the gallery, we are alerted of the new range of
visible images. Then we fetch those specific images.

Previously, each change of range triggered a throttled function to fetch
that range. The throttle timeout was 100ms.

Now, each change of range appends that range to a list of ranges and
triggers the throttled fetch. The timeout is increased to 500ms, but to
compensate, each fetch handles all ranges that had been accumulated
since the last fetch.

The result is far fewer network requests, but each of them gets more
images.
2025-07-10 14:27:36 +10:00
psychedelicious
9cf82de8c5 fix(ui): check for absolute value of scroll velocity to handle scrolling up 2025-07-10 14:27:36 +10:00
psychedelicious
aced349152 perf(ui): increase viewport in gallery
This allows us to prefetch more images and reduce how often placeholders
are shown as we fetch more images in the gallery.
2025-07-10 14:27:36 +10:00
psychedelicious
0d67ee6548 tests(ui): fix logging mock 2025-07-09 23:15:25 +10:00
psychedelicious
03c21d1607 fix(ui): gallery not updating when saving staging area image 2025-07-09 23:15:25 +10:00
psychedelicious
752e8db1f5 tidy(ui): demote logging in nav api to trace 2025-07-09 23:15:25 +10:00
psychedelicious
85fc861dd9 chore(ui): lint 2025-07-09 23:15:25 +10:00
psychedelicious
458cbfd874 fix(ui): selected model not highlighted 2025-07-09 23:15:25 +10:00
psychedelicious
04331c070a fix(ui): set denoise w/h when running flux fill 2025-07-09 23:15:25 +10:00
psychedelicious
632ddf0cb4 tests(ui): update tests for navigation api 2025-07-09 23:15:25 +10:00
psychedelicious
2b193ff416 fix(ui): delete stored state on error & save new state 2025-07-09 23:15:25 +10:00
psychedelicious
96ee394f9e refactor(ui): use dockview's own ser/de for persistence 2025-07-09 23:15:25 +10:00
psychedelicious
0badc80c0c fix(ui): ignore disabled ref images in readiness checks 2025-07-09 23:15:25 +10:00
psychedelicious
78e6cbf96e fix(ui): default tab is generate 2025-07-09 23:15:25 +10:00
psychedelicious
0b969a661b fix(ui): remove dep on focus from useDeleteImage 2025-07-09 23:15:25 +10:00
psychedelicious
6fe47ec9f8 feat(ui): improve ref image model autoswitch logic 2025-07-09 23:15:25 +10:00
Kent Keirsey
3850dd61f8 update comment 2025-07-09 23:15:25 +10:00
Kent Keirsey
75520eaf0f Match Chatgpt4o and kontext names exactly 2025-07-09 23:15:25 +10:00
Kent Keirsey
10e88c58c1 fix and lint 2025-07-09 23:15:25 +10:00
Kent Keirsey
30ed4dbd92 lint 2025-07-09 23:15:25 +10:00
Kent Keirsey
ed9c090f33 fixes 2025-07-09 23:15:25 +10:00
Kent Keirsey
d29f65ed22 lint fixes 2025-07-09 23:15:25 +10:00
Kent Keirsey
2062ec8ac0 Update invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts
Co-authored-by: Mary Hipp Rogers <maryhipp@gmail.com>
2025-07-09 23:15:25 +10:00
Cursor Agent
49e818338a Changes from background composer bc-abfadb27-a265-41a7-b0db-829879f4701e 2025-07-09 23:15:25 +10:00
Cursor Agent
1caab2b9c4 Implement automatic reference image model switching on base model change
Co-authored-by: kent <kent@invoke.ai>
2025-07-09 23:15:25 +10:00
psychedelicious
50079ea349 fix(ui): big red cancel button has diff behaviour than staging discard 2025-07-09 23:15:25 +10:00
psychedelicious
fffa1b24c4 fix(ui): isStaging selector could return wrong query cache 2025-07-09 23:15:25 +10:00
psychedelicious
a6d6170387 fix(ui): discarding 1 item when 2 items left in staging area discards both 2025-07-09 23:15:25 +10:00
psychedelicious
e5fceb0448 fix(ui): whole app scrolls while selecting staging area image 2025-07-09 23:15:25 +10:00
psychedelicious
059baf5b29 chore(ui): lint 2025-07-09 23:15:25 +10:00
psychedelicious
1be8a9a310 fix(ui): add metadata i18nKey to handler; fixes metadata toasts 2025-07-09 23:15:25 +10:00
psychedelicious
7adc33e04d refactor(ui): metadata recall buttons & hotkeys (WIP) 2025-07-09 23:15:25 +10:00
psychedelicious
7f2dd22d47 refactor(ui): metadata recall buttons & hotkeys (WIP) 2025-07-09 23:15:25 +10:00
psychedelicious
bb50f4b8a2 fix(ui): prevent panels from growing on init
This works but I think a better solution is to use dockview's provided
serialization API to store and restore layouts.
2025-07-09 23:15:25 +10:00
psychedelicious
a48958e0d4 chore(ui): lint 2025-07-09 23:15:25 +10:00
psychedelicious
e3a1e9af53 feat(ui): staging area updates
- Smaller staged image previews.
- Move autoswitch buttons to staging area toolbar, remove from settings
popover and the little three-dots menu. Use persisted autoswitch
setting, which is renamed from `defaultAutoSwitch` to
`stagingAreaAutoSwitch`.
- Fix issue with misaligned border radii in staging area preview images.
Required small changes to DndImage and its usage elsewhere.
- Fix issue where staging area toolbar could show up without any
previews in the list.
- Migrate canvas settings slice to use zod schema and inferred types for
its state.
2025-07-09 23:15:25 +10:00
psychedelicious
c6fe11c42f fix(ui): disable gallery hotkeys when in staging area 2025-07-09 23:15:25 +10:00
psychedelicious
4eb1bd67df fix(ui): hide staging area when there are no items 2025-07-09 23:15:25 +10:00
psychedelicious
c376f914d2 chore: bump version v6.0.0 2025-07-09 23:15:25 +10:00
Kent Keirsey
b5d1c47ef7 final link fix 2025-07-09 10:17:38 +10:00
Kent Keirsey
004a52ca65 fix to direct links 2025-07-09 10:17:38 +10:00
Kent Keirsey
b1d5a51ddf add-quantized-kontext-dev 2025-07-09 10:17:38 +10:00
Kent Keirsey
2b2498eaa1 fix prettier quirk 2025-07-08 14:54:29 -04:00
Kent Keirsey
10dda4440e Fix label 2025-07-08 14:54:29 -04:00
Cursor Agent
98f78abefa Add default auto-switch mode setting for canvas sessions
Co-authored-by: kent <kent@invoke.ai>
2025-07-08 14:54:29 -04:00
Mary Hipp Rogers
cc93fa270f update whats new for v6 (#8234)
Co-authored-by: Mary Hipp <maryhipp@Marys-Air.lan>
2025-07-08 18:24:33 +00:00
Mary Hipp Rogers
014b27680f fix flux kontext error (#8235)
Co-authored-by: Mary Hipp <maryhipp@Marys-Air.lan>
2025-07-08 13:42:48 -04:00
Mary Hipp Rogers
c3d8f875de if on generate tab, recall dimensions instead of bbox (#8233)
Co-authored-by: Mary Hipp <maryhipp@Marys-Air.lan>
2025-07-08 13:09:21 -04:00
Mary Hipp Rogers
79f9dc6e4a fix(ui): dont show option to add new layer from if on generate tab (#8231)
* dont show option to add new layer from if on generate tab

* only disable width/height recall is staging AND canvas tab

---------

Co-authored-by: Mary Hipp <maryhipp@Marys-Air.lan>
2025-07-08 11:46:54 -04:00
psychedelicious
6e1c0c1105 chore: bump version to v6.0.0rc5 2025-07-08 11:26:47 -04:00
Mary Hipp Rogers
0362524040 remove hard-coded flux kontext dev guidance (#8230)
Co-authored-by: Mary Hipp <maryhipp@Marys-Air.lan>
2025-07-08 10:26:20 -04:00
psychedelicious
dc6656459b docs(ui): updated comments for navigation api 2025-07-08 07:30:36 -04:00
psychedelicious
3ea1b97f6f fix(ui): protect against getting stuck on tab loading screen 2025-07-08 07:30:36 -04:00
psychedelicious
a7c7405ccc feat(ui): style model picker selected item 2025-07-08 07:28:07 -04:00
psychedelicious
c391f1117a fix(ui): traverse groups when finding selected model in picker 2025-07-08 07:28:07 -04:00
psychedelicious
b1e2cb8401 fix(ui): queue tab list of queue items
Reverted incomplete change to how queue items are listed. In the future
I think we should redo it to work like the gallery. For now, it is back
the way it was in v5.
2025-07-08 07:22:51 -04:00
Emmanuel Ferdman
db6af134b7 fix: resolve FastAPI deprecation warning for example fields
Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com>
2025-07-08 20:54:08 +10:00
Emmanuel Ferdman
7e6cffb00c fix: resolve FastAPI deprecation warning for example fields
Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com>
2025-07-08 20:54:08 +10:00
psychedelicious
5b187bcb00 fix(ui): pull bbox into ref image component 2025-07-08 14:54:43 +10:00
psychedelicious
0843d609a3 feat(ui): add list of warnings in tooltip on ref image 2025-07-08 14:54:43 +10:00
Kent Keirsey
95bd9cef18 Lint 2025-07-08 14:54:43 +10:00
Kent Keirsey
931d6521f6 Adds bbox to ref image button 2025-07-08 14:54:43 +10:00
psychedelicious
e37665ff59 tests(ui): add wiggle room to timeout tests 2025-07-08 12:55:33 +10:00
psychedelicious
56857fbbe6 tests(ui): add tests for panel storage 2025-07-08 12:55:33 +10:00
psychedelicious
43cfb8a574 tests(ui): get tests passing
Still need tests for panel storage.
2025-07-08 12:55:33 +10:00
psychedelicious
05b1682d15 fix(ui): handle collapsed panels when rehydrating their state 2025-07-08 12:55:33 +10:00
psychedelicious
69a08ee7f2 feat(ui): panel state persistence (WIP) 2025-07-08 12:55:33 +10:00
psychedelicious
18212c7d8a feat(ui): clean up navigation API surface and add comments 2025-07-08 12:55:33 +10:00
psychedelicious
7de26f8e69 feat(ui): clean up auto layout context for panels 2025-07-08 12:55:33 +10:00
Kent Keirsey
0652b12a6f Address comments 2025-07-08 12:31:11 +10:00
Kent Keirsey
43a361a00f prettier 2025-07-08 12:31:11 +10:00
Kent Keirsey
cf68ad9cbc update links to playlist instead of video 2025-07-08 12:31:11 +10:00
Kent Keirsey
ec02a39325 fixes 2025-07-08 12:31:11 +10:00
Kent Keirsey
e52d7a05c2 Update support links. 2025-07-08 12:31:11 +10:00
Cursor Agent
c9d4e2b761 Refactor support videos modal to simplify video and playlist handling
Co-authored-by: kent <kent@invoke.ai>
2025-07-08 12:31:11 +10:00
Kent Keirsey
ac26aa9508 fix 2025-07-08 12:31:11 +10:00
Cursor Agent
9ff6ada15b Add support for video playlists in support videos modal
Co-authored-by: kent <kent@invoke.ai>
2025-07-08 12:31:11 +10:00
psychedelicious
e81a115169 chore(ui): lint 2025-07-08 12:23:57 +10:00
Kent Keirsey
52827807de remove ref image from upscale 2025-07-08 12:23:57 +10:00
Kent Keirsey
b631de4cb5 consistency 2025-07-08 12:20:08 +10:00
Kent Keirsey
099ebdbc37 fix 2025-07-08 12:20:08 +10:00
psychedelicious
4de6549be9 refactor(ui): track discarded items instead of using delete method 2025-07-08 12:12:55 +10:00
psychedelicious
368be34949 chore(ui): lint 2025-07-08 12:12:55 +10:00
psychedelicious
5baa4bd916 refactor(ui): use cancelation for staging area (mostly) 2025-07-08 12:12:55 +10:00
psychedelicious
4229377532 fix(app): ensure cancel events are emitted for current item when bulk canceling
There was a bug where bulk cancel operations would cancel the current
queue item in the DB but not emit the status changed events correctly.
2025-07-08 12:12:55 +10:00
psychedelicious
2610772ffd feat(ui): tighten up launchpad content to fit better 2025-07-08 08:57:44 +10:00
psychedelicious
193de6a8f2 feat(ui): add launchpad container component 2025-07-08 08:57:44 +10:00
psychedelicious
7ea343c787 tidy(ui): remove "staging" from the new settings verbiage 2025-07-08 07:10:55 +10:00
Kent Keirsey
12179dabba fix prettier 2025-07-08 07:10:55 +10:00
Cursor Agent
ef135f9923 Add option to save all staging images to gallery in canvas mode
Co-authored-by: kent <kent@invoke.ai>
2025-07-08 07:10:55 +10:00
Mary Hipp
e6c67cc00f update toast for prompt expansion failed 2025-07-08 06:42:00 +10:00
psychedelicious
179b988148 fix(ui): prompt concat derived state recall 2025-07-08 06:37:43 +10:00
psychedelicious
d913a3c85b fix(ui): reset selected ref image when replacing all
Fixes an unhandled error in a selector that can throw.
2025-07-08 06:37:43 +10:00
psychedelicious
e79525c40c docs(ui): update comments 2025-07-08 06:11:32 +10:00
psychedelicious
f409f913ac fix(ui): navigation api usage 2025-07-08 06:11:32 +10:00
Mary Hipp
7a79f61d4c add claude nodes to blacklist for publishing 2025-07-08 05:50:40 +10:00
psychedelicious
ea182c234b chore: bump version to v6.0.0rc4 2025-07-07 22:15:28 +10:00
psychedelicious
f2eee4a82d chore(ui): lint 2025-07-07 22:05:49 +10:00
psychedelicious
e129525306 fix(app): handle None in queue count queries 2025-07-07 22:05:49 +10:00
psychedelicious
ecedfce758 feat(ui): support a min expanded size for collapsible panels 2025-07-07 22:05:49 +10:00
psychedelicious
702cb2cb1e fix(ui): flux kontext special handlign for ref image models 2025-07-07 22:05:49 +10:00
psychedelicious
2e8db3cce3 fix(ui): ensure noise is correctly sized 2025-07-07 22:05:49 +10:00
psychedelicious
7845623fa5 fix(ui): session context indexing bug 2025-07-07 22:05:49 +10:00
psychedelicious
e6a25ca7a2 feat(ui): render progress as indeterminate when percentage is 0
When percentage is zero, the progress bar looks the same as it does when
no generation is in progress. Render it as indeterminate (pulsing) when
percentage is zero to indicate that somethign is happenign.
2025-07-07 22:05:49 +10:00
psychedelicious
71e12bcebe fix(ui): when no negative prompt is provided, recall it as null 2025-07-07 22:05:49 +10:00
psychedelicious
863c7eb9e2 fix(ui): metadata display for primitive values 2025-07-07 22:05:49 +10:00
psychedelicious
9945c20d02 refactor(ui): simplifiy graph builders (WIP) 2025-07-07 22:05:49 +10:00
psychedelicious
e3c1334b1f refactor(ui): simplifiy graph builders (WIP) 2025-07-07 22:05:49 +10:00
psychedelicious
c143f63ef0 refactor(ui): simplifiy graph builders (WIP) 2025-07-07 22:05:49 +10:00
psychedelicious
067026a0d0 feat(ui): add autocomplete for Graph.addEdgeToMetadata 2025-07-07 22:05:49 +10:00
psychedelicious
66991334fc refactor(ui): simplify graph builder handling of VAE encode and seed 2025-07-07 22:05:49 +10:00
psychedelicious
b771c3b164 refactor(ui): update graphs to use the right w/h/aspect 2025-07-07 22:05:49 +10:00
psychedelicious
4925694dc1 feat(ui): generate tab has separate w/h/aspect 2025-07-07 22:05:49 +10:00
psychedelicious
0a737ced44 feat(ui): add dimensions to params slice 2025-07-07 22:05:49 +10:00
psychedelicious
8d83caaae0 feat(ui): extract aspect ratios from canvas reducers 2025-07-07 22:05:49 +10:00
psychedelicious
16c8017f1a feat(ui): more resilient gallery scrollIntoView 2025-07-07 22:05:49 +10:00
psychedelicious
61a35f1396 fix(ui): skip optimistic updates for gallery when using search term 2025-07-07 22:05:49 +10:00
psychedelicious
6bd004d868 fix(ui): clear ref images when recalling all
Closes #8202
2025-07-07 22:05:49 +10:00
psychedelicious
b6a6d406c7 chore(ui): typegen 2025-07-07 10:25:24 +10:00
psychedelicious
8e287c32ee chore(ui): lint 2025-07-07 10:25:24 +10:00
psychedelicious
2d8b5e26c2 build(ui): bump vite to latest 2025-07-07 10:25:24 +10:00
psychedelicious
50914b74ee chore(build): update pnpm to v10 2025-07-07 10:25:24 +10:00
psychedelicious
0fc1c33536 chore(ui): knip 2025-07-07 10:25:24 +10:00
psychedelicious
3b08c35f72 chore(ui): update knip config 2025-07-07 10:25:24 +10:00
psychedelicious
607b2561fd chore(ui): bump knip to latest 2025-07-07 10:25:24 +10:00
psychedelicious
d68f922efb fix(ui): restore upscale-tab-specific settings components 2025-07-07 10:25:24 +10:00
psychedelicious
2bbd74d418 feat(ui): restore canvas busy spinner 2025-07-07 10:25:24 +10:00
psychedelicious
3a5392a9ee chore: bump version to v6.0.0rc3 2025-07-04 20:46:08 +10:00
psychedelicious
6f80efe71d fix(ui): bump expandprompt timeout to 15s 2025-07-04 20:46:08 +10:00
psychedelicious
7fac833813 fix(ui): ref image model types again 2025-07-04 20:35:29 +10:00
psychedelicious
b67eb4134d fix(ui): select next image when deleting 2025-07-04 20:35:29 +10:00
psychedelicious
522eeda2e2 fix(ui): ref image model types 2025-07-04 20:35:29 +10:00
psychedelicious
76233241f0 fix(ui): include ref image metadata for flux kontext 2025-07-04 20:35:29 +10:00
psychedelicious
54be9989c5 feat(ui): add 'replace' and 'merge' strategies for upsertMetadata 2025-07-04 20:35:29 +10:00
psychedelicious
0d3af08d27 fix(ui): prompt parsing in useImageActions 2025-07-04 20:35:29 +10:00
psychedelicious
767ac91f2c fix(nodes): revert unnecessary version bump 2025-07-04 20:35:29 +10:00
psychedelicious
68571ece8f tidy(app): remove unused methods 2025-07-04 20:35:29 +10:00
psychedelicious
01100a2b9a fix(ui): check for ref image config compatibility for flux kontext dev 2025-07-04 20:35:29 +10:00
psychedelicious
ce2e6d8ab6 fix(ui): kontext gen mode error tkey 2025-07-04 20:35:29 +10:00
psychedelicious
4887424ca3 chore: ruff 2025-07-04 20:35:29 +10:00
Kent Keirsey
28f6a20e71 format import block 2025-07-04 20:35:29 +10:00
Kent Keirsey
c4142e75b2 fix import 2025-07-04 20:35:29 +10:00
Kent Keirsey
fefe563127 fix resizing and versioning 2025-07-04 20:35:29 +10:00
Mary Hipp
1c72f1ff9f include flux kontext non-api models in ref image dropdown options 2025-07-04 20:35:29 +10:00
Mary Hipp
605cc7369d update flux kontext implementation to include flux kontext dev non-api models 2025-07-04 20:35:29 +10:00
Kent Keirsey
e7ce08cffa ruff format 2025-07-04 19:24:44 +10:00
Kent Keirsey
983cb5ebd2 ruff ruff 2025-07-04 19:24:44 +10:00
Kent Keirsey
52dbdb7118 ruff 2025-07-04 19:24:44 +10:00
Kent Keirsey
71e6f00e10 test fixes
fix

test

fix 2

fix 3

fix 4

yet another

attempt new fix

pray

more pray

lol
2025-07-04 19:24:44 +10:00
psychedelicious
e73150c3e6 feat(ui): improved automatic tab/panel switching on user actions 2025-07-04 19:18:03 +10:00
psychedelicious
f2426c3ab2 fix(ui): type for dnd action 2025-07-04 19:18:03 +10:00
psychedelicious
9d9c4c0f1a tidy(ui): remove unused old metadata impl 2025-07-04 17:53:47 +10:00
psychedelicious
acb930f6b9 fix(ui): flux redux saves metadata 2025-07-04 17:53:47 +10:00
psychedelicious
585b54dc7d feat(ui): ref image recall w/ old canvas metadata backup 2025-07-04 17:53:47 +10:00
psychedelicious
f65affc0ec fix(ui): do not attempt to recall ref images from canvas metadata 2025-07-04 17:53:47 +10:00
psychedelicious
22d574c92a feat(ui): canvas metadata recall 2025-07-04 17:53:47 +10:00
psychedelicious
f23be119fc refactor(ui): migrating to new metadata handlers 2025-07-04 17:53:47 +10:00
psychedelicious
2d06949e80 feat(ui): display cached metadata if it exists instead of always waiting for debounce 2025-07-04 17:53:47 +10:00
psychedelicious
67804313e1 fix(ui): add ref images to metadata 2025-07-04 17:53:47 +10:00
psychedelicious
dc23be117a refactor(ui): simplified metadata parsing (WIP) 2025-07-04 17:53:47 +10:00
psychedelicious
350de058fc refactor(ui): simplified metadata parsing (WIP) 2025-07-04 17:53:47 +10:00
psychedelicious
fd5cd707a3 refactor(ui): simplified metadata parsing (WIP) 2025-07-04 17:53:47 +10:00
psychedelicious
98ecefdce0 refactor(ui): simplified metadata parsing (WIP) 2025-07-04 17:53:47 +10:00
psychedelicious
42688a0993 refactor(ui): metadata parsing 2025-07-04 17:53:47 +10:00
psychedelicious
d94aa4abf7 feat(ui): enforce loader when switching tabs 2025-07-04 16:49:57 +10:00
psychedelicious
69a56aafed feat(ui): do not require root ref to focus on prompt 2025-07-04 16:49:57 +10:00
psychedelicious
56873f6936 feat(ui): queue and models tab are wrapped in dockview panels 2025-07-04 16:49:57 +10:00
psychedelicious
6bc6a680cf tests(ui): NavigationApi 2025-07-04 16:49:57 +10:00
psychedelicious
9a49682f60 feat(ui): utils to get tab/panel keys to prevent typos 2025-07-04 16:49:57 +10:00
psychedelicious
ff84b0a495 refactor(ui): navigation api 2025-07-04 16:49:57 +10:00
psychedelicious
bcced8a5e8 refactor(ui): navigation api 2025-07-04 16:49:57 +10:00
psychedelicious
4a18e9eaea refactor(ui): panel api (WIP) 2025-07-04 16:49:57 +10:00
psychedelicious
dde5bf61be feat(ui): use exact brand colors in loader 2025-07-04 16:49:57 +10:00
psychedelicious
987e401709 perf(ui): lora components 2025-07-04 14:55:52 +10:00
psychedelicious
5c5ac570e3 fix(ui): hardcode literals for run graph errors
When we build, the class names are minified. This hardcodes the values
to literals.
2025-07-04 14:52:08 +10:00
psychedelicious
309903fe0f feat(ui): refetch gallery image names on reconnect
Maybe fixes JP's issue (again)
2025-07-04 14:49:32 +10:00
psychedelicious
f16ea43e9a feat(ui): enable RTK Query's refetchOnReconnect 2025-07-04 14:49:32 +10:00
Jeremy Gooch
d794aedb43 fix(ui): sets cfg_rescael_multiplier to 0 if there is no default. Also fixes issue with truthiness check causing 0 value to be missed. See https://github.com/invoke-ai/InvokeAI/issues/7584 2025-07-04 06:20:14 +10:00
psychedelicious
9930440f33 chore: bump version to v6.0.0rc2 2025-07-03 12:35:04 +10:00
psychedelicious
f0a6c4aa1f fix(ui): after canceling a filter, layer loses its content 2025-07-03 12:30:01 +10:00
psychedelicious
f36d22f13c fix(ui): control layers ignored in txt2img 2025-07-03 12:27:05 +10:00
Cursor Agent
e0d7fab524 Fix: Toggle right panel instead of left panel in navigation
Co-authored-by: kent <kent@invoke.ai>
2025-07-03 12:15:22 +10:00
Cursor Agent
f20c230f4a Add drag-and-drop comparison image target to ImageViewerPanel
Co-authored-by: kent <kent@invoke.ai>
2025-07-03 12:10:51 +10:00
Cursor Agent
05c9bc730e Fix canvas export layer bounds calculation in PSD export hook
Co-authored-by: kent <kent@invoke.ai>
2025-07-03 12:07:22 +10:00
Cursor Agent
f17ac06591 Fix PSD export to use layer content bounds and crop canvas
Co-authored-by: kent <kent@invoke.ai>
2025-07-03 12:07:22 +10:00
Kent Keirsey
b35f93d919 Change implementation to check $ispending 2025-07-03 12:04:27 +10:00
Cursor Agent
289d8076d8 Reset canvas session when queue item is canceled in current session
Co-authored-by: kent <kent@invoke.ai>
2025-07-03 12:04:27 +10:00
skunkworxdark
604763d20f Update flux.py
Replace T5Tokenizer with T5TokenizerFast
2025-07-03 08:04:08 +10:00
Mary Hipp
7b452f098d lint 2025-07-02 16:27:44 -04:00
Mary Hipp
b41c18d35f disable dropzone if prompt expansion is disabled 2025-07-02 16:27:44 -04:00
Mary Hipp
8328081333 properly build batch for flux kontext api batches 2025-07-02 14:27:57 -04:00
Mary Hipp Rogers
07517cf2c2 remove pulsing animation (#8181)
Co-authored-by: Mary Hipp <maryhipp@Marys-Air.lan>
2025-07-02 16:12:52 +00:00
Kent Keirsey
6b98ad9095 Only display one icon on disabled state 2025-07-02 10:54:46 -04:00
Kent Keirsey
0de3967e7e remove stray file 2025-07-02 10:54:46 -04:00
Kent Keirsey
1335377fb1 Fixes 2025-07-02 10:54:46 -04:00
Cursor Agent
adbcc191d9 Add reference image enable/disable functionality
Co-authored-by: kent <kent@invoke.ai>
2025-07-02 10:54:46 -04:00
Kent Keirsey
11fc7af1c8 fix 2025-07-02 10:47:01 -04:00
Cursor Agent
6f12fd22b9 Optimize image API invalidation tags and simplify cache invalidation logic
Co-authored-by: kent <kent@invoke.ai>
2025-07-02 10:47:01 -04:00
Cursor Agent
324b6e2af4 Update LoRA select placeholder text for better clarity
Co-authored-by: kent <kent@invoke.ai>
2025-07-02 10:36:45 -04:00
Mary Hipp Rogers
038010a1ca feat(ui): prompt expansion (#8140)
* initializing prompt expansion and putting response in prompt box working for all methods

* properly disable UI and show loading state on prompt box when there is a pending prompt expansion item

* misc wrapup: disable apploying prompt templates, dont block textarea resize handle

* update progress to differentiate between prompt expansion and non

* cleanup

* lint

* more cleanup

* add image to background of loading state

* add allowPromptExpansion for front-end gating

* updated readiness text for needing to accept or discard

* fix tsc

* lint

* lint

* refactor(ui): prompt expansion logic

* tidy(ui): remove unnecessary changes

* revert(ui): unused arg on useImageUploadButton

* feat(ui): simplify prompt expansion state

* set pending for dragndrop and context menu

* add readiness logic for generate tab

* missing translation

* update error handling for prompt expansion

---------

Co-authored-by: Mary Hipp <maryhipp@Marys-Air.lan>
Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
2025-07-02 10:26:48 -04:00
Cursor Agent
2dd1bc54c9 Set brush tool automatically when sending image to canvas
Co-authored-by: kent <kent@invoke.ai>
2025-07-02 10:09:22 -04:00
Kent Keirsey
8b69842678 lint 2025-07-02 09:46:32 -04:00
Kent Keirsey
9821f7c4fc Remove Canvas Session 2025-07-02 09:46:32 -04:00
Cursor Agent
2290ff4ad6 Fix: Focus viewer panel when switching to workflow view mode
Co-authored-by: kent <kent@invoke.ai>
2025-07-02 09:42:21 -04:00
psychedelicious
8d82ad6d0b fix(api): return HTTP errors from session queue handlers 2025-07-02 08:42:06 -04:00
Mary Hipp
8ed9f652e8 lint 2025-07-02 08:25:42 -04:00
Mary Hipp
ee8ed344bd add modelRelationships and aboutModal to disable-able features 2025-07-02 08:25:42 -04:00
Mary Hipp
6d16cfdbe2 missing import 2025-07-02 08:23:13 -04:00
Mary Hipp
3ef2872dda handle flux-kontext models 2025-07-02 08:23:13 -04:00
Cursor Agent
b52ba149b4 Update regional guidance empty state translation key
Co-authored-by: kent <kent@invoke.ai>
2025-07-02 08:09:42 -04:00
Kent Keirsey
c6126c6875 Remove all references to New Sessions entirely. 2025-07-01 17:20:35 -04:00
psychedelicious
3f78ac9295 fix(ui): really do not load disabled tabs
Ensure disabled tabs are never mounted:
- Add didLoad flag to configSlice, default false
- Always merge in config - even it is is empty
- On first merge, set didLoad to true
- Until didLoad is true, mark _all_ tabs as disabled

This gets around an issue where tabs are all enabled for a brief moment
before the config is loaded.

A bit hacky but it works.
2025-07-01 10:52:28 -04:00
psychedelicious
79fea1ac40 chore: bump version to v6.0.0rc1 2025-07-02 00:14:13 +10:00
psychedelicious
6eade5781d feat(ui): remove mini metadata viewer 2025-07-01 23:37:31 +10:00
psychedelicious
3d8f865fb0 fix(ui): initial panel sizing 2025-07-01 23:37:31 +10:00
psychedelicious
dc9cd22d9d feat(ui): better naming for panel apis 2025-07-01 23:37:31 +10:00
psychedelicious
fe115ff8f9 fix(ui): models & queue tab styling 2025-07-01 23:37:31 +10:00
psychedelicious
1d35aad213 feat(ui): move more things over to pane lreg 2025-07-01 23:37:31 +10:00
psychedelicious
195d6ce893 refactor(ui): implement global panel registry, replace context-based panel API 2025-07-01 23:37:31 +10:00
psychedelicious
f13ced7ed4 fix(ui): rebase conflicts 2025-07-01 23:37:31 +10:00
psychedelicious
735fc276e5 tidy(ui): clean up focus/layout container 2025-07-01 23:37:31 +10:00
psychedelicious
cd3caf8c30 fix(ui): delete image hotkey 2025-07-01 23:37:31 +10:00
psychedelicious
e9012280ab fix(ui): upscaling tab boards/gallery collapse 2025-07-01 23:37:31 +10:00
psychedelicious
fa72a97794 refactor(ui): even more better focus handling 2025-07-01 23:37:31 +10:00
psychedelicious
e817631ba3 refactor(ui): focus handling for new layout system (WIP) 2025-07-01 23:37:31 +10:00
psychedelicious
d0619c033f feat(ui): add edit button to current image buttons 2025-07-01 16:29:20 +10:00
psychedelicious
6f4850f34f tidy(ui): launchpad tab with icon cleanup 2025-07-01 15:37:06 +10:00
Kent Keirsey
072cd9dee7 Styling Fixes 2025-07-01 15:37:06 +10:00
Cursor Agent
19b6dc1c1f Add custom Launchpad tab with dynamic icon based on active tab
Co-authored-by: kent <kent@invoke.ai>
2025-07-01 15:37:06 +10:00
Cursor Agent
7566d0d6c6 Enhance workflow mode toggle with panel navigation and focus
Co-authored-by: kent <kent@invoke.ai>
2025-07-01 15:27:21 +10:00
psychedelicious
f123888b46 feat(ui): tidy workflows tab launchapd 2025-07-01 15:24:08 +10:00
psychedelicious
aeab7d0cab feat(ui): tidy upscaling tab launchapd 2025-07-01 15:24:08 +10:00
Kent Keirsey
3f1b2c39ab Model Guide link update 2025-07-01 15:24:08 +10:00
Kent Keirsey
72e3a4b4be Fixes & Updates 2025-07-01 15:24:08 +10:00
Kent Keirsey
58e0f80138 Lint 2025-07-01 15:24:08 +10:00
Kent Keirsey
8b8e29d22d Fixes & Styling updates 2025-07-01 15:24:08 +10:00
Kent Keirsey
90201be670 lint 2025-07-01 15:24:08 +10:00
Kent Keirsey
46a5619100 Update all text to translations 2025-07-01 15:24:08 +10:00
Kent Keirsey
d608a7469e Upscale Workflow Launchpad updates & translation updates 2025-07-01 15:24:08 +10:00
Cursor Agent
a7d413d372 Refactor Upscaling and Workflows Launchpad Panels with enhanced UI
Co-authored-by: kent <kent@invoke.ai>
2025-07-01 15:24:08 +10:00
Cursor Agent
f5c9e68dbf Fix division by zero in multi-diffusion pipeline with creativity values
Co-authored-by: kent <kent@invoke.ai>

Revert unnecessary validation changes in multi-diffusion

Fix in python instead of graphbuilder

tidy(ui): remove extraneous comment
2025-07-01 15:00:02 +10:00
psychedelicious
1ded459f03 refactor(ui): clean up related models impl for picker 2025-07-01 14:52:26 +10:00
Kent Keirsey
d9024dc230 linting fixes 2025-07-01 14:52:26 +10:00
Kent Keirsey
40528692c3 Update icon 2025-07-01 14:52:26 +10:00
Kent Keirsey
f35b05be43 simplifies Modelpicker wrapper 2025-07-01 14:52:26 +10:00
Kent Keirsey
29e87fc615 lints 2025-07-01 14:52:26 +10:00
Kent Keirsey
ca26b2718e Small Changes 2025-07-01 14:52:26 +10:00
Cursor Agent
5fa6c0b413 Enhance model picker with related models and improved filtering
Co-authored-by: kent <kent@invoke.ai>
2025-07-01 14:52:26 +10:00
psychedelicious
c37c8c50cd tidy(ui): clean up psd export 2025-07-01 14:12:14 +10:00
Kent Keirsey
f0a4de245d Moved size constants to a reasonable spot... 2025-07-01 14:12:14 +10:00
Kent Keirsey
5db62f8643 Fix Type refs 2025-07-01 14:12:14 +10:00
Kent Keirsey
e1c478f94c Size Updates 2025-07-01 14:12:14 +10:00
Kent Keirsey
11fe3b6332 Comments 2025-07-01 14:12:14 +10:00
Kent Keirsey
e4aae1a591 prettier 2025-07-01 14:12:14 +10:00
Kent Keirsey
4d83d1c56d Linting 2025-07-01 14:12:14 +10:00
Kent Keirsey
34def323e8 Restyle & locate 2025-07-01 14:12:14 +10:00
Kent Keirsey
854956316b Fix export layers 2025-07-01 14:12:14 +10:00
Cursor Agent
91afe7884a Add PSD export functionality for canvas layers
Co-authored-by: kent <kent@invoke.ai>
2025-07-01 14:12:14 +10:00
psychedelicious
8417ee8a7b chore(ui): lint 2025-06-30 23:42:53 +10:00
psychedelicious
a035645ed3 refactor(ui): graph building respects selected tab 2025-06-30 23:42:53 +10:00
psychedelicious
e00ccba7d3 perf(ui): select only loading state for enqueueBatch mutation 2025-06-30 23:42:53 +10:00
psychedelicious
fb883d63aa refactor(ui): dedicated enqueue funcs for each tab 2025-06-30 23:42:53 +10:00
psychedelicious
b113c57fc4 refactor(ui): use redux-provided hooks for accessing store 2025-06-30 23:42:53 +10:00
psychedelicious
7636007349 fix(ui): useAppStore uses correct types 2025-06-30 23:42:53 +10:00
psychedelicious
fda86ae981 fix(app): incorrect node mappings when preparing collect nodes
The previous logic had a subtle python bug related the scope and nested
generators.

Python generators are lazily evaluated - the expressions are stored and
only evaluated when needed (e.g. calling next() or list() on them)

The old logic used a variable `s`, which was continually overwritten as
the generator expressions were created. As a result, the final mappings
all use the _final_ value for `s`.

Following the consequences of this down the line, we find that collect
nodes can end up with multiple edges from exactly one of their ancestor
nodes, instead of one edge from each ancestor. Notably, it's only the
source _node_id_ that is affected - the source _fields_ have the correct
values.

So the invalid edges will point to a real node and a real field, but the
field exists on a different node.

---

This can result in a number of cryptic problems - include an error about
incompatible field types:

```
InvalidEdgeError: Field types are incompatible
(31758fd5-14a8-4de7-a840-b73ec1a1b94f.value ->
3459c793-41a2-4d82-9204-7df2d6d099ba.item)
```

Here are the conditions that lead to this error:
- The collect node has at least two incoming connections.
- The two incoming connections come from nodes of different types.
- The nodes both output a value of the same type, but the name of the
output field differs between them.

---

This commit uses non-generator logic to build up the mappings, avoiding
the issue entirely. As a bonus, it is much easier to read.
2025-06-30 23:39:28 +10:00
psychedelicious
c02be4bdf4 refactor(app): lean on pydantic to get field types in edge validation logic
Previously we used python's own type introspection utilties to determine
input and output field types. We can use pydantic to get the field types
in a clearer, more direct way.

This improvement also exposed an awkward behaviour in this utility,
where it would return None when a field doesn't exist. I've added a
comment in the code describing the issue, but changing it would require
some significant changes and I don't want to risk breaking anything.
2025-06-30 23:39:28 +10:00
psychedelicious
ed7772d993 tests(app): add more tests for complex iterate/collect graph topologies 2025-06-30 23:39:28 +10:00
psychedelicious
baae998b5b tests(app): add failing test for collector edge case
squash

squash
2025-06-30 23:39:28 +10:00
DustyShoe
4077ffe595 Fixed a typo 2025-06-30 15:44:23 +10:00
psychedelicious
c1937b1379 chore: ruff 2025-06-30 12:56:51 +10:00
psychedelicious
5c66dfed8e fix(app): remove errant comment from prev impl 2025-06-30 12:56:51 +10:00
psychedelicious
126dcc96c0 feat(ui): clean up logging and comments in runGraph 2025-06-30 12:56:51 +10:00
psychedelicious
cb9c7b4a28 feat(ui): simplify runGraph logic for error handling 2025-06-30 12:56:51 +10:00
psychedelicious
e8c4f49a14 feat(ui): add .wrap() method to WrappedError 2025-06-30 12:56:51 +10:00
psychedelicious
30fffae637 feat(ui): runGraph settlement callbacks can simply return or throw 2025-06-30 12:56:51 +10:00
psychedelicious
4558a292b6 tests(ui): update runGraph tests for separate options 2025-06-30 12:56:51 +10:00
psychedelicious
825d17441c feat(ui): separate options arg for runGraph 2025-06-30 12:56:51 +10:00
psychedelicious
9b16504af9 docs(ui): improved runGraph docstring 2025-06-30 12:56:51 +10:00
psychedelicious
46c92fadff feat(ui): use system logger for runGraph 2025-06-30 12:56:51 +10:00
psychedelicious
c0467b82ac tests(ui): update runGraph tests for new error state 2025-06-30 12:56:51 +10:00
psychedelicious
6dafa67286 feat(ui): improved logging for runGraph 2025-06-30 12:56:51 +10:00
psychedelicious
eb406aa07e feat(ui): mark runGraph error properties public readonly 2025-06-30 12:56:51 +10:00
psychedelicious
d9422ffebd tests(ui): add testes for enriched cancel/timeout errors 2025-06-30 12:56:51 +10:00
psychedelicious
d5c033be4d feat(ui): enrich cancel/timeout errors when queue item cancel fails 2025-06-30 12:56:51 +10:00
psychedelicious
4662cd6f15 fix(ui): await cancelation of queue item before returning 2025-06-30 12:56:51 +10:00
psychedelicious
a740a22613 feat(ui): runGraph uses settle for all promise handling, better comments 2025-06-30 12:56:51 +10:00
psychedelicious
bf4016b4bc feat(ui): add getNodes method to Graph 2025-06-30 12:56:51 +10:00
psychedelicious
6fa7c8c2ee feat(ui): better exception naming and docstrings in runGraph 2025-06-30 12:56:51 +10:00
psychedelicious
ea40f582da tweak(ui): naming, code style 2025-06-30 12:56:51 +10:00
psychedelicious
01caf56251 feat(ui): clearer naming in WrappedError 2025-06-30 12:56:51 +10:00
psychedelicious
42d577e65a tests(ui): check for error instance instead of message 2025-06-30 12:56:51 +10:00
psychedelicious
38d80c9ce5 fix(ui): clear cleanupFunctions when finished calling them 2025-06-30 12:56:51 +10:00
psychedelicious
6acaa8abbf refactor(ui): use deferred promise as workaround to antipattern of async promise executor 2025-06-30 12:56:51 +10:00
psychedelicious
4b84e34599 refactor(ui): better race condition handling in runGraph 2025-06-30 12:56:51 +10:00
psychedelicious
bbd21b1eb2 feat(ui): rename isSettled -> isFinished 2025-06-30 12:56:51 +10:00
psychedelicious
4fa83a6228 feat(ui): better error handling for runGraph 2025-06-30 12:56:51 +10:00
psychedelicious
051876dcff feat(ui): ensure promise always marked as settled, better comments 2025-06-30 12:56:51 +10:00
psychedelicious
8dc6d0b5ae feat(ui): use runGraph in canvas 2025-06-30 12:56:51 +10:00
psychedelicious
40e9624954 tests(ui): edge cases in runGraph 2025-06-30 12:56:51 +10:00
psychedelicious
ae27c83dc4 feat(ui): log when cancelation fails 2025-06-30 12:56:51 +10:00
psychedelicious
161059551b fix(ui): handle errors during cleanup 2025-06-30 12:56:51 +10:00
psychedelicious
c196f8a5d5 tests(ui): add tests for runGraph 2025-06-30 12:56:51 +10:00
psychedelicious
2c6d22664e feat(ui): use DI to make runGraph testable 2025-06-30 12:56:51 +10:00
psychedelicious
b9ce5389ef fix(ui): clean up signal 2025-06-30 12:56:51 +10:00
psychedelicious
d1cbf56695 feat(ui): iterate on runGraph 2025-06-30 12:56:51 +10:00
psychedelicious
e379ac12c3 feat(ui): abstraction to make a graph await-able 2025-06-30 12:56:51 +10:00
psychedelicious
aa10373292 feat(ui): loosen typings for Result 2025-06-30 12:56:51 +10:00
psychedelicious
780f3692a0 chore(ui): typegen 2025-06-30 12:56:51 +10:00
psychedelicious
3604dcfdd1 feat(api): return list of enqueued item ids when enqueuing 2025-06-30 12:56:51 +10:00
Jonathan
2b1cffde5e typegen 2025-06-30 11:28:02 +10:00
Jonathan
83d642ed15 Update flux_denoise.py
Fixed version to 4.0.0
2025-06-30 11:28:02 +10:00
Jonathan
455c73235e Update flux_denoise.py
Updated version, removed WithBoard and WithMetadata
2025-06-30 11:28:02 +10:00
psychedelicious
8efef8da41 feat(ui): workflows styling tweaks 2025-06-30 11:17:29 +10:00
psychedelicious
060a9e57b9 fix(ui): prevent NaN from getting into konva internals 2025-06-30 10:43:11 +10:00
skunkworxdark
099d75ca1e use "\u2581" instead of the character itself for clarity 2025-06-30 10:40:31 +10:00
skunkworxdark
bbb5d68146 Update flux_text_encoder.py
Added tokenizer logging to flux
2025-06-30 10:40:31 +10:00
psychedelicious
9066dc1839 tidy(nodes): remove extraneous comments & add useful ones 2025-06-27 18:27:46 +10:00
psychedelicious
075345bffd feat(app): add flux kontext dev to starter modelss 2025-06-27 18:27:46 +10:00
psychedelicious
74d1239c87 chore(ui): typegen 2025-06-27 18:27:46 +10:00
Kent Keirsey
51e1c56636 ruff 2025-06-27 18:27:46 +10:00
Kent Keirsey
ca1df60e54 Explain the Magic 2025-06-27 18:27:46 +10:00
Cursor Agent
7549c1250d Add FLUX Kontext conditioning support for reference images
Co-authored-by: kent <kent@invoke.ai>

Fix Kontext sequence length handling in Flux denoise invocation

Co-authored-by: kent <kent@invoke.ai>

Fix Kontext step callback to handle combined token sequences

Co-authored-by: kent <kent@invoke.ai>

fix ruff

Fix Flux Kontext
2025-06-27 18:27:46 +10:00
psychedelicious
df8751b5a1 fix(ui): remove extraneous rect in stagingareamodule 2025-06-27 15:45:53 +10:00
psychedelicious
651b80b997 fix(ui): remove extraneous syncPlaceholderSize method and calls 2025-06-27 15:45:53 +10:00
psychedelicious
5d236ae4e7 fix(ui): canvas staging waiting for image placeholder sizing and layout 2025-06-27 15:45:53 +10:00
psychedelicious
e5dc606f5e fix(ui): get accurate theme tokens 2025-06-27 15:45:53 +10:00
Kent Keirsey
dc6b8e13bd prettier 2025-06-27 15:45:53 +10:00
Cursor Agent
c1b34e1f11 Standardize UI spacing and constants across canvas and image components
Co-authored-by: kent <kent@invoke.ai>
2025-06-27 15:45:53 +10:00
Cursor Agent
89f1684072 Improve placeholder styling with badge and refined text positioning
Co-authored-by: kent <kent@invoke.ai>
2025-06-27 15:45:53 +10:00
Kent Keirsey
14fbee17a3 Rule of 3rds Composition Guide (#8130)
* Add Rule of 4 composition guide to canvas settings and rendering

Co-authored-by: kent <kent@invoke.ai>

* Rename Rule of 4 Guide to Rule of Thirds in canvas composition guide

Co-authored-by: kent <kent@invoke.ai>

* Updates to comp guide and naming

* Fix reference

* Update translation keys and organize settings.

* revert to previous canvas manager for conflict

* Re-add composition guide.

* Fix lint

* prettier

* feat(ui): improve markup in canvas settings popover

* feat(ui): use brand colors for canvas rule of thirds guide

---------

Co-authored-by: Cursor Agent <cursoragent@cursor.com>
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
2025-06-27 15:05:34 +10:00
psychedelicious
5dbc32e06e feat(ui): minor restyle of style preset list 2025-06-27 14:40:35 +10:00
psychedelicious
23baf61e51 fix(ui): remove extraneous slice migration for style presets 2025-06-27 14:40:35 +10:00
Kent Keirsey
5e55f6074b prettier 2025-06-27 14:40:35 +10:00
Kent Keirsey
f7c555e501 Change to Toggle Tooltip 2025-06-27 14:40:35 +10:00
Cursor Agent
6aa605e811 Add toggle for showing/hiding style preset prompt previews
Co-authored-by: kent <kent@invoke.ai>
2025-06-27 14:40:35 +10:00
psychedelicious
f51014e108 feat(ui): make launchpad button its own component 2025-06-27 14:37:30 +10:00
psychedelicious
9862ba9210 feat(ui): improved starter model buttons & tooltips 2025-06-27 14:37:30 +10:00
psychedelicious
920aea08cc tidy(ui): remove unused translation strings 2025-06-27 14:37:30 +10:00
psychedelicious
39e584297e feat(ui): fix missing translations 2025-06-27 14:37:30 +10:00
psychedelicious
62a14bb935 feat(ui): use enriched starter model metadata 2025-06-27 14:37:30 +10:00
psychedelicious
d7ae2cdf75 chore(ui): typegen 2025-06-27 14:37:30 +10:00
psychedelicious
6172c859ac feat(api): enrich starer model bundle metadata 2025-06-27 14:37:30 +10:00
psychedelicious
b26fb1f617 feat(ui): simplify markup for install models launchpad form 2025-06-27 14:37:30 +10:00
psychedelicious
05167dfd7a feat(ui): use existing design language for install model bundle buttons 2025-06-27 14:37:30 +10:00
psychedelicious
c090ea7387 feat(ui): use existing design language for install model launchpad buttons 2025-06-27 14:37:30 +10:00
psychedelicious
7ba6c67049 feat(ui): named install models tabs 2025-06-27 14:37:30 +10:00
psychedelicious
3de186061d chore(ui): lint 2025-06-27 14:37:30 +10:00
Kent Keirsey
a716381733 Model Launchpad prettier 2025-06-27 14:37:30 +10:00
Kent Keirsey
fb5df06835 Updating toinclude translations and import fixes 2025-06-27 14:37:30 +10:00
Kent Keirsey
33c597c224 fix lint 2025-06-27 14:37:30 +10:00
Kent Keirsey
19d882d038 Address comments 2025-06-27 14:37:30 +10:00
Kent Keirsey
ee4bc49bd4 Prettier. 2025-06-27 14:37:30 +10:00
Kent Keirsey
188cf37f48 fix lint 2025-06-27 14:37:30 +10:00
Kent Keirsey
15a0a7134c fix circ dependency 2025-06-27 14:37:30 +10:00
Kent Keirsey
22cea0de8b Remove scrap 2025-06-27 14:37:30 +10:00
Kent Keirsey
cd21816d12 Model Launchpad 2025-06-27 14:37:30 +10:00
psychedelicious
605b912ba4 fix(ui): remove noop hook 2025-06-27 11:37:47 +10:00
psychedelicious
52e31112f9 chore(ui): lint 2025-06-27 11:37:47 +10:00
Kent Keirsey
a4c9346cd7 lint 2025-06-27 11:37:47 +10:00
Kent Keirsey
a1647e4c6e Address comments 2025-06-27 11:37:47 +10:00
Kent Keirsey
8c9ca088a7 update tooltip 2025-06-27 11:37:47 +10:00
Cursor Agent
7a7a2e147c Add toggle for non-raster layers with hotkey and UI button 2025-06-27 11:37:47 +10:00
psychedelicious
adf4cc750a fix(ui): Fix LoRA picker to default to current base model architecture (#8135)
Enhance LoRA picker to default filter by current base model architecture

## Summary
Fixes new LoRA picker to auto select the architecture filter for the
current model group

## Related Issues / Discussions
N/A

## QA Instructions

Open LoRA menu with any model group selected. The right models should be
filtered.

## Merge Plan
Merge when ready.

## Checklist

- [X] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
2025-06-27 11:21:39 +10:00
psychedelicious
9f1ea9d1c7 fix(ui): use existing GroupStatusMap type 2025-06-27 11:19:24 +10:00
Cursor Agent
571d286506 Enhance LoRA picker to default to current base model architecture
Co-authored-by: kent <kent@invoke.ai>

Enhance LoRA picker to default filter by current base model architecture

Co-authored-by: kent <kent@invoke.ai>
2025-06-26 20:43:43 -04:00
Mary Hipp
1320a2c5f8 add option to override text for no options available 2025-06-26 18:09:57 -04:00
Mary Hipp
26a9b3131d convert LoRA picker to use new model picker component 2025-06-26 18:09:57 -04:00
psychedelicious
d48140b35d fix(ui): regional guidance ref image not selecting 2025-06-26 10:05:25 -04:00
psychedelicious
9757bb0325 refactor(ui): canvas flow (#8069) 2025-06-26 21:24:17 +10:00
psychedelicious
38ccd8e09c chore: bump version to v6.0.0a10 2025-06-26 21:06:24 +10:00
psychedelicious
7759b166a9 fix(ui): dnd on images
Need to use callback refs else chakra's image fallback breaks the ref
2025-06-26 20:53:50 +10:00
psychedelicious
9fc51c7a6e fix(ui): optimistic updates when sorting by oldest first 2025-06-26 20:24:52 +10:00
psychedelicious
62fa4f42f5 fix(ui): more viewer progress nonsense 2025-06-26 20:17:47 +10:00
psychedelicious
418ad0de38 fix(ui): rebase conflicts 2025-06-26 20:06:26 +10:00
psychedelicious
f4a411326e chore: bump version to v6.0.0a9 2025-06-26 20:00:41 +10:00
psychedelicious
6358f39ebb chore(ui): lint 2025-06-26 20:00:40 +10:00
psychedelicious
ea8da0bfbf chore: ruff 2025-06-26 20:00:40 +10:00
psychedelicious
5385282325 feat(ui): use consistent gallery scrollseek placeholder component 2025-06-26 20:00:40 +10:00
psychedelicious
0bf84ab803 feat(ui): gallery scrollbars autohide 2025-06-26 20:00:40 +10:00
psychedelicious
82f31f2258 feat(ui): tweak canvas entity group list button layout 2025-06-26 20:00:40 +10:00
psychedelicious
966dd8857d feat(ui): boards and gallery panel collapse 2025-06-26 20:00:40 +10:00
psychedelicious
1c778bd719 fix(ui): some progress image jank 2025-06-26 20:00:40 +10:00
psychedelicious
394a14cf61 fix(ui): progress in viewer bg color 2025-06-26 20:00:40 +10:00
psychedelicious
0e843823d1 fix(ui): ensure image selected on first load 2025-06-26 20:00:40 +10:00
psychedelicious
29462e62d2 fix(ui): handle selecting images/boards on invocation complete 2025-06-26 20:00:40 +10:00
psychedelicious
175c0147f8 fix(ui): auto image selection on invocation complete, board change 2025-06-26 20:00:40 +10:00
psychedelicious
df6e67c982 fix(ui): queue count badge showing up multiple times 2025-06-26 20:00:40 +10:00
psychedelicious
4612f0ac50 fix(ui): tab bar shrinkage 2025-06-26 20:00:39 +10:00
psychedelicious
386a932f2a feat(ui): clean up GalleryImage 2025-06-26 20:00:39 +10:00
psychedelicious
32438532b0 fix(ui): prevent duplicate initial galler yfetches 2025-06-26 20:00:39 +10:00
psychedelicious
ab5cb2c264 refactor: optimistic gallery updates 2025-06-26 20:00:39 +10:00
psychedelicious
504daa0ae5 Revert "build(ui): adopt sonda over rollup-plugin-visualizer to examine bundle"
This reverts commit e0cf2a8046.
2025-06-26 20:00:39 +10:00
psychedelicious
14f7c98e8a chore(ui): bump package version 2025-06-26 20:00:39 +10:00
psychedelicious
ab39305223 chore(ui): upgrade zod to v4 2025-06-26 20:00:39 +10:00
psychedelicious
7948bca864 build(ui): adopt sonda over rollup-plugin-visualizer to examine bundle
Requires a change to tsconfig module/moduleResolution settings. We were
on old legacy values anyways so good to update it.
2025-06-26 20:00:39 +10:00
psychedelicious
1a39d22b6c feat(ui): migrate from lodash-es to es-toolkit 2025-06-26 20:00:39 +10:00
psychedelicious
9424271d12 revert(ui): undo accidental downgrade of rtk 2025-06-26 20:00:39 +10:00
psychedelicious
b5acc204a8 feat(ui): migrate from lodash.isEqual to objectEquals 2025-06-26 20:00:39 +10:00
psychedelicious
7aefa8f36b fix(ui): invalidate image name list cache on mutation 2025-06-26 20:00:38 +10:00
psychedelicious
242da9e888 fix(ui): hide ref panel when last one is deleted 2025-06-26 20:00:38 +10:00
psychedelicious
1aedc26041 feat(ui): handle ref image deletion autoswitch 2025-06-26 20:00:38 +10:00
psychedelicious
2c7fa90892 chore: bump version to v6.0.0a8 2025-06-26 20:00:38 +10:00
psychedelicious
6c8cf99ad2 feat(ui): revised ref image panel 2025-06-26 20:00:38 +10:00
psychedelicious
a92ba2542c feat(ui): switch to canvas tab when using launchpad 2025-06-26 20:00:38 +10:00
psychedelicious
2367b9f945 chore: bump version to v6.0.0a7 2025-06-26 20:00:38 +10:00
psychedelicious
a928ed0204 chore(ui): dpdm 2025-06-26 20:00:38 +10:00
psychedelicious
e164451dfe chore: ruff 2025-06-26 20:00:38 +10:00
psychedelicious
d74d079356 fix(ui): restore gallery selection count tag 2025-06-26 20:00:38 +10:00
psychedelicious
0eb4360c01 fix(ui): debounce gallery min width value 2025-06-26 20:00:38 +10:00
psychedelicious
937c03f2ec chore(ui): disable debug logger 2025-06-26 20:00:38 +10:00
psychedelicious
f7b249252d fix(ui): issues with progress viewer 2025-06-26 20:00:37 +10:00
psychedelicious
b2b42be51c refactor: remove unused methods/routes, fix some gallery invalidation issues 2025-06-26 20:00:37 +10:00
psychedelicious
98368b0665 feat(ui): restore gallery hotkeys (except delete) 2025-06-26 20:00:37 +10:00
psychedelicious
b5eb3d9798 fix(ui): gallery updates on image completion 2025-06-26 20:00:37 +10:00
psychedelicious
1218f49e20 fix(ui): remove context from DOM props 2025-06-26 20:00:37 +10:00
psychedelicious
89c609fd61 feat(ui): calculate gridTemplateColumns in selector 2025-06-26 20:00:37 +10:00
psychedelicious
b204fb6a91 chore: ruff 2025-06-26 20:00:37 +10:00
psychedelicious
6e3e316416 chore: bump version to v6.0.0a6 2025-06-26 20:00:37 +10:00
psychedelicious
bf5fc9512d fix(ui): minor jank when siwtching images rapidly 2025-06-26 20:00:37 +10:00
psychedelicious
7080889ed4 feat(ui): scrollbar styles 2025-06-26 20:00:37 +10:00
psychedelicious
adea983bfc refactor: gallery scroll (improved impl) 2025-06-26 20:00:37 +10:00
psychedelicious
f68d8ed36a refactor: gallery scroll (improved impl) 2025-06-26 20:00:37 +10:00
psychedelicious
d45197e0af refactor: gallery scroll (improved impl) 2025-06-26 20:00:36 +10:00
psychedelicious
434d8a2b12 refactor: gallery scroll (improved impl) 2025-06-26 20:00:36 +10:00
psychedelicious
f55c593705 refactor: gallery scroll (improved impl) 2025-06-26 20:00:36 +10:00
psychedelicious
8327d86774 refactor: gallery scroll (improved impl) 2025-06-26 20:00:36 +10:00
psychedelicious
c8254710e6 refactor: gallery scroll (improved impl) 2025-06-26 20:00:36 +10:00
psychedelicious
0a8f647260 refactor: gallery scroll (improved impl) 2025-06-26 20:00:36 +10:00
psychedelicious
32a5e9652a refactor: gallery scroll (improved impl) 2025-06-26 20:00:36 +10:00
psychedelicious
87909a06a8 refactor: gallery scroll (improved impl) 2025-06-26 20:00:36 +10:00
psychedelicious
2c8ce6f2f4 refactor: gallery scroll (improved impl) 2025-06-26 20:00:36 +10:00
psychedelicious
bee4cf41b4 refactor: gallery scroll 2025-06-26 20:00:36 +10:00
psychedelicious
049a8d8144 fix(ui): fix metadata toggle stuck disabled 2025-06-26 20:00:36 +10:00
psychedelicious
ac81ec41c3 chore: bump version to v6.0.0a5 2025-06-26 20:00:35 +10:00
psychedelicious
a294e8e0fd chore(ui): lint 2025-06-26 20:00:35 +10:00
psychedelicious
4665f0df40 refactor(ui): use image names for selection instead of dtos
Update the frontend to incorporate the previous changes to how image
selection and general image identification is handled in the frontend.
2025-06-26 20:00:35 +10:00
psychedelicious
70382294f5 chore(ui): typegen 2025-06-26 20:00:35 +10:00
psychedelicious
4028cadfaf feat(api): return more data when doing image/board mutations
When we delete images, boards, or do any other board mutation, we need
to invalidate numerous query caches and related internal frontend state.
This gets complicated very quickly.

We can drastically reduce the complexity by having the backend return
some more information when we make these mutations.

For example, when deleting a list of images by name, we can return a
list of deleted image name and affected boards. The frontend can use
this information to determine which queries to invalidate with far less
tedium.

This will also enable the more efficient storage of images (e.g. in the
gallery selection). Previously, we had to store the entire image DTO
object, else we wouldn't be able to figure out which queries to
invalidate. But now that the backend tells us exactly what images/boards
have changed, we can just store image names in frontend state. This
amounts to a substantial improvement in DX and reduction in frontend
complexity.
2025-06-26 20:00:35 +10:00
psychedelicious
d23cdfd0ad feat(ui): viewer integrates progress (wip) 2025-06-26 20:00:35 +10:00
psychedelicious
f0ba693922 feat(ui): switch to viewer/canvas on invoke 2025-06-26 20:00:35 +10:00
psychedelicious
214005d795 feat(ui): generation progress tab improvements 2025-06-26 20:00:35 +10:00
psychedelicious
34aa131115 feat(ui): show last progress message & placeholder in generation progress panel 2025-06-26 20:00:35 +10:00
psychedelicious
5d8061bea9 fix(ui): staging area does not show placeholder on first render 2025-06-26 20:00:35 +10:00
psychedelicious
36ec1015d6 feat(ui): double-click staging area image to disable auto-switch 2025-06-26 20:00:35 +10:00
psychedelicious
7208373576 fix(ui): reset last started item id when doing autoswitch 2025-06-26 20:00:35 +10:00
psychedelicious
e10afe3026 feat(ui): re-implement multiple auto-switch modes 2025-06-26 20:00:34 +10:00
psychedelicious
399d6e7bce chore: bump version to v6.0.0a4 2025-06-26 20:00:34 +10:00
psychedelicious
8d0fe5522b feat(ui): no model error state for ref images 2025-06-26 20:00:34 +10:00
psychedelicious
81341deb46 feat(ui): mini metadata viewer 2025-06-26 20:00:34 +10:00
psychedelicious
a30933b09c feat(ui): clean up image view components & code 2025-06-26 20:00:34 +10:00
psychedelicious
3264188ffd fix(ui): launchpad layouts 2025-06-26 20:00:34 +10:00
psychedelicious
3984b341e1 fix(ui): don't use layers when generating on generate tab 2025-06-26 20:00:34 +10:00
psychedelicious
041023df53 feat(ui): tweak vertical tab bar layout 2025-06-26 20:00:34 +10:00
psychedelicious
b06f76cdb6 fix(ui): unable to resize prompt box bc negative prompt button is over
the handle
2025-06-26 20:00:34 +10:00
psychedelicious
852badc90b feat(ui): standardize auto layout structure 2025-06-26 20:00:34 +10:00
psychedelicious
01953cf057 feat(ui): tweak dockview tabs 2025-06-26 20:00:34 +10:00
psychedelicious
241844bdef refactor(ui): rip out image viewer as modal 2025-06-26 20:00:34 +10:00
psychedelicious
33a28ad4f9 chore: bump version to v6.0.0a3 2025-06-26 20:00:34 +10:00
psychedelicious
7c4550cbd5 chore(ui): lint 2025-06-26 20:00:33 +10:00
psychedelicious
553d1a6ac6 feat(ui): restore all panel hotkeys 2025-06-26 20:00:33 +10:00
psychedelicious
f4794e409b fix(ui): generate tab hotkey 2025-06-26 20:00:33 +10:00
psychedelicious
df87800d61 feat(ui): restore floating panel buttons 2025-06-26 20:00:33 +10:00
psychedelicious
16993cd216 feat(ui): get all tabs working w/ new layout 2025-06-26 20:00:33 +10:00
psychedelicious
7f222ffb9d fix(ui): unnecessary dependency on tab selection in
useCanvasDeleteLayerHotkey
2025-06-26 20:00:33 +10:00
psychedelicious
e0ed56ff8d fix(ui): inverted logic for resume queue button 2025-06-26 20:00:33 +10:00
psychedelicious
e7e1142c77 feat(ui): get layouts working 2025-06-26 20:00:33 +10:00
psychedelicious
fcaeba290e feat(ui): canvas launchpad 2025-06-26 20:00:33 +10:00
psychedelicious
6eecdca56c wip 2025-06-26 20:00:33 +10:00
psychedelicious
7f44da4902 fix(ui): wonky stage sizing on first visibility 2025-06-26 20:00:33 +10:00
psychedelicious
abaa33e22c wip 2025-06-26 20:00:32 +10:00
psychedelicious
d5c238e7c2 feat(ui): port UI slice to zod 2025-06-26 20:00:32 +10:00
psychedelicious
18775e8b67 fix(ui): only show weight for IP adapters 2025-06-26 20:00:32 +10:00
psychedelicious
903776bfbc feat(ui): represent IP adapter weight in ref image thumbnail 2025-06-26 20:00:32 +10:00
psychedelicious
a5baf0c102 fix(ui): overflow on ref image model 2025-06-26 20:00:32 +10:00
psychedelicious
a7e45731ec feat(ui): ref images feel more like buttons 2025-06-26 20:00:32 +10:00
psychedelicious
32aa3e6d48 feat(ui): switch tab on drag over tab button 2025-06-26 20:00:32 +10:00
psychedelicious
2f9ea91896 feat(ui): tweak splash screen layout 2025-06-26 20:00:32 +10:00
psychedelicious
5ac5115269 chore(ui): lint 2025-06-26 20:00:32 +10:00
psychedelicious
161624c722 feat(ui): rework simple session initial state 2025-06-26 20:00:32 +10:00
psychedelicious
c31cb0b106 fix(ui): invoke button tooltip on generate tab 2025-06-26 20:00:32 +10:00
psychedelicious
893f7a8744 fix(ui): progress image fixes 2025-06-26 20:00:32 +10:00
psychedelicious
2e0824a799 feat(ui): make autoswitch on/off
When the invocation cache is used, we might skip all progress images. This can prevent auto-switch-on-first-progress from working, as we don't get any of those events.

It's much easier to only support auto-switch on complete.
2025-06-26 20:00:31 +10:00
psychedelicious
ed05bf2df3 feat(ui): refine ref images UI 2025-06-26 20:00:31 +10:00
psychedelicious
0f1a69a0c3 feat(ui): toggleable negative prompt 2025-06-26 20:00:31 +10:00
psychedelicious
450a0bf142 fix(ui): remove old isSelected from refImageAdded call 2025-06-26 19:59:05 +10:00
psychedelicious
a28c15d545 chore: bump version to v6.0.0a2 2025-06-26 19:59:05 +10:00
psychedelicious
1b1e1983d9 fix(ui): update queue item preview images on init of queue items context 2025-06-26 19:59:05 +10:00
psychedelicious
d08e2fbd82 fix(ui): hack to close chakra tooltips on drag 2025-06-26 19:59:04 +10:00
psychedelicious
45b1ef6231 tweak(ui): ref image header 2025-06-26 19:59:04 +10:00
psychedelicious
3bb446c08f experiment(ui): add generate tab 2025-06-26 19:59:04 +10:00
psychedelicious
8d1ab0a2e5 refactor(ui): ref images (WIP) 2025-06-26 19:59:04 +10:00
psychedelicious
48e2e7e4a1 refactor(ui): ref images (WIP) 2025-06-26 19:59:04 +10:00
psychedelicious
5a2f5c105d refactor(ui): refImage.ipAdapter -> refImage.config 2025-06-26 19:57:15 +10:00
psychedelicious
aa93e95a94 feat(ui): split out ref images into own slice (WIP) 2025-06-26 19:55:21 +10:00
psychedelicious
a5e5cbd7c3 feat(ui): simple session initial state cards are buttons 2025-06-26 19:51:37 +10:00
psychedelicious
baa9141be3 chore(ui): dpdm 2025-06-26 19:51:37 +10:00
psychedelicious
c7ed351bab refactor(ui): async modal pattern; use for deleting images
This was needed for a canvas flow change which is currently paused, but the new API is much much nicer to use, so I am keeping it.
2025-06-26 19:51:37 +10:00
psychedelicious
8c17bde4ea fix(ui): use imageDTO in staging area 2025-06-26 19:51:37 +10:00
psychedelicious
ba082ccc2f fix(ui): wait until last queue item deleted before flagging canvas session finished 2025-06-26 19:51:37 +10:00
psychedelicious
01784fb3bf feat(ui): store output image DTO in session context instead of just the name 2025-06-26 19:51:37 +10:00
psychedelicious
a71a0e143c feat(ui): add AppGetState type 2025-06-26 19:51:37 +10:00
psychedelicious
94afc13813 feat(ui): close viewer on escape 2025-06-26 19:51:37 +10:00
psychedelicious
d640a9001b fix(ui): switch only on first progress image 2025-06-26 19:51:37 +10:00
psychedelicious
711fe91b24 feat(ui): add on first progress autoswitch mode 2025-06-26 19:51:37 +10:00
psychedelicious
2f26657c17 feat(ui): move canvas-specific staging subscriptions to CanvasStagingAreaModule 2025-06-26 19:51:37 +10:00
psychedelicious
6754fde935 chore(ui): lint 2025-06-26 19:51:37 +10:00
psychedelicious
ac206f4767 feat(ui): make main panel styling and title consistent 2025-06-26 19:51:37 +10:00
psychedelicious
c316f07fb2 feat(ui): add startover button to canvas toolbar 2025-06-26 19:51:36 +10:00
psychedelicious
e81dde0933 feat(ui): fiddle w/ staging area header 2025-06-26 19:51:36 +10:00
psychedelicious
9f392c8c3c feat(ui): remove technical progress message from full preview 2025-06-26 19:51:36 +10:00
psychedelicious
2531366386 feat(ui): simple session initial state 2025-06-26 19:51:36 +10:00
psychedelicious
9df69496e4 feat(ui): remove vary and edit as control buttons 2025-06-26 19:51:36 +10:00
psychedelicious
2ddcde13ff refactor(ui): migrate from canceling queue items to deleteing, make queue hook APIs consistent 2025-06-26 19:51:36 +10:00
psychedelicious
cc5083599d fix(ui): mini preview bg color 2025-06-26 19:51:36 +10:00
psychedelicious
2431060a7e fix(ui): hide layers when not on canvas tab 2025-06-26 19:51:36 +10:00
psychedelicious
592c842632 build(ui): temporarily ignore all knip issues 2025-06-26 19:51:36 +10:00
psychedelicious
bc3550f238 feat(ui): finish generation when discarding last item 2025-06-26 19:51:36 +10:00
psychedelicious
23511d68db feat(ui): when discarding last item, select new last instead of first 2025-06-26 19:51:36 +10:00
psychedelicious
cd0668dd0b feat(ui): tweak staging image display 2025-06-26 19:51:35 +10:00
psychedelicious
bf5ed61b84 feat(ui): add staging area toolbar to simple session 2025-06-26 19:51:35 +10:00
psychedelicious
3038a797a6 fix(ui): ensure canvas tool modules are destroyed 2025-06-26 19:51:35 +10:00
psychedelicious
9bbc31b2d9 fix(ui): reset layers when changing session type 2025-06-26 19:51:35 +10:00
psychedelicious
526e6335a1 feat(ui): improved staging placeholders 2025-06-26 19:51:35 +10:00
psychedelicious
1412c079ad feat(ui): improved staging placeholders 2025-06-26 19:51:35 +10:00
psychedelicious
6570c0c3b9 feat(ui): more staging fixes 2025-06-26 19:51:35 +10:00
psychedelicious
3a08ea799a feat(ui): update canvas session state handling for new staging strat 2025-06-26 19:51:35 +10:00
psychedelicious
e3fc244126 chore(ui): lint (partial cleanup) 2025-06-26 19:51:35 +10:00
psychedelicious
56938ca0a1 feat(ui): rough out canvas staging area 2025-06-26 19:51:34 +10:00
psychedelicious
5d80642ea4 feat(app): support deleting queue items by id or destination 2025-06-26 19:50:37 +10:00
psychedelicious
da4b084a8b feat(ui): tweak canvas scroll to zoom feel 2025-06-26 19:50:37 +10:00
psychedelicious
86e1a37a00 docs(ui): add comment about auto-switch not being quite right yet 2025-06-26 19:50:37 +10:00
psychedelicious
ea34690709 feat: canvas flow rework (wip) 2025-06-26 19:50:37 +10:00
psychedelicious
c8df7cd2c0 feat(ui): prevent flicker of image action buttons 2025-06-26 19:50:37 +10:00
psychedelicious
628367b97b feat(ui): move socket events handling into ctx component 2025-06-26 19:50:37 +10:00
psychedelicious
002816653e feat(ui): modularize all staging area logic so it can be shared w/ canvas more easily 2025-06-26 19:50:37 +10:00
psychedelicious
b05de8634d perf(ui): queue actions menu is lazy 2025-06-26 19:50:36 +10:00
psychedelicious
5088e700ad fix(ui): cursor on staging area preview image 2025-06-26 19:50:36 +10:00
psychedelicious
d2155e98ef feat(ui): remove clear queue ui components 2025-06-26 19:50:36 +10:00
psychedelicious
7ec511da01 feat(app): do not prune queue on startup
With the new canvas design, this will result in loss of staging area images.
2025-06-26 19:50:36 +10:00
psychedelicious
985cd8272b tidy(ui): component organization 2025-06-26 19:50:36 +10:00
psychedelicious
cd136194ad fix(ui): prevent drag of progress images 2025-06-26 19:50:36 +10:00
psychedelicious
2e2ac71278 feat: canvas flow rework (wip) 2025-06-26 19:50:36 +10:00
psychedelicious
db4220fb20 feat: canvas flow rework (wip) 2025-06-26 19:50:36 +10:00
psychedelicious
84f70942e7 chore(ui): typegen 2025-06-26 19:50:36 +10:00
psychedelicious
0af20b03e5 feat(api): remove status from list all queue items query 2025-06-26 19:50:36 +10:00
psychedelicious
e16414b452 tidy(ui): app layout components 2025-06-26 19:50:36 +10:00
psychedelicious
5dbc2a74a2 feat: canvas flow rework (wip) 2025-06-26 19:50:36 +10:00
psychedelicious
ad736bc190 feat: canvas flow rework (wip) 2025-06-26 19:50:35 +10:00
psychedelicious
0e9b71801a feat: canvas flow rework (wip) 2025-06-26 19:50:35 +10:00
psychedelicious
e80f0b2b43 fix(ui): unstable selector results in lora drop down 2025-06-26 19:50:35 +10:00
psychedelicious
c9042e52d4 feat: canvas flow rework (wip) 2025-06-26 19:50:35 +10:00
psychedelicious
8a78e37634 feat: canvas flow rework (wip) 2025-06-26 19:50:35 +10:00
psychedelicious
5e93f58530 wip progress events 2025-06-26 19:50:35 +10:00
psychedelicious
a3851e0b08 refactor(ui): canvas flow (wip) 2025-06-26 19:50:35 +10:00
psychedelicious
eb45a457e9 fix(ui): ref goes undefined in GalleryImage
This appears to be a bug in Chakra UI v2 - use of a fallback component makes the ref passed to an image end up undefined. Had to remove the skeleton loader fallback component.
2025-06-26 19:50:35 +10:00
psychedelicious
1446d3490b fix(ui): merge refs when forwardingin DndImage 2025-06-26 19:50:35 +10:00
psychedelicious
579318af70 fix(ui): remove unused sessionId field from type 2025-06-26 19:50:35 +10:00
psychedelicious
57bfae6774 fix(ui): ensure all args are passed to handler when creating new canvas from image 2025-06-26 19:50:35 +10:00
psychedelicious
2a92524546 feat(ui): bookmark new inpaint masks 2025-06-26 19:50:34 +10:00
psychedelicious
7a5fa25b48 feat(ui): support bookmarking an entity when adding it 2025-06-26 19:50:34 +10:00
psychedelicious
b3f3020793 fix(ui): ensure images are added to gallery in simple sessions 2025-06-26 19:50:34 +10:00
psychedelicious
650809e50d feat(ui): images always added to gallery in simple session 2025-06-26 19:50:34 +10:00
psychedelicious
7308428f32 wip 2025-06-26 19:50:34 +10:00
psychedelicious
4dc3f1bcee refactor(ui): canvas flow (wip) 2025-06-26 19:50:34 +10:00
psychedelicious
faeb5f0c3b refactor(ui): canvas flow (wip) 2025-06-26 19:50:34 +10:00
psychedelicious
d985dfe821 refactor(ui): canvas flow events (wip) 2025-06-26 19:50:34 +10:00
psychedelicious
ce5ae83689 refactor(ui): canvas flow (wip) 2025-06-26 19:50:34 +10:00
psychedelicious
c0428ee7ef refactor(ui): canvas flow (wip) 2025-06-26 19:50:34 +10:00
psychedelicious
aa3b2106d4 refactor(ui): canvas flow (wip) 2025-06-26 19:50:34 +10:00
psychedelicious
cf2d67ef3d refactor(ui): canvas flow (wip) 2025-06-26 19:50:33 +10:00
psychedelicious
c4d1e78f59 fix(ui): circular import issue 2025-06-26 19:50:33 +10:00
psychedelicious
02e4a3aa82 refactor(ui): params state zodification 2025-06-26 19:50:33 +10:00
psychedelicious
a0b0c30be9 refactor(ui): move params state to big file of canvas zod stuff 2025-06-26 19:50:33 +10:00
psychedelicious
5c4cbc7fa2 refactor(ui): zod-ify params slice state 2025-06-26 19:50:33 +10:00
psychedelicious
5f2f12f803 refactor(ui): org state in prep for new flow 2025-06-26 19:50:33 +10:00
psychedelicious
c9cd0a87be refactor(ui): image viewer & comparison convolutedness 2025-06-26 19:49:01 +10:00
psychedelicious
668c475271 feat(ui): default canvas tool is move 2025-06-26 19:49:01 +10:00
psychedelicious
341910739e chore(ui): bump @reduxjs/toolkit to latest 2025-06-26 19:49:01 +10:00
psychedelicious
53a3dc52bc feat(ui): viewer is a modal (wip) 2025-06-26 19:49:01 +10:00
Billy
23b0a4a7f4 Update uv lock 2025-06-26 19:47:06 +10:00
Billy
6afbf31750 Ruff formatting 2025-06-26 19:47:06 +10:00
Billy
3cd4306eec Update import path 2025-06-26 19:47:06 +10:00
Billy
827191d2fc Use definitions in config 2025-06-26 19:47:06 +10:00
Billy
aaa34f717d OMI files 2025-06-26 19:47:06 +10:00
Billy
fe83c2f81f Add OMI vendor files 2025-06-26 19:47:06 +10:00
Billy
17dead3309 Remove OMI from dependencies 2025-06-26 19:47:06 +10:00
Mary Hipp Rogers
979bd33dfb fix 1:1 ratio (#8127)
Co-authored-by: Mary Hipp <maryhipp@Marys-Air.lan>
2025-06-25 19:39:21 -04:00
psychedelicious
5128f072a8 feat: add user_label to FieldIdentifier (#8126)
Co-authored-by: Mary Hipp Rogers <maryhipp@gmail.com>
2025-06-25 13:44:57 +00:00
Mary Hipp Rogers
2ad5b5cc2e Flux Kontext UI support (#8111)
* add support for flux-kontext models in nodes

* flux kontext in canvas

* add aspect ratio support

* lint

* restore aspect ratio logic

* more linting

* typegen

* fix typegen

---------

Co-authored-by: Mary Hipp <maryhipp@Marys-Air.lan>
2025-06-25 09:39:57 -04:00
jazzhaiku
24d8a96071 Omi (#8120)
## Summary

Support for
[OMI](https://github.com/Open-Model-Initiative/OMI-Model-Standards/tree/main)
LoRAs that use Flux and SDXL as the base model. Automated tests for
config classification. Manually tested (visual inspection) for LoRA
loading and execution.



## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
2025-06-24 14:53:57 +10:00
Billy
f1e4665aa2 Revert 2025-06-24 08:53:39 +10:00
Billy
1cbfea3a21 Update uv lock 2025-06-24 08:45:57 +10:00
Billy
981e8e217d Regenerate uv lock 2025-06-24 07:42:44 +10:00
Billy
e7ca30f406 Updated schema 2025-06-24 07:38:51 +10:00
Billy
2832ca300f Formatting 2025-06-24 07:26:42 +10:00
Billy
de5f413440 Filter bundle_emb for all LoRAs 2025-06-24 07:12:11 +10:00
Billy
fbc14c61ea Remove bundle_emb filter 2025-06-24 06:53:33 +10:00
Kent Keirsey
77e029a49f Ignore bundled embeddings in conversion 2025-06-23 10:05:55 -04:00
Kent Keirsey
61b049ad35 Fix to config 2025-06-23 09:52:47 -04:00
Billy
b88f4a24d0 Frontend types 2025-06-23 14:01:41 +10:00
Billy
8c632f0d32 Remove files 2025-06-23 13:54:21 +10:00
Billy
150a876c73 Formatting 2025-06-23 13:52:19 +10:00
Billy
62c3b01e4f Merge branch 'main' into OMI 2025-06-23 13:52:07 +10:00
Billy
e1157f343b Support for Flux and SDXL 2025-06-23 13:51:16 +10:00
Kent Keirsey
6a78739076 Change save button to Invoke Blue 2025-06-20 15:07:40 +10:00
psychedelicious
0794eb43e7 fix(nodes): ensure each invocation overrides _original_model_fields with own field data 2025-06-20 15:03:55 +10:00
Billy
4ee54eac1d Another attempt 2025-06-20 14:10:06 +10:00
Billy
5851c46c81 Hard code source 2025-06-19 11:05:43 +10:00
Billy
a296559e79 Ignore 2025-06-19 11:02:18 +10:00
Billy
1fd83f5e68 Import 2025-06-19 11:01:50 +10:00
Billy
637487c573 Convert FROM OMI to diffusers 2025-06-19 11:00:27 +10:00
Billy
4e98e7d0a2 Typo: dot should be comma 2025-06-19 10:47:24 +10:00
Billy
12f65d800d Formatting 2025-06-19 09:40:58 +10:00
Billy
45d09f8f51 Use OMI conversion utils 2025-06-19 09:40:49 +10:00
Billy
2876c72fa9 Schema update 2025-06-18 10:54:01 +10:00
Billy
9b4fdb493e Loader 2025-06-18 10:53:54 +10:00
Billy
47e21d6e04 Formatting 2025-06-17 13:56:38 +10:00
Billy
84ab4a1c30 Convert from OMI to default LoRA state dict 2025-06-17 13:56:22 +10:00
Billy
85c4304efd Add OMI LoRA config 2025-06-17 13:34:03 +10:00
Billy
8f152f162b Add OMI to model format taxonomy 2025-06-17 13:33:40 +10:00
Billy
63b49f045a Add stripped models for testing OMI 2025-06-17 13:33:23 +10:00
Mary Hipp
291e0736d6 fix names of unpublishable nodes 2025-06-16 12:40:54 -04:00
psychedelicious
4bfa6439d4 chore(ui): typgen 2025-06-16 19:33:19 +10:00
psychedelicious
a8d7969a1d fix(app): config docstrings 2025-06-16 19:33:19 +10:00
Heathen711
46bfa24af3 ruff format 2025-06-16 19:33:19 +10:00
Heathen711
a8cb8e128d run "make frontend-typegen" 2025-06-16 19:33:19 +10:00
Heathen711
8cef0f5bf5 Update supported cuda slot input. 2025-06-16 19:33:19 +10:00
psychedelicious
911baeb58b chore(ui): bump version to v5.15.0 2025-06-16 19:18:25 +10:00
Kevin Turner
312960645b fix: move AI Toolkit to the bottom of the detection list
to avoid disrupting already-working LoRA
2025-06-16 19:08:11 +10:00
Kevin Turner
50cf285efb fix: group aitoolkit lora layers 2025-06-16 19:08:11 +10:00
Kevin Turner
a214f4fff5 fix: group aitoolkit lora layers 2025-06-16 19:08:11 +10:00
Kevin Turner
2981591c36 test: add some aitoolkit lora tests 2025-06-16 19:08:11 +10:00
Kevin Turner
b08f90c99f WIP!: …they weren't in diffusers format… 2025-06-16 19:08:11 +10:00
Kevin Turner
ab8c739cd8 fix(LoRA): add ai-toolkit to lora loader 2025-06-16 19:08:11 +10:00
Kevin Turner
5c5108c28a feat(LoRA): support AI Toolkit LoRA for FLUX [WIP] 2025-06-16 19:08:11 +10:00
j-brooke
3df7cfd605 Updated fracturedjsonjs to version 4.1.0 and included settings adjustments for more pleasing comma placement. 2025-06-14 14:59:43 +10:00
psychedelicious
1ff3d44dba fix(app): guard against possible race conditions during enqueue
In #7724 we made a number of perf optimisations related to enqueuing. One of these optimisations included moving the enqueue logic - including expensive prep work and db writes - to a separate thread.

At the same time manual DB locking was abandoned in favor of WAL mode.

Finally, we set `check_same_thread=False` to allow multiple threads to access the connection at a given time.

I think this may be the cause of #7950:
- We start an enqueue in a thread (running in bg)
- We dequeue
- Dequeue pulls a partially-written queue item from DB and we get the errors in the linked issue

To be honest, I don't understand enough about SQLite to confidently say that this kind of race condition is actually possible. But:
- The error started popping up around the time we made this change.
- I have reviewed the logic from enqueue to dequeue very carefully _many_ times over the past month or so, and I am confident that the error is only possible if we are getting unexpectedly `NULL` values from the DB.
- The DB schema includes `NOT NULL` constraints for the column that is apparently returning `NULL`.
- Therefore, without some kind of race condition or schema issue, the error should not be possible.
- The `enqueue_batch` call is the only place I can find where we have the possibility of a race condition due to async logic. Everywhere else, all DB interaction for the queue is synchronous, as far as I can tell.

This change retains the perf benefits by running the heavy enqueue prep logic in a separate thread, but moves back to the main thread for the DB write. It also uses an explicit transaction for the write.

Will just have to wait and see if this fixes the issue.
2025-06-13 23:51:47 +10:00
Emmanuel Ferdman
c80ad90f72 Migrate to modern logger interface
Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com>
2025-06-13 13:07:09 +10:00
psychedelicious
3b4d1b8786 perf(app): gc before every queue item
This reduces peak memory usage at a negligible cost. Queue items typically take on the order of seconds, making the time cost of a GC essentially free.

Not a great idea on a hotter code path though.
2025-06-11 12:56:16 +10:00
psychedelicious
c66201c7e1 perf(app): skip TI logic when no TIs to apply 2025-06-11 12:56:16 +10:00
psychedelicious
35c7c59455 fix(app): reduce peak memory usage
We've long suspected there is a memory leak in Invoke, but that may not be true. What looks like a memory leak may in fact be the expected behaviour for our allocation patterns.

We observe ~20 to ~30 MB increase in memory usage per session executed. I did some prolonged tests, where I measured the process's RSS in bytes while doing 200 SDXL generations. I found that it eventually leveled off at around 100 generations, at which point memory usage had climbed by ~900MB from its starting point.

I used tracemalloc to diff the allocations of single session executions and found that we are allocating ~20MB or so per session in `ModelPatcher.apply_ti()`.

In `ModelPatcher.apply_ti()` we add tokens to the tokenizer when handling TIs. The added tokens should be scoped to only the current invocation, but there is no simple way to remove the tokens afterwards.

As a workaround for this, we clone the tokenizer, add the TI tokens to the clone, and use the clone to when running compel. Afterwards, this cloned tokenizer is discarded.

The tokenizer uses ~20MB of memory, and it has referrers/referents to other compel stuff. This is what is causing the observed increases in memory per session!

We'd expect these objects to be GC'd but python doesn't do it immediately. After creating the cond tensors, we quickly move on to denoising. So there isn't any time for the GC to happen to free up its existing memory arenas/blocks to reuse them. Instead, python needs to request more memory from the OS.

We can improve the situation by immediately calling `del` on the tokenizer clone and related objects. In fact, we already had some code in the compel nodes to `del` some of these objects, but not all.

Adding the `del`s vastly improves things. We hit peak RSS in half the sessions (~50 or less) and it's now ~100MB more than starting value. There is still a gradual increase in memory usage until we level off.
2025-06-11 12:56:16 +10:00
psychedelicious
85f98ab3eb fix(app): error on upload + resize for unusual image modes 2025-06-11 11:18:08 +10:00
Mary Hipp
dac75685be disable publish and cancel buttons once it begins 2025-06-10 19:50:09 -04:00
psychedelicious
d7b5a8b298 fix: opencv dependency conflict (#8095)
* build: prevent `opencv-python` from being installed

Fixes this error: `AttributeError: module 'cv2.ximgproc' has no attribute 'thinning'`

`opencv-contrib-python` supersedes `opencv-python`, providing the same API + additional features. The two packages should not be installed at the same time to avoid conflicts and/or errors.

The `invisible-watermark` package requires `opencv-python`, but we require the contrib variant.

This change updates `pyproject.toml` to prevent `opencv-python` from ever being installed using a `uv` features called dependency overrides.

* feat(ui): data viewer supports disabling wrap

* feat(api): list _all_ pkgs in app deps endpoint

* chore(ui): typegen

* feat(ui): update about modal to display new full deps list

* chore: uv lock
2025-06-10 08:33:41 -04:00
Kent Keirsey
d3ecaa740f Add Precise Reference to Starter Models 2025-06-09 22:02:11 +10:00
dunkeroni
b5a6765a3d also search image creation date 2025-06-09 21:54:26 +10:00
psychedelicious
3704573ef8 chore: bump version to v5.14.0 2025-06-06 22:36:32 +10:00
Hiroto N
01fbf2ce4d translationBot(ui): update translation (Japanese)
Currently translated at 76.5% (1467 of 1917 strings)

Co-authored-by: Hiroto N <hironow365@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/ja/
Translation: InvokeAI/Web UI
2025-06-06 20:56:13 +10:00
Riccardo Giovanetti
96e7003449 translationBot(ui): update translation (Italian)
Currently translated at 98.9% (1896 of 1917 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2025-06-06 20:56:13 +10:00
RyoKoba
80197b8856 translationBot(ui): update translation (Japanese)
Currently translated at 76.1% (1460 of 1917 strings)

Co-authored-by: RyoKoba <kobayashi_ryo@cyberagent.co.jp>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/ja/
Translation: InvokeAI/Web UI
2025-06-06 20:52:36 +10:00
Hosted Weblate
0187bc671e translationBot(ui): update translation files
Updated by "Cleanup translation files" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI
2025-06-06 20:52:36 +10:00
psychedelicious
31584daabe feat(ui): display canvas spinner during compositing operations 2025-06-06 20:50:02 +10:00
psychedelicious
a6cb522fed feat(ui): add bboxUpdated callback to transformer, use it to fit layer to stage when creating new canvas from an image
When a layer is initialized, we do not yet know its bbox, so we cannot fit the stage view to the layer. We have to wait for the bbox calculation to finish. Previously, we had no way to wait unti lthat bbox calculation was complete to take an action.

For example, this means we could not fit the layers to the stage immediately after creating a new layer, bc we don't know the dimensions of the layer yet.

This callback lets us do that. When creating a new canvas from an image, we now...
- Register a bbox update callback to fit the layers to stage
- Layer is created
- Canvas initializes the layer's entity adapter module (layer's width and height are set to zero at this point)
- Canvas calculates the bbox
- Bbox is updated (width and height are now correct)
- Callback is ran, fitting layer to stage
2025-06-06 20:50:02 +10:00
psychedelicious
f70be1e415 feat(ui): animate stage fit operations (e.g. fit layers to stage) 2025-06-06 20:50:02 +10:00
psychedelicious
a2901f2b46 feat(ui): add method to stage to fit to union of bbox and layers
This ensures that _both_ bbox and layers are visible
2025-06-06 20:50:02 +10:00
psychedelicious
b61c66c3a9 feat(ui): add spinner indicator to canvas during rasterizing operations and while pending rect calculations 2025-06-06 20:50:02 +10:00
psychedelicious
c77f9ec202 feat(ui): add hook to get all entity adapters in array 2025-06-06 20:50:02 +10:00
psychedelicious
2c5c35647f fix(ui): new canvas from image places image in bbox correctly 2025-06-06 20:50:02 +10:00
dunkeroni
bf0fdbd10e Fix: inpaint model mask using wrong tensor name 2025-06-05 11:31:35 -04:00
psychedelicious
731d317a42 chore(ui): update whatsnew 2025-06-04 22:29:37 +10:00
psychedelicious
e81579f752 fix(mm): handle invoke syntax for HF repo ids when fetching HF model metadata
Closes #8074
2025-06-04 22:27:15 +10:00
Linos
9a10e98c0b translationBot(ui): update translation (Vietnamese)
Currently translated at 100.0% (1918 of 1918 strings)

Co-authored-by: Linos <linos.coding@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/vi/
Translation: InvokeAI/Web UI
2025-06-04 17:03:06 +10:00
Riccardo Giovanetti
27fdc139b7 translationBot(ui): update translation (Italian)
Currently translated at 98.9% (1897 of 1918 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2025-06-04 17:03:06 +10:00
psychedelicious
0a00805afc chore: bump version to v5.13.0 2025-06-04 05:55:34 +10:00
psychedelicious
7b38143fbd chore: bump version to v5.13.0rc3 2025-05-30 21:44:21 +10:00
mickr777
4c5ad1b7d7 Ruff Fix 2025-05-30 19:03:43 +10:00
mickr777
d80cc962ad Delay Imports that require torch 2025-05-30 19:03:43 +10:00
RyoKoba
7ccabfa200 translationBot(ui): update translation (Japanese)
Currently translated at 68.0% (1304 of 1915 strings)

Co-authored-by: RyoKoba <kobayashi_ryo@cyberagent.co.jp>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/ja/
Translation: InvokeAI/Web UI
2025-05-30 14:48:41 +10:00
Riccardo Giovanetti
936d59cc52 translationBot(ui): update translation (Italian)
Currently translated at 98.9% (1894 of 1915 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2025-05-30 14:48:41 +10:00
psychedelicious
fc16fb6099 chore: bump version to v5.13.0rc2 2025-05-30 14:16:33 +10:00
psychedelicious
c848cbc2e3 feat(app): move output annotation checking to run_app
Also change import order to ensure CLI args are handled correctly. Had to do this bc importing `InvocationRegistry` before parsing args resulted in the `--root` CLI arg being ignored.
2025-05-30 14:10:13 +10:00
psychedelicious
66fd0f0d8a feat(ui): warn on unregistered invocation output 2025-05-30 14:10:13 +10:00
psychedelicious
c266f39f06 chore(ui): typegen 2025-05-30 13:36:04 +10:00
psychedelicious
98a44fa4d7 fix(ui): conditional display of message 2025-05-30 13:36:04 +10:00
Mary Hipp
c1d230f961 add support to delete all uncategorized images 2025-05-30 13:36:04 +10:00
Kevin Turner
68108435ae feat(LoRA): allow LoRA layer patcher to continue past unknown layers 2025-05-30 13:29:02 +10:00
psychedelicious
e121bf1f62 feat(ui): persist sizes of all 4 prompt boxes 2025-05-30 12:36:06 +10:00
psychedelicious
4835c344b3 feat(ui): implement generalized textarea size tracking system 2025-05-30 12:36:06 +10:00
Mary Hipp
a589dec122 store positive prompt textarea height in redux so it persists across refresh 2025-05-30 12:36:06 +10:00
dunkeroni
bc67d5c841 add invert logic to grayscale mask composite 2025-05-30 11:19:37 +10:00
Mary Hipp
f3d5691c04 use onClickGoToModelManager for empty model picker 2025-05-29 11:13:55 -04:00
psychedelicious
b98abc2457 chore(ui): typegen 2025-05-29 13:49:07 +10:00
psychedelicious
7e527ccfb7 feat(api): add validationg for max resize_to on upload endpoint 2025-05-29 13:49:07 +10:00
psychedelicious
0f0c911845 chore: uv lock 2025-05-29 13:49:07 +10:00
psychedelicious
e4818b967b tidy(api): remove benchmark logging 2025-05-29 13:49:07 +10:00
psychedelicious
ce3eede26f feat(nodes): revised heuristic_resize
better handling for smaller image sizes
2025-05-29 13:49:07 +10:00
psychedelicious
d98725c5e9 feat(nodes): use guo-hall thinning 2025-05-29 13:49:07 +10:00
psychedelicious
31a96d2945 feat(ui): use resize on uplaod functionality when creating new canvas from image 2025-05-29 13:49:07 +10:00
psychedelicious
845a321a43 feat(ui): support resize_to when uploading images 2025-05-29 13:49:07 +10:00
psychedelicious
87a44a28ef chore(ui): typegen 2025-05-29 13:49:07 +10:00
psychedelicious
d5b9c3ee5a feat(api): support resizing image on upload 2025-05-29 13:49:07 +10:00
psychedelicious
91db136cd1 feat(nodes): much faster heuristic resize utility
Add `heuristic_resize_fast`, which does the same thing as `heuristic_resize`, except it's about 20x faster.

This is achieved by using opencv for the binary edge handling isntead of python, and checking only 100k pixels to determine what kind of image we are working with.

Besides being much faster, it results in cleaner lines for resized binary canny edge maps, and has results in fewer misidentified segmentation maps.

Tested against normal images, binary canny edge maps, grayscale HED edge maps, segmentation maps, and normal images.

Tested resizing up and down for each.

Besides the new utility function, I needed to swap the `opencv-python` dep for `opencv-contrib-python`, which includes `cv2.ximgproc.thinning`. This function accounts for a good chunk of the perf improvement.
2025-05-29 13:49:07 +10:00
Jonathan
f351ad4b66 Update communityNodes.md
Added some of JPPhoto's nodes.
2025-05-28 07:26:44 +10:00
psychedelicious
fb6fb9abbd gh: update CODEOWNERS
Added myself to everything so we do not get into situations where we need to rely on vic or lincoln to approve
2025-05-27 22:37:44 +10:00
psychedelicious
675c990486 docs: add comments to classifiers stuff 2025-05-27 22:02:48 +10:00
psychedelicious
6ee5cde4bb ci: do not install project when checking classifiers 2025-05-27 22:02:48 +10:00
psychedelicious
c8077f9430 ci: check classifiers in python-checks workflow 2025-05-27 22:02:48 +10:00
psychedelicious
6aabe9959e chore: fix license classifier 2025-05-27 22:02:48 +10:00
psychedelicious
0b58d172d2 build: update build script to check classifiers 2025-05-27 22:02:48 +10:00
psychedelicious
d7c6e293d7 scripts: add script to check pypi classifiers 2025-05-27 22:02:48 +10:00
psychedelicious
c600bc867d chore: bump version to v5.13.0rc1 2025-05-27 13:30:34 +10:00
Riccardo Giovanetti
f4140dd772 translationBot(ui): update translation (Italian)
Currently translated at 98.9% (1890 of 1911 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.9% (1890 of 1911 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2025-05-27 13:18:06 +10:00
psychedelicious
a2d8261d40 feat(ui): canvas scroll scale snap 2025-05-27 13:10:57 +10:00
psychedelicious
bce88a8873 perf(ui): lazy mount scale slider popover 2025-05-27 13:10:57 +10:00
psychedelicious
b37e1a3ad6 feat(ui): do not round scale
Makes it a lot smoother, don't think it breaks anything...
2025-05-27 13:10:57 +10:00
psychedelicious
35a088e0a6 perf(ui): optimize <CanvasToolbarScale /> 2025-05-27 13:10:57 +10:00
psychedelicious
b936cab039 feat(ui): add computed for stage scale 2025-05-27 13:10:57 +10:00
psychedelicious
34e4093408 fix(ui): revert snapping logic, doesn't work w/ certain input devices 2025-05-27 13:10:57 +10:00
Kent Keirsey
d7f93c3cc0 uv update 2025-05-26 22:54:15 -04:00
Kent Keirsey
d4c4926caa Update Compel to 2.1.1 and apply Sentences Split logic 2025-05-26 22:54:15 -04:00
psychedelicious
558c7db055 chore(ui): knipignore InpaintMaskAddButtons 2025-05-27 07:28:47 +10:00
psychedelicious
2ece59b51b feat(ui): remove unnecessary type casts 2025-05-27 07:28:47 +10:00
psychedelicious
7dbe39957c feat(ui): bbox rect is always defined, no need for fallback logic 2025-05-27 07:28:47 +10:00
psychedelicious
6fa46d35a5 feat(ui): inpaint mask settings layout 2025-05-27 07:28:47 +10:00
psychedelicious
b2a2b38ea8 feat(ui): split inpaint mask setting selectors to avoid manual memoization 2025-05-27 07:28:47 +10:00
dunkeroni
12934da390 Use Optional instead of Nullable for mask settings 2025-05-27 07:28:47 +10:00
dunkeroni
231bc18188 remove buttons, change denoise limit format 2025-05-27 07:28:47 +10:00
dunkeroni
530cd180c5 chore:ruff 2025-05-27 07:28:47 +10:00
dunkeroni
2a92e7b920 Flux/CogView/SD3 compatible with gradient masks 2025-05-27 07:28:47 +10:00
dunkeroni
019e057e29 chore: typegen 2025-05-27 07:28:47 +10:00
dunkeroni
9aa26f883e chore: ruff 2025-05-27 07:28:47 +10:00
dunkeroni
3f727e24b1 change default noise level to 0.15 2025-05-27 07:28:47 +10:00
dunkeroni
9e90bf1b20 fix gradient mask broken with flux gen 2025-05-27 07:28:47 +10:00
dunkeroni
db3964797f clean up comments 2025-05-27 07:28:47 +10:00
dunkeroni
881efbda1b fix: inpaint breaks when scaled processing 2025-05-27 07:28:47 +10:00
dunkeroni
e9ce2ed5f2 inpaint mask sliders compatible with outpainting 2025-05-27 07:28:47 +10:00
dunkeroni
53ac9eafbf reuse inpaint image noise seed for caching 2025-05-27 07:28:47 +10:00
dunkeroni
9e095006a5 remove some AI detritus 2025-05-27 07:28:47 +10:00
dunkeroni
21b24c3ba6 change denoise limit default to 1.0 2025-05-27 07:28:47 +10:00
dunkeroni
139ecc10ce ruff 2025-05-27 07:28:47 +10:00
dunkeroni
78ea143b46 composite masks based on denoise level 2025-05-27 07:28:47 +10:00
dunkeroni
174249ec15 grtadient mask node works on greyscale now 2025-05-27 07:28:47 +10:00
dunkeroni
2510ad7431 consolidate code 2025-05-27 07:28:47 +10:00
dunkeroni
ba5e855a60 Correctly composite grey values on white for masks 2025-05-27 07:28:47 +10:00
dunkeroni
23627cf18d compositing in frontend 2025-05-27 07:28:47 +10:00
dunkeroni
5e20c9a1ca mask noise slider option 2025-05-27 07:28:47 +10:00
Kent Keirsey
933cf5f276 update prettier 2025-05-25 23:53:16 -04:00
Kent Keirsey
41316de659 Update order 2025-05-25 23:53:16 -04:00
Kent Keirsey
041ccfd68e Enable 'pull into bounding box' from empty Control Layer 2025-05-25 23:53:16 -04:00
dunkeroni
ad24c203a4 preserve SDXL training values for bounding box 2025-05-25 08:15:37 -04:00
Kent Keirsey
3fd28ce600 Update scaling math to land on 100% consistently. 2025-05-25 07:59:27 -04:00
Mary Hipp
32df3bdf6e typegen 2025-05-22 14:09:10 -04:00
Mary Hipp
ba69e89e8c typegen 2025-05-22 14:09:10 -04:00
Mary Hipp
a8e0c48ddc add new method types to metadata 2025-05-22 14:09:10 -04:00
Jonathan
66f6571086 Update manual installation for v5.12.0 2025-05-22 09:00:58 -04:00
psychedelicious
8a3848e7b6 chore(ui): update whats new copy 2025-05-22 14:25:02 +10:00
psychedelicious
3f8486b480 chore: bump version to v5.12.0 2025-05-22 14:25:02 +10:00
Hosted Weblate
b80be4f639 translationBot(ui): update translation files
Updated by "Cleanup translation files" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI
2025-05-22 14:11:52 +10:00
Linos
adb3a849b9 translationBot(ui): update translation (Vietnamese)
Currently translated at 100.0% (1910 of 1910 strings)

Co-authored-by: Linos <linos.coding@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/vi/
Translation: InvokeAI/Web UI
2025-05-22 14:11:52 +10:00
Riccardo Giovanetti
798499fda6 translationBot(ui): update translation (Italian)
Currently translated at 98.9% (1889 of 1910 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.9% (1889 of 1910 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2025-05-22 14:11:52 +10:00
psychedelicious
02fc5a165c chore(ui): typegen 2025-05-22 13:50:15 +10:00
psychedelicious
b1b8edecfb fix(ui): minor ts issue 2025-05-22 13:50:15 +10:00
Mary Hipp
3cd8d48809 lint 2025-05-22 13:50:15 +10:00
Mary Hipp
f4672ad8c1 more cleanup 2025-05-22 13:50:15 +10:00
Mary Hipp
5a86490845 cleanup and refactor into hooks 2025-05-22 13:50:15 +10:00
Mary Hipp
27dc843046 Imagen4 working in UI 2025-05-22 13:50:15 +10:00
Mary Hipp
2f35d74902 backend updates 2025-05-22 13:50:15 +10:00
Kevin Turner
8bd52ed744 fix: improve gguf performance with torch.compile
pytorch 2.7 does not implement `set.__contains__`, so make this a list instead.

See https://github.com/pytorch/pytorch/issues/145761
2025-05-22 13:42:09 +10:00
psychedelicious
f3e2a3c384 gh: update CODEOWNERS
- Remove brandon
- Consolidate two entries for `invokeai/backend`
2025-05-22 13:37:24 +10:00
psychedelicious
ecc6e8a532 fix(nodes): transformers bug with SAM
Upstream bug in `transformers` breaks use of `AutoModelForMaskGeneration` class to load SAM models

Simple fix - directly load the model with `SamModel` class instead.

See upstream issue https://github.com/huggingface/transformers/issues/38228
2025-05-22 11:32:37 +10:00
Mary Hipp
9170576a38 make logic more straight forward 2025-05-21 10:52:04 -04:00
Mary Hipp
f26baa0341 use hook instead 2025-05-21 10:52:04 -04:00
psychedelicious
99dad953a4 chore: bump version to v5.12.0rc2 2025-05-20 14:50:03 +10:00
jazzhaiku
c39bcdffd3 Re-enable classification API as fallback (#8007)
## Summary

- Fallback to new classification API if legacy probe fails
- Method to read model metadata
- Created `StrippedModelOnDisk` class for testing
- Test to verify only a single config `matches` with a model

## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
2025-05-20 11:25:38 +10:00
Billy
32f2223237 Warning comment 2025-05-20 11:19:59 +10:00
Billy
6176941853 Warning comment 2025-05-20 11:19:59 +10:00
Billy
af41dc83f7 Make ruff happy 2025-05-20 11:19:59 +10:00
Billy
a17e771eba Re-enable classification API as fallback 2025-05-20 11:19:59 +10:00
psychedelicious
19ecdb196e chore: ruff 2025-05-20 10:47:02 +10:00
psychedelicious
15880e6ea7 fix(ui): invocation parsing for optional enum fields
For example:
```py
my_field: Literal["foo", "bar"] | None = InputField(default=None)
```

Previously, this would cause a field parsing error and prevent the app from loading.

Two fixes:
- This type annotation and resultant schema are now parsed correctly
- Error handling added to template building logic to prevent the hang at startup when an error does occur
2025-05-20 10:47:02 +10:00
psychedelicious
53ffa98662 chore(ui): typegen 2025-05-20 10:47:02 +10:00
psychedelicious
021a334240 fix(nodes): fix spots where default of None was provided for non-optional fields 2025-05-20 10:47:02 +10:00
psychedelicious
cfed293d48 fix(nodes): do not make invocation field defaults None when they are not provided 2025-05-20 10:47:02 +10:00
Mary Hipp
d36bc185c8 only use client side uploads if more than one image to retain metadata for single uploads 2025-05-20 08:03:00 +10:00
psychedelicious
7878203b03 chore(ui): update whats new copy 2025-05-19 23:28:40 +10:00
psychedelicious
3352220d39 chore: bump version to v5.12.0rc1 2025-05-19 23:28:40 +10:00
Riccardo Giovanetti
bcfb1e7e52 translationBot(ui): update translation (Italian)
Currently translated at 98.7% (1887 of 1910 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2025-05-19 23:23:07 +10:00
psychedelicious
e84b3c142c chore(ui): typegen 2025-05-19 13:50:04 +10:00
Kent Keirsey
22f637b647 ruff ruff 2025-05-19 13:50:04 +10:00
Kent Keirsey
5d192ab6e5 Fix SD precise in patcher. 2025-05-19 13:50:04 +10:00
Kent Keirsey
9273d1629e UX Copy Clean-up 2025-05-19 13:50:04 +10:00
Kent Keirsey
27a12f080b missing translation values 2025-05-19 13:50:04 +10:00
Kent Keirsey
3bfb497764 ruff fixes 2025-05-19 13:50:04 +10:00
Kent Keirsey
b849c7d382 ruff fix 2025-05-19 13:50:04 +10:00
Kent Keirsey
8d4120583d update schema pt 2 2025-05-19 13:50:04 +10:00
Kent Keirsey
402cdc7eda update schema 2025-05-19 13:50:04 +10:00
Kent Keirsey
b02ea1a898 Expanded styles & updated UI 2025-05-19 13:50:04 +10:00
Kent Keirsey
d709040f4b Matt3o base changes 2025-05-19 13:50:04 +10:00
psychedelicious
8a7a498da3 chore: update uv lock 2025-05-19 12:29:51 +10:00
psychedelicious
699736486b chore: bump torch to 2.7.0
- Update `pyproject.toml`
- Update `pins.json` so launcher installs latest CUDA 12.8 & ROCm 6.3
2025-05-19 12:29:51 +10:00
psychedelicious
37e790ae19 fix(app): address pydantic deprecation warning for accessing BaseModel.model_fields 2025-05-19 12:22:59 +10:00
David Burnett
6c0bd7d150 fix import ordering, remove code I reverted that the resync added back 2025-05-19 11:16:23 +10:00
David Burnett
99e154d773 fix picky ruff issue 2025-05-19 11:16:23 +10:00
David Burnett
e4e43ae126 fix missing bracket 2025-05-19 11:16:23 +10:00
David Burnett
a07fac6180 raise exected exception when attempting to change dtype 2025-05-19 11:16:23 +10:00
David Burnett
93d4b00082 Add to overload for GGMLTensor, so calling to on the model moves the quantized data as well 2025-05-19 11:16:23 +10:00
David Burnett
8abcc99ced add check for state_dict, required to load TI's 2025-05-19 11:16:23 +10:00
David Burnett
73ab4b8895 fix offload device 2025-05-19 11:16:23 +10:00
David Burnett
86719f2065 revert to overload due to failing tests, use Torch futures instead 2025-05-19 11:16:23 +10:00
David Burnett
5271fc1cac fix picky ruff issue 2025-05-19 11:16:23 +10:00
David Burnett
96ff7d9093 fix missing bracket 2025-05-19 11:16:23 +10:00
David Burnett
6f73d9e9c6 raise exected exception when attempting to change dtype 2025-05-19 11:16:23 +10:00
David Burnett
29b406a84b Add to overload for GGMLTensor, so calling to on the model moves the quantized data as well 2025-05-19 11:16:23 +10:00
psychedelicious
2b1e4b88d3 tests: add new service to mocks 2025-05-19 10:29:07 +10:00
psychedelicious
0f0085a776 chore(ui): typegen 2025-05-19 10:29:07 +10:00
psychedelicious
ea28ed8261 chore: ruff 2025-05-19 10:29:07 +10:00
Lucian Hardy
c0e6327d3a chore(ui): Refactor RelatedModels.tsx
Major cleanup of RelatedModels.tsx for improved readability, structure, and maintainability.
Dried out repetitive logic
Consolidated model type sorting into reusable helpers
Added disallowed model type relationships to prevent broken connections (e.g. VAE ↔ LoRA)
- Aware this introduces a new constraint—open to feedback (see PR comment)
Some naming and types may still need refinement; happy to revisit
2025-05-19 10:29:07 +10:00
Lucian Hardy
459491e402 chore(backend): Removed unused model_relationship methods
removed unused AnyModelConfig related methods,
removed unused get_related_model_key_count method.
2025-05-19 10:29:07 +10:00
Lucian Hardy
a4cddfa47d feat(ui): model relationship management
Adds full support for managing model-to-model relationships in the UI and backend.

Introduces RelatedModels subpanel for linking and unlinking models in model management.
 - Adds REST API routes for adding, removing, and retrieving model relationships.
 - New database migration: creates model_relationships table for bidirectional links.
 - New service layer (model_relationships) for relationship management.
 - Updated frontend: Related models float to top of LoRA/Main grouped model comboboxes for quick access.
     - Added 'Show Only Related' toggle badge to MainModelPicker filter bar

**Amended commit to remove changes to ParamMainModelSelect.tsx and MainModelPicker.tsx to avoid conflict with upstream deletion/ rewrite**
2025-05-19 10:29:07 +10:00
jazzhaiku
9a822bcfe8 Jazzhaiku/stats (#8006)
## Summary

- Modify stats reset to be on a per session basis, rather than a "full
reset", to allow for parallel session execution
- Add "aider" to gitignore

## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
2025-05-16 07:51:23 +10:00
psychedelicious
5f12b9185f feat(mm): add cache_snapshot to model cache clear callback 2025-05-15 16:06:47 +10:00
psychedelicious
d958d2e5a0 feat(mm): iterate on cache callbacks API 2025-05-15 14:37:22 +10:00
psychedelicious
823ca214e6 feat(mm): iterate on cache callbacks API 2025-05-15 13:28:51 +10:00
psychedelicious
a33da450fd feat(mm): support cache callbacks 2025-05-15 11:23:58 +10:00
Billy
8b5f4d190c Restore Schema 2025-05-15 10:38:01 +10:00
Billy
f1f3b7965a Schema 2025-05-15 10:26:45 +10:00
Billy
987be3507c Merge branch 'main' into jazzhaiku/stats 2025-05-15 10:22:56 +10:00
Billy
1f4090fe0e Reset invocation stats on per session basis 2025-05-15 10:19:05 +10:00
Billy
029e2d2c46 Add aider to gitignore 2025-05-15 10:18:42 +10:00
Riku
7722f479e8 translationBot(ui): update translation (German)
Currently translated at 64.9% (1236 of 1902 strings)

Co-authored-by: Riku <riku.block@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/de/
Translation: InvokeAI/Web UI
2025-05-14 10:32:24 +10:00
Linos
3ad4072183 translationBot(ui): update translation (Vietnamese)
Currently translated at 100.0% (1904 of 1904 strings)

translationBot(ui): update translation (Vietnamese)

Currently translated at 100.0% (1902 of 1902 strings)

Co-authored-by: Linos <linos.coding@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/vi/
Translation: InvokeAI/Web UI
2025-05-14 10:32:24 +10:00
Hosted Weblate
6dfb9a1906 translationBot(ui): update translation files
Updated by "Cleanup translation files" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI
2025-05-14 10:32:24 +10:00
RyoKoba
ad2924350d translationBot(ui): update translation (Japanese)
Currently translated at 67.1% (1279 of 1904 strings)

translationBot(ui): update translation (Japanese)

Currently translated at 64.9% (1231 of 1895 strings)

translationBot(ui): update translation (Japanese)

Currently translated at 60.2% (1141 of 1895 strings)

translationBot(ui): update translation (Japanese)

Currently translated at 56.7% (1075 of 1895 strings)

Co-authored-by: RyoKoba <kobayashi_ryo@cyberagent.co.jp>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/ja/
Translation: InvokeAI/Web UI
2025-05-14 10:32:24 +10:00
Linos
3bf51ee0c2 translationBot(ui): update translation (Vietnamese)
Currently translated at 100.0% (1896 of 1896 strings)

translationBot(ui): update translation (Vietnamese)

Currently translated at 100.0% (1895 of 1895 strings)

translationBot(ui): update translation (Vietnamese)

Currently translated at 100.0% (1886 of 1886 strings)

Co-authored-by: Linos <linos.coding@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/vi/
Translation: InvokeAI/Web UI
2025-05-14 10:32:24 +10:00
Hosted Weblate
fce5051dcc translationBot(ui): update translation files
Updated by "Remove blank strings" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI
2025-05-14 10:32:24 +10:00
Riccardo Giovanetti
446d8818b9 translationBot(ui): update translation (Italian)
Currently translated at 98.8% (1883 of 1904 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.8% (1882 of 1903 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.8% (1881 of 1902 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.8% (1878 of 1899 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.8% (1874 of 1895 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.8% (1873 of 1895 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.8% (1864 of 1886 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2025-05-14 10:32:24 +10:00
psychedelicious
1566e29c19 feat(nodes): tidy some type annotations in baseinvocation 2025-05-14 06:55:15 +10:00
psychedelicious
6a2e35f2c4 feat(nodes): store original field annotation & FieldInfo in invocations 2025-05-14 06:55:15 +10:00
psychedelicious
b6d58774f4 feat(nodes): improved error messages for invalid defaults 2025-05-14 06:55:15 +10:00
psychedelicious
758f94d3c6 chore(ui): typegen 2025-05-14 06:55:15 +10:00
psychedelicious
9df0871754 fix(nodes): do not provide invalid defaults for batch nodes 2025-05-14 06:55:15 +10:00
psychedelicious
3011150a3a feat(nodes): validate default values for all fields
This prevents issues where the node is defined with an invalid default value, which would guarantee an error during a ser/de roundtrip.

- Upstream issue requesting this functionality be built-in to pydantic: https://github.com/pydantic/pydantic/issues/8722
- Upstream PR that implements the functionality: https://github.com/pydantic/pydantic-core/pull/1593
2025-05-14 06:55:15 +10:00
psychedelicious
05aa1fce71 chore(ui): typegen 2025-05-14 06:55:15 +10:00
psychedelicious
df81f3274a feat(nodes): improved pydantic type annotation massaging
When we do our field type overrides to allow invocations to be instantiated without all required fields, we were not modifying the annotation of the field but did set the default value of the field to `None`.

This results in an error when doing a ser/de round trip. Here's what we end up doing:

```py
from pydantic import BaseModel, Field

class MyModel(BaseModel):
    foo: str = Field(default=None)
```

And here is a simple round-trip, which should not error but which does:

```py
MyModel(**MyModel().model_dump())
# ValidationError: 1 validation error for MyModel
# foo
#   Input should be a valid string [type=string_type, input_value=None, input_type=NoneType]
#     For further information visit https://errors.pydantic.dev/2.11/v/string_type
```

To fix this, we now check every incoming field and update its annotation to match its default value. In other words, when we override the default field value to `None`, we make its type annotation `<original type> | None`.

This prevents the error during deserialization.

This slightly alters the schema for all invocations and outputs - the values of all fields without default values are now typed as `<original type> | None`, reflecting the overrides.

This means the autogenerated types for fields have also changed for fields without defaults:

```ts
// Old
image?: components["schemas"]["ImageField"];

// New
image?: components["schemas"]["ImageField"] | null;
```

This does not break anything on the frontend.
2025-05-14 06:55:15 +10:00
psychedelicious
143487a492 chore: bump version to v5.11.0 2025-05-13 14:04:45 +10:00
psychedelicious
203fa04295 feat(nodes): support bottleneck flag for nodes 2025-05-13 11:56:40 +10:00
Mary Hipp Rogers
954fce3c67 feat(ui): custom error toast support (#8001)
* support for custom error toast components, starting with usage limit

* add support for all usage limits

---------

Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
2025-05-08 15:53:10 -04:00
Mary Hipp
821889148a easier way to override Whats New 2025-05-07 15:40:21 -04:00
Mary Hipp
4c248d8c2c refetch queue list on mount 2025-05-07 15:37:55 -04:00
Mary Hipp
deb75805d4 use the max for iterations passed in 2025-05-06 18:26:40 -04:00
Mary Hipp Rogers
93110654da Change feature to disable apiModels to chatGPT4oModels only (#7996)
* display credit column in queue list if shouldShowCredits is true

* change apiModels feature to chatGPT4oModels feature

* empty

---------

Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
2025-05-06 14:37:03 -04:00
psychedelicious
ff0c48d532 chore(ui): prettier 2025-05-06 09:07:52 -04:00
psychedelicious
de18073814 feat(ui): support imagen3/chatgpt-4o models in canvas 2025-05-06 09:07:52 -04:00
psychedelicious
0708af9545 feat(ui): support imagen3/chatgpt-4o models in workflow editor 2025-05-06 09:07:52 -04:00
psychedelicious
1e85184c62 feat(nodes): add imagen3/chatgpt-4o field types 2025-05-06 09:07:52 -04:00
psychedelicious
11d3b8d944 feat(ui): add usage info to model picker 2025-05-06 09:07:52 -04:00
psychedelicious
bffd4afb96 chore(ui): typegen 2025-05-06 09:07:52 -04:00
psychedelicious
518a896521 feat(mm): add usage_info to model config 2025-05-06 09:07:52 -04:00
psychedelicious
2647ff141a feat(ui): add basic metadata to imagen3/chatgpt-4o graphs 2025-05-06 09:07:52 -04:00
Mary Hipp Rogers
ba0bac2aa5 add credits to queue item status changed (#7993)
* display credit column in queue list if shouldShowCredits is true

* add credits when queue item status changes

* chore(ui): typegen

---------

Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
2025-05-06 08:54:44 -04:00
psychedelicious
862e2a3e49 chore(ui): typegen 2025-05-05 16:09:13 -04:00
Mary Hipp
d22fd32b05 typegen 2025-05-05 16:09:13 -04:00
Mary Hipp
391e5b7f8c update schema 2025-05-05 16:09:13 -04:00
Mary Hipp
c9d2a5f59a display credit column in queue list if shouldShowCredits is true 2025-05-05 16:09:13 -04:00
Kent Keirsey
1f63b60021 Implementing support for Non-Standard LoRA Format (#7985)
* integrate loRA

* idk anymore tbh

* enable fused matrix for quantized models

* integrate loRA

* idk anymore tbh

* enable fused matrix for quantized models

* ruff fix

---------

Co-authored-by: Sam <bhaskarmdutt@gmail.com>
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
2025-05-05 09:40:38 -04:00
psychedelicious
a499b9f54e chore: bump version to v5.11.0rc2 2025-05-05 23:32:27 +10:00
psychedelicious
104505ea02 chore(ui): lint 2025-05-05 23:25:29 +10:00
psychedelicious
ee4002607c feat(ui): add UI to reset hf token 2025-05-05 23:25:29 +10:00
psychedelicious
fd20582cdd chore(ui): typegen 2025-05-05 23:25:29 +10:00
psychedelicious
43b0d07517 feat(api): add route to reset hf token 2025-05-05 23:25:29 +10:00
blessedcoolant
f83592a052 fix: deprecation warning in get_iso_timestemp 2025-05-05 11:45:30 +10:00
Mary Hipp
b3ee906749 add prompt validation to imagen3 graph 2025-05-01 13:02:13 -04:00
psychedelicious
5d69e9068a feat(ui): add ability to globally disable hotkeys
This will both hide the hotkey from the hotkey modal and override any other enabled status it has.
2025-05-01 10:50:34 -04:00
psychedelicious
a79136b058 fix(ui): always add selectModelsTab hotkey data to prevent unhandled exception while registering the hotkey handler 2025-05-01 10:50:34 -04:00
psychedelicious
944af4d4a9 feat(ui): show unsupported gen mode toasts as warnings intead of errors 2025-05-01 23:25:01 +10:00
psychedelicious
5e001be73a tidy(ui): remove excessive nav to mm buttons 2025-05-01 23:22:19 +10:00
psychedelicious
576a644b3a tidy(ui): modelpicker component 2025-05-01 23:22:19 +10:00
psychedelicious
703557c8a6 feat(ui): cleanup 2025-05-01 23:22:19 +10:00
psychedelicious
d59a53b3f9 feat(ui): simplify picker types 2025-05-01 23:22:19 +10:00
psychedelicious
7b8f78c2d9 fix(ui): focus bug w/ popvoer 2025-05-01 23:22:19 +10:00
psychedelicious
31ab9be79a feat(ui): iterate on picker 2025-05-01 23:22:19 +10:00
psychedelicious
5011fab85d fix(ui): restore FLUX Dev info popover to main model picker 2025-05-01 10:59:51 +10:00
psychedelicious
92bdb9fdcc chore(ui): remove unused exports 2025-05-01 10:59:51 +10:00
Mary Hipp
548e766c0b feat(ui): ability to disable generating with API models 2025-05-01 10:59:51 +10:00
Mary Hipp
ff897f74a1 send the list of reference images reversed to chatGPT so it matches displayed order 2025-04-30 15:56:38 -04:00
psychedelicious
3d29c996ed feat(ui): support img2img for chatgpt 4o w/ ref images 2025-04-30 13:39:05 +10:00
psychedelicious
42d57d1225 fix(ui): ref image layout 2025-04-30 13:39:05 +10:00
psychedelicious
193fa9395a fix(ui): match ref image model to main model when creating global ref image 2025-04-30 13:39:05 +10:00
psychedelicious
56cd839d5b feat(ui): support for ref images for chatgpt on canvas 2025-04-30 13:39:05 +10:00
ubansi
7b446ee40d docs: fix Contribute node import error
When I followed the Contribute Node documentation, I encountered an import error.
This commit fixes the error, which will help reduce debugging time for all future contributors.
2025-04-29 21:03:00 -04:00
Mary Hipp Rogers
17027c4070 Maryhipp/chatgpt UI (#7969)
* add GPTimage1 as allowed base model

* fix for non-disabled inpaint layers

* lots of boilerplate for adding gpt-image base model and disabling things along with imagen

* handle gpt-image dimensions

* build graph for gpt-image

* lint

* feat(ui): make chatgpt model naming consistent

* feat(ui): graph builder naming

* feat(ui): disable img2img for imagen3

* feat(ui): more naming

* feat(ui): support presigned url prefetch

* feat(ui): disable neg prompt for chatgpt

* docs(ui): update docstring

* feat(ui): fix graph building issues for chatgpt

* fix(ui): node ids for chatgpt/imagen

* chore(ui): typegen

---------

Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
2025-04-29 09:38:03 -04:00
psychedelicious
13d44f47ce chore(ui): prettier 2025-04-29 09:12:49 +10:00
psychedelicious
550fbdeb1c fix(ui): more types fixes 2025-04-29 09:12:49 +10:00
psychedelicious
a01cd7c497 fix(ui): add chatgpt-4o to zod schemas that need to match autogenerated types 2025-04-29 09:12:49 +10:00
Mary Hipp
c54afd600c typegen 2025-04-29 09:12:49 +10:00
Mary Hipp
4f911a0ea8 typegen 2025-04-29 09:12:49 +10:00
Mary Hipp
fb91f48722 change base model for chatGPT 4o 2025-04-29 09:12:49 +10:00
psychedelicious
69db60a614 fix(ui): toast typo 2025-04-29 06:56:36 +10:00
Mary Hipp
c6d7f951aa typegen 2025-04-28 15:39:11 -04:00
Mary Hipp
04c005284c add gpt-image to possible base model types 2025-04-28 15:39:11 -04:00
psychedelicious
2d7f9697bf chore(ui): lint 2025-04-28 13:31:26 -04:00
psychedelicious
ae530492a2 chore(ui): typegen 2025-04-28 13:31:26 -04:00
psychedelicious
87ed1e3b6d feat(ui): do not allow imagen3 nodes in published workflows 2025-04-28 13:31:26 -04:00
psychedelicious
cc54466db9 fix(nodes): default value for UIConfigBase.tags 2025-04-28 13:31:26 -04:00
psychedelicious
cbdafe7e38 feat(nodes): allow node clobbering 2025-04-28 13:31:26 -04:00
psychedelicious
112cb76174 fix: random seed for edit mode imagen 2025-04-28 13:31:26 -04:00
psychedelicious
e56d41ab99 feat: rip out enhance prompt as toggleable option, imagen always randomizes seed 2025-04-28 13:31:26 -04:00
psychedelicious
273dfd86ab fix(ui): upscale builder 2025-04-28 13:31:26 -04:00
psychedelicious
871271fde5 feat(ui): rough out imagen3 support for canvas 2025-04-28 13:31:26 -04:00
psychedelicious
14944872c4 feat(mm): add model taxonomy for API models & Imagen3 as base model type 2025-04-28 13:31:26 -04:00
psychedelicious
07bcf3c446 feat(ui): port bbox select to native select 2025-04-28 13:31:26 -04:00
psychedelicious
8ed5585285 feat(nodes): move output metadata to BaseInvocationOutput 2025-04-28 09:19:43 -04:00
psychedelicious
5ce226a467 chore(ui): typegen 2025-04-28 09:19:43 -04:00
Mary Hipp
c64f20a72b remove output_metdata from schema 2025-04-28 09:19:43 -04:00
Mary Hipp
0c9c10a03a update schema 2025-04-28 09:19:43 -04:00
Mary Hipp
4a0df6b865 add optional output_metadata to baseinvocation 2025-04-28 09:19:43 -04:00
psychedelicious
ba165572bf chore: bump version to v5.11.0rc1 2025-04-28 10:10:50 +10:00
psychedelicious
c3d6a10603 fix(ui): handle minor breaking typing change from serialize-error 2025-04-28 09:53:08 +10:00
psychedelicious
4efc86299d fix(ui): type error in SettingsUpsellMenuItem 2025-04-28 09:53:08 +10:00
psychedelicious
e8c7cf63fd fix(ui): type error in canvas worker 2025-04-28 09:53:08 +10:00
psychedelicious
698b034190 chore(ui): bump deps 2025-04-28 09:53:08 +10:00
psychedelicious
3988128c40 feat(ui): add _all_ image outputs to gallery (including collections) 2025-04-28 09:49:04 +10:00
psychedelicious
c768f47365 fix(ui): dnd autoscroll in scrollable containers 2025-04-28 09:46:38 +10:00
psychedelicious
19a63abc54 fix(ui): hide file size on model picker when it is zero 2025-04-23 17:45:09 +10:00
psychedelicious
75ec36bf9a chore(ui): lint 2025-04-23 17:45:09 +10:00
psychedelicious
d802f8e7fb feat(ui): disable search when no options 2025-04-23 17:45:09 +10:00
psychedelicious
6873e0308d feat(ui): custom fallback for model picker when no models installed 2025-04-23 17:45:09 +10:00
psychedelicious
66eb73088e feat(ui): rename user-provided extra ctx for picker from ctx to extra to be less confusing 2025-04-23 17:45:09 +10:00
psychedelicious
ed81a13eb4 docs(ui): add some comments for picker 2025-04-23 17:45:09 +10:00
psychedelicious
fbc1aae52d feat(ui): more flexible fallbacks for model picker 2025-04-23 17:45:09 +10:00
psychedelicious
ba42c3e63f feat(ui): tooltip for compact/full model picker view 2025-04-23 17:45:09 +10:00
psychedelicious
b24e820aa0 fix(ui): flash of "select a model" when changing model 2025-04-23 17:45:09 +10:00
psychedelicious
e8f6b3b77a feat(ui): split out mainmodelpicker component 2025-04-23 17:45:09 +10:00
psychedelicious
8f13518c97 feat(ui): add clear search button to model combobox 2025-04-23 17:45:09 +10:00
psychedelicious
6afbc12074 feat(ui): when no model bases selected, show all models 2025-04-23 17:45:09 +10:00
psychedelicious
6b0a56ceb9 chore(ui): lint 2025-04-23 17:45:09 +10:00
psychedelicious
ca92497e52 feat(ui): remove description from model pciker for now 2025-04-23 17:45:09 +10:00
psychedelicious
97d45ceaf2 feat(ui): model picker filter buttons 2025-04-23 17:45:09 +10:00
psychedelicious
aeb3841a6f feat(ui): wip model picker 2025-04-23 17:45:09 +10:00
psychedelicious
c14d33d3c1 tweak(ui): remove bg on ModelImage fallback 2025-04-23 17:45:09 +10:00
psychedelicious
676e59e072 chore(ui): bump react-resizable-panels to latest
This resolves a bug where SVG elements were ignored when checking when cursor is over a resize handle
2025-04-23 17:45:09 +10:00
psychedelicious
e7dcb6a03f feat(ui): wip model picker 2025-04-23 17:45:09 +10:00
psychedelicious
fb95b7cc2b feat(ui): wip model picker 2025-04-23 17:45:09 +10:00
psychedelicious
015dc3ac0d feat(ui): wip model picker 2025-04-23 17:45:09 +10:00
psychedelicious
9d8a71b362 feat(ui): genericizing picker 2025-04-23 17:45:09 +10:00
psychedelicious
2eb212f393 feat(ui): onSelectId -> onSelectById 2025-04-23 17:45:09 +10:00
psychedelicious
34b268c15c feat(ui): use context for stable picker state 2025-04-23 17:45:09 +10:00
psychedelicious
9a203a64dc feat(ui): render picker in portal 2025-04-23 17:45:09 +10:00
psychedelicious
d80004e056 feat(ui): iterate on model combobox (wip) 2025-04-23 17:45:09 +10:00
psychedelicious
de32ed23a7 feat(ui): iterate on model combobox (wip) 2025-04-23 17:45:09 +10:00
psychedelicious
5aed2b315d feat(ui): iterate on model combobox (wip) 2025-04-23 17:45:09 +10:00
psychedelicious
48db6cfc4f feat(ui): iterate on model combobox (wip) 2025-04-23 17:45:09 +10:00
psychedelicious
aa7c5c281a feat(ui): iterate on model combobox (wip) 2025-04-23 17:45:09 +10:00
psychedelicious
87aeb7f889 feat(ui): iterate on model combobox (wip) 2025-04-23 17:45:09 +10:00
psychedelicious
3b3d6e413a feat(ui): iterate on model combobox (wip) 2025-04-23 17:45:09 +10:00
psychedelicious
b6432f2de3 feat(ui): iterate on model combobox (wip) 2025-04-23 17:45:09 +10:00
psychedelicious
9d0a28ccae feat(ui): iterate on model combobox (wip) 2025-04-23 17:45:09 +10:00
psychedelicious
c3bf0a3277 feat(ui): iterate on model combobox (wip) 2025-04-23 17:45:09 +10:00
psychedelicious
b516610c1e feat(ui): iterate on model combobox (wip) 2025-04-23 17:45:09 +10:00
psychedelicious
677e717cd7 feat(ui): iterate on model combobox (wip) 2025-04-23 17:45:09 +10:00
psychedelicious
c52584e057 feat(ui): simplify ScrollableContent 2025-04-23 17:45:09 +10:00
psychedelicious
b6767441db feat(ui): iterate on model combobox (wip) 2025-04-23 17:45:09 +10:00
psychedelicious
8745dbe67d feat(ui): iterate on model combobox (wip) 2025-04-23 17:45:09 +10:00
psychedelicious
a565d9473e feat(ui): add useStateImperative 2025-04-23 17:45:09 +10:00
psychedelicious
4dbf07c3e0 feat(ui): iterate on model combobox (wip) 2025-04-23 17:45:09 +10:00
psychedelicious
f6eb4d9a6b feat(ui): toast on select for demo purposes 2025-04-23 17:45:09 +10:00
psychedelicious
5037967b82 feat(ui): just make the damn thing myself 2025-04-23 17:45:09 +10:00
psychedelicious
4930ba48ce feat(ui): just make the damn thing myself 2025-04-23 17:45:09 +10:00
psychedelicious
40d2092256 feat(ui): reworked model selection ui (WIP) 2025-04-23 17:45:09 +10:00
psychedelicious
d2e9237740 feat(ui): reworked model selection ui (WIP) 2025-04-23 17:45:09 +10:00
psychedelicious
b191b706c1 feat(ui): reworked model selection ui (WIP) 2025-04-23 17:45:09 +10:00
psychedelicious
4d0f760ec8 chore(ui): bump cmdk to latest 2025-04-23 17:45:09 +10:00
psychedelicious
65cda5365a feat(ui): remove go to mm button from node fields 2025-04-23 17:45:09 +10:00
psychedelicious
1f2d1d086f feat(ui): add <NavigateToModelManagerButton /> to model comboboxes everywhere 2025-04-23 17:45:09 +10:00
psychedelicious
418f3c3f19 feat(ui): abstract out workflow editor model combobox, ensure consistent ui for all model fields 2025-04-23 17:45:09 +10:00
psychedelicious
72173e284c fix(ui): useModelCombobox should use null for no value instead of undefined
This fixes an issue where the refiner combobox doesn't clear itself visually when clicking the little X icon to clear the selection.
2025-04-23 17:45:09 +10:00
psychedelicious
9cc13556aa feat(ui): accept callback to override navigate to model manager functionality
If provided, `<NavigateToModelManagerButton />` will render, even if `disabledTabs` includes "models". If provided, `<NavigateToModelManagerButton />` will run the callback instead of switching tabs within the studio.

The button's tooltip is now just "Manage Models" and its icon is the same as the model manager tab's icon ([CUBE!](https://www.youtube.com/watch?v=4aGDCE6Nrz0)).
2025-04-23 17:45:09 +10:00
psychedelicious
298444f2bc chore: bump version to v5.10.1 2025-04-19 00:05:02 +10:00
psychedelicious
deb1984289 fix(mm): disable new model probe API
There is a subtle change in behaviour with the new model probe API.

Previously, checks for model types was done in a specific order. For example, we did all main model checks before LoRA checks.

With the new API, the order of checks has changed. Check ordering is as follows:
- New API checks are run first, then legacy API checks.
- New API checks categorized by their speed. When we run new API checks, we sort them from fastest to slowest, and run them in that order. This is a performance optimization.

Currently, LoRA and LLaVA models are the only model types with the new API. Checks for them are thus run first.

LoRA checks involve checking the state dict for presence of keys with specific prefixes. We expect these keys to only exist in LoRAs.

It turns out that main models may have some of these keys.

For example, this model has keys that match the LoRA prefix `lora_te_`: https://civitai.com/models/134442/helloyoung25d

Under the old probe, we'd do the main model checks first and correctly identify this as a main model. But with the new setup, we do the LoRA check first, and those pass. So we import this model as a LoRA.

Thankfully, the old probe still exists. For now, the new probe is fully disabled. It was only called in one spot.

I've also added the example affected model as a test case for the model probe. Right now, this causes the test to fail, and I've marked the test as xfail. CI will pass.

Once we enable the new API again, the xfail will pass, and CI will fail, and we'll be reminded to update the test.
2025-04-18 22:44:10 +10:00
psychedelicious
814406d98a feat(mm): siglip model loading supports partial loading
In the previous commit, the LLaVA model was updated to support partial loading.

In this commit, the SigLIP model is updated in the same way.

This model is used for FLUX Redux. It's <4GB and only ever run in isolation, so it won't benefit from partial loading for the vast majority of users. Regardless, I think it is best if we make _all_ models work with partial loading.

PS: I also fixed the initial load dtype issue, described in the prev commit. It's probably a non-issue for this model, but we may as well fix it.
2025-04-18 10:12:03 +10:00
psychedelicious
c054501103 feat(mm): llava model loading supports partial loading; fix OOM crash on initial load
The model manager has two types of model cache entries:
- `CachedModelOnlyFullLoad`: The model may only ever be loaded and unloaded as a single object.
- `CachedModelWithPartialLoad`: The model may be partially loaded and unloaded.

Partial loaded is enabled by overwriting certain torch layer classes, adding the ability to autocast the layer to a device on-the-fly. See `CustomLinear` for an example.

So, to take advantage of partial loading and be cached as a `CachedModelWithPartialLoad`, the model must inherit from `torch.nn.Module`.

The LLaVA classes provided by `transformers` do inherit from `torch.nn.Module`, but we wrap those classes in a separate class called `LlavaOnevisionModel`. The wrapper encapsulate both the LLaVA model and its "processor" - a lightweight class that prepares model inputs like text and images.

While it is more elegant to encapsulate both model and processor classes in a single entity, this prevents the model cache from enabling partial loading for the chunky vLLM model.

Fixing this involved a few changes.
- Update the `LlavaOnevisionModelLoader` class to operate on the vLLM model directly, instead the `LlavaOnevisionModel` wrapper class.
- Instantiate the processor directly in the node. The processor is lightweight and does its business on the CPU. We don't need to worry about caching in the model manager.
- Remove caching support code from the `LlavaOnevisionModel` wrapper class. It's not needed, because we do not cache this class. The class now only handles running the models provided to it.
- Rename `LlavaOnevisionModel` to `LlavaOnevisionPipeline` to better represent its purpose.

These changes have a bonus effect of fixing an OOM crash when initially loading the models. This was most apparent when loading LLaVA 7B, which is pretty chunky.

The initial load is onto CPU RAM. In the old version of the loaders, we ignored the loader's target dtype for the initial load. Instead, we loaded the model at `transformers`'s "default" dtype of fp32.

LLaVA 7B is fp16 and weighs ~17GB. Loading as fp32 means we need double that amount (~34GB) of CPU RAM. Many users only have 32GB RAM, so this causes a _CPU_ OOM - which is a hard crash of the whole process.

With the updated loaders, the initial load logic now uses the target dtype for the initial load. LLaVA now needs the expected ~17GB RAM for its initial load.

PS: If we didn't make the accompanying partial loading changes, we still could have solved this OOM. We'd just need to pass the initial load dtype to the wrapper class and have it load on that dtype. But we may as well fix both issues.

PPS: There are other models whose model classes are wrappers around a torch module class, and thus cannot be partially loaded. However, these models are typically fairly small and/or are run only on their own, so they don't benefit as much from partial loading. It's the really big models (like LLaVA 7B) that benefit most from the partial loading.
2025-04-18 10:12:03 +10:00
psychedelicious
c1d819c7e5 feat(nodes): add get_absolute_path method to context.models API
Given a model config or path (presumably to a model), returns the absolute path to the model.

Check the next few commits for use-case.
2025-04-18 10:12:03 +10:00
psychedelicious
2a8e91f94d feat(ui): wrap JSON in dataviewer 2025-04-17 22:55:04 +10:00
psychedelicious
64f3e56039 chore: bump version to v5.10.0 2025-04-17 15:08:26 +10:00
Hosted Weblate
819afab230 translationBot(ui): update translation files
Updated by "Cleanup translation files" hook in Weblate.

translationBot(ui): update translation files

Updated by "Cleanup translation files" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI
2025-04-17 11:28:02 +10:00
Linos
9fff064c55 translationBot(ui): update translation (Vietnamese)
Currently translated at 100.0% (1887 of 1887 strings)

translationBot(ui): update translation (Vietnamese)

Currently translated at 100.0% (1887 of 1887 strings)

Co-authored-by: Linos <linos.coding@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/vi/
Translation: InvokeAI/Web UI
2025-04-17 11:28:02 +10:00
Riccardo Giovanetti
1aa8d94378 translationBot(ui): update translation (Italian)
Currently translated at 98.0% (1851 of 1887 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2025-04-17 11:28:02 +10:00
RyoKoba
d78bdde2c3 translationBot(ui): update translation (Japanese)
Currently translated at 56.6% (1069 of 1887 strings)

translationBot(ui): update translation (Japanese)

Currently translated at 50.8% (960 of 1887 strings)

translationBot(ui): update translation (Japanese)

Currently translated at 48.4% (912 of 1882 strings)

Co-authored-by: RyoKoba <kobayashi_ryo@cyberagent.co.jp>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/ja/
Translation: InvokeAI/Web UI
2025-04-17 11:28:02 +10:00
psychedelicious
7b663b3432 fix(ui): scrolling in builder
I am at loss as the to cause of this bug. The styles that I needed to change to fix it haven't been changed in a couple months. But these do seem to fix it.

Closes #7910
2025-04-17 11:24:54 +10:00
psychedelicious
9c4159915a feat(ui): add guardrails to prevent entity types being missed in useIsEntityTypeEnabled 2025-04-17 11:21:16 +10:00
psychedelicious
dbb5830027 fix(ui): useIsEntityTypeEnabled should use useMemo not useCallback
Typo/bug introduced in #7770
2025-04-17 11:21:16 +10:00
psychedelicious
4fc4dbb656 fix(ui): ensure query subs are reset in case of error 2025-04-17 11:13:41 +10:00
psychedelicious
d4f6d09cc9 fix(ui): never subscribe to dynamic prompts queries
If the request errors, we would never get to unsubscribe. The request would forever be marked as having a subscriber and never be cleared from memory.
2025-04-17 10:36:09 +10:00
psychedelicious
44e44602d3 feat(ui): remove keepUnusedDataFor for dynamic prompts
This query can have potentially large responses. Keeping them around for 24 hours essentially a hardcoded memory leak. Use the default for RTKQ of 60 seconds.
2025-04-17 10:36:09 +10:00
psychedelicious
36066c5f26 fix(ui): ensure dynamic prompts updates on any change to any dependent state
When users generate on the canvas or upscaling tabs, we parse prompts through dynamic prompts before invoking. Whenever the prompt or other settings change, we run dynamic prompts.

Previously, we used a redux listener to react to changes to dynamic prompts' dependent state, keeping the processed dynamic prompts synced. For example, when the user changed the prompt field, we re-processed the dynamic prompts.

This requires that all redux actions that change the dependent state be added to the listener matcher. It's easy to forget actions, though, which can result in the dynamic prompts state being stale.

For example, when resetting canvas state, we dispatch an action that resets the whole params slice, but this wasn't in the matcher. As a result, when resetting canvas, the dynamic prompts aren't updated. If the user then clicks Invoke (with an empty prompt), the last dynamic prompts state will be used.

For example:
- Generate w/ prompt "frog", get frog
- Click new canvas session
- Generate without any prompt, still get frog

To resolve this, the logic that keeps the dynamic prompts synced is moved from the listener to a hook. The way the logic is triggered is improved - it's now triggered in a useEffect, which is run when the dependent state changes. This way, it doesn't matter _how_ the dependent state changes - the changes will always be "seen", and the dynamic prompts will update.
2025-04-17 10:36:09 +10:00
psychedelicious
361c6eed4b docs: update manual install docs w/ correct pytorch indicies for v5.10.0 and later 2025-04-17 10:32:41 +10:00
psychedelicious
bb154fd40f docs: update dev env docs with correct pytorch pypi index 2025-04-17 10:32:41 +10:00
psychedelicious
cbee6e6faf fix(app): remove accidentally committed tensor cache size
I had set this to zero for testing udring the python 2.6.0 upgrade and neglected to remove it.
2025-04-17 10:12:47 +10:00
1060 changed files with 55644 additions and 26800 deletions

29
.github/CODEOWNERS vendored
View File

@@ -1,32 +1,31 @@
# continuous integration
/.github/workflows/ @lstein @blessedcoolant @hipsterusername @ebr @jazzhaiku
/.github/workflows/ @lstein @blessedcoolant @hipsterusername @ebr @jazzhaiku @psychedelicious
# documentation
/docs/ @lstein @blessedcoolant @hipsterusername @psychedelicious
/mkdocs.yml @lstein @blessedcoolant @hipsterusername @psychedelicious
# nodes
/invokeai/app/ @blessedcoolant @psychedelicious @brandonrising @hipsterusername @jazzhaiku
/invokeai/app/ @blessedcoolant @psychedelicious @hipsterusername @jazzhaiku
# installation and configuration
/pyproject.toml @lstein @blessedcoolant @hipsterusername
/docker/ @lstein @blessedcoolant @hipsterusername @ebr
/scripts/ @ebr @lstein @hipsterusername
/installer/ @lstein @ebr @hipsterusername
/invokeai/assets @lstein @ebr @hipsterusername
/invokeai/configs @lstein @hipsterusername
/invokeai/version @lstein @blessedcoolant @hipsterusername
/pyproject.toml @lstein @blessedcoolant @psychedelicious @hipsterusername
/docker/ @lstein @blessedcoolant @psychedelicious @hipsterusername @ebr
/scripts/ @ebr @lstein @psychedelicious @hipsterusername
/installer/ @lstein @ebr @psychedelicious @hipsterusername
/invokeai/assets @lstein @ebr @psychedelicious @hipsterusername
/invokeai/configs @lstein @psychedelicious @hipsterusername
/invokeai/version @lstein @blessedcoolant @psychedelicious @hipsterusername
# web ui
/invokeai/frontend @blessedcoolant @psychedelicious @lstein @maryhipp @hipsterusername
/invokeai/backend @blessedcoolant @psychedelicious @lstein @maryhipp @hipsterusername
# generation, model management, postprocessing
/invokeai/backend @lstein @blessedcoolant @brandonrising @hipsterusername @jazzhaiku
/invokeai/backend @lstein @blessedcoolant @hipsterusername @jazzhaiku @psychedelicious @maryhipp
# front ends
/invokeai/frontend/CLI @lstein @hipsterusername
/invokeai/frontend/install @lstein @ebr @hipsterusername
/invokeai/frontend/merge @lstein @blessedcoolant @hipsterusername
/invokeai/frontend/training @lstein @blessedcoolant @hipsterusername
/invokeai/frontend/CLI @lstein @psychedelicious @hipsterusername
/invokeai/frontend/install @lstein @ebr @psychedelicious @hipsterusername
/invokeai/frontend/merge @lstein @blessedcoolant @psychedelicious @hipsterusername
/invokeai/frontend/training @lstein @blessedcoolant @psychedelicious @hipsterusername
/invokeai/frontend/web @psychedelicious @blessedcoolant @maryhipp @hipsterusername

View File

@@ -21,6 +21,20 @@ body:
- label: I have searched the existing issues
required: true
- type: dropdown
id: install_method
attributes:
label: Install method
description: How did you install Invoke?
multiple: false
options:
- "Invoke's Launcher"
- 'Stability Matrix'
- 'Pinokio'
- 'Manual'
validations:
required: true
- type: markdown
attributes:
value: __Describe your environment__
@@ -76,8 +90,8 @@ body:
attributes:
label: Version number
description: |
The version of Invoke you have installed. If it is not the latest version, please update and try again to confirm the issue still exists. If you are testing main, please include the commit hash instead.
placeholder: ex. 3.6.1
The version of Invoke you have installed. If it is not the [latest version](https://github.com/invoke-ai/InvokeAI/releases/latest), please update and try again to confirm the issue still exists. If you are testing main, please include the commit hash instead.
placeholder: ex. v6.0.2
validations:
required: true
@@ -85,17 +99,17 @@ body:
id: browser-version
attributes:
label: Browser
description: Your web browser and version.
description: Your web browser and version, if you do not use the Launcher's provided GUI.
placeholder: ex. Firefox 123.0b3
validations:
required: true
required: false
- type: textarea
id: python-deps
attributes:
label: Python dependencies
label: System Information
description: |
If the problem occurred during image generation, click the gear icon at the bottom left corner, click "About", click the copy button and then paste here.
Click the gear icon at the bottom left corner, then click "About". Click the copy button and then paste here.
validations:
required: false

View File

@@ -3,15 +3,15 @@ description: Installs frontend dependencies with pnpm, with caching
runs:
using: 'composite'
steps:
- name: setup node 18
- name: setup node 20
uses: actions/setup-node@v4
with:
node-version: '18'
node-version: '20'
- name: setup pnpm
uses: pnpm/action-setup@v4
with:
version: 8.15.6
version: 10
run_install: false
- name: get pnpm store directory

View File

@@ -67,6 +67,10 @@ jobs:
version: '0.6.10'
enable-cache: true
- name: check pypi classifiers
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
run: uv run --no-project scripts/check_classifiers.py ./pyproject.toml
- name: ruff check
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
run: uv tool run ruff@0.11.2 check --output-format=github .

4
.gitignore vendored
View File

@@ -180,6 +180,7 @@ cython_debug/
# Scratch folder
.scratch/
.vscode/
.zed/
# source installer files
installer/*zip
@@ -188,3 +189,6 @@ installer/install.sh
installer/update.bat
installer/update.sh
installer/InvokeAI-Installer/
.aider*
.claude/

View File

@@ -5,8 +5,7 @@
FROM docker.io/node:22-slim AS web-builder
ENV PNPM_HOME="/pnpm"
ENV PATH="$PNPM_HOME:$PATH"
RUN corepack use pnpm@8.x
RUN corepack enable
RUN corepack use pnpm@10.x && corepack enable
WORKDIR /build
COPY invokeai/frontend/web/ ./

View File

@@ -39,7 +39,7 @@ nodes imported in the `__init__.py` file are loaded. See the README in the nodes
folder for more examples:
```py
from .cool_node import CoolInvocation
from .cool_node import ResizeInvocation
```
## Creating A New Invocation
@@ -69,7 +69,10 @@ The first set of things we need to do when creating a new Invocation are -
So let us do that.
```python
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.invocation_api import (
BaseInvocation,
invocation,
)
@invocation('resize')
class ResizeInvocation(BaseInvocation):
@@ -103,8 +106,12 @@ create your own custom field types later in this guide. For now, let's go ahead
and use it.
```python
from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, invocation
from invokeai.app.invocations.primitives import ImageField
from invokeai.invocation_api import (
BaseInvocation,
ImageField,
InputField,
invocation,
)
@invocation('resize')
class ResizeInvocation(BaseInvocation):
@@ -128,8 +135,12 @@ image: ImageField = InputField(description="The input image")
Great. Now let us create our other inputs for `width` and `height`
```python
from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, invocation
from invokeai.app.invocations.primitives import ImageField
from invokeai.invocation_api import (
BaseInvocation,
ImageField,
InputField,
invocation,
)
@invocation('resize')
class ResizeInvocation(BaseInvocation):
@@ -163,8 +174,13 @@ that are provided by it by InvokeAI.
Let us create this function first.
```python
from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, invocation, InvocationContext
from invokeai.app.invocations.primitives import ImageField
from invokeai.invocation_api import (
BaseInvocation,
ImageField,
InputField,
InvocationContext,
invocation,
)
@invocation('resize')
class ResizeInvocation(BaseInvocation):
@@ -191,8 +207,14 @@ all the necessary info related to image outputs. So let us use that.
We will cover how to create your own output types later in this guide.
```python
from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, invocation, InvocationContext
from invokeai.app.invocations.primitives import ImageField
from invokeai.invocation_api import (
BaseInvocation,
ImageField,
InputField,
InvocationContext,
invocation,
)
from invokeai.app.invocations.image import ImageOutput
@invocation('resize')
@@ -217,9 +239,15 @@ Perfect. Now that we have our Invocation setup, let us do what we want to do.
So let's do that.
```python
from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, invocation, InvocationContext
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.image import ImageOutput, ResourceOrigin, ImageCategory
from invokeai.invocation_api import (
BaseInvocation,
ImageField,
InputField,
InvocationContext,
invocation,
)
from invokeai.app.invocations.image import ImageOutput
@invocation("resize")
class ResizeInvocation(BaseInvocation):

View File

@@ -41,7 +41,7 @@ If you just want to use Invoke, you should use the [launcher][launcher link].
With the modifications made, the install command should look something like this:
```sh
uv pip install -e ".[dev,test,docs,xformers]" --python 3.12 --python-preference only-managed --index=https://download.pytorch.org/whl/cu124 --reinstall
uv pip install -e ".[dev,test,docs,xformers]" --python 3.12 --python-preference only-managed --index=https://download.pytorch.org/whl/cu128 --reinstall
```
6. At this point, you should have Invoke installed, a venv set up and activated, and the server running. But you will see a warning in the terminal that no UI was found. If you go to the URL for the server, you won't get a UI.
@@ -50,11 +50,11 @@ If you just want to use Invoke, you should use the [launcher][launcher link].
If you only want to edit the docs, you can stop here and skip to the **Documentation** section below.
7. Install the frontend dev toolchain:
7. Install the frontend dev toolchain, paying attention to versions:
- [`nodejs`](https://nodejs.org/) (v20+)
- [`nodejs`](https://nodejs.org/) (tested on LTS, v22)
- [`pnpm`](https://pnpm.io/8.x/installation) (must be v8 - not v9!)
- [`pnpm`](https://pnpm.io/installation) (tested on v10)
8. Do a production build of the frontend:

View File

@@ -71,7 +71,21 @@ The following commands vary depending on the version of Invoke being installed a
7. Determine the `PyPI` index URL to use for installation, if any. This is necessary to get the right version of torch installed.
=== "Invoke v5 or later"
=== "Invoke v5.12 and later"
- If you are on Windows or Linux with an Nvidia GPU, use `https://download.pytorch.org/whl/cu128`.
- If you are on Linux with no GPU, use `https://download.pytorch.org/whl/cpu`.
- If you are on Linux with an AMD GPU, use `https://download.pytorch.org/whl/rocm6.2.4`.
- **In all other cases, do not use an index.**
=== "Invoke v5.10.0 to v5.11.0"
- If you are on Windows or Linux with an Nvidia GPU, use `https://download.pytorch.org/whl/cu126`.
- If you are on Linux with no GPU, use `https://download.pytorch.org/whl/cpu`.
- If you are on Linux with an AMD GPU, use `https://download.pytorch.org/whl/rocm6.2.4`.
- **In all other cases, do not use an index.**
=== "Invoke v5.0.0 to v5.9.1"
- If you are on Windows with an Nvidia GPU, use `https://download.pytorch.org/whl/cu124`.
- If you are on Linux with no GPU, use `https://download.pytorch.org/whl/cpu`.

View File

@@ -35,7 +35,7 @@ More detail on system requirements can be found [here](./requirements.md).
## Step 2: Download
Download the most launcher for your operating system:
Download the most recent launcher for your operating system:
- [Download for Windows](https://download.invoke.ai/Invoke%20Community%20Edition.exe)
- [Download for macOS](https://download.invoke.ai/Invoke%20Community%20Edition.dmg)

View File

@@ -13,6 +13,7 @@ If you'd prefer, you can also just download the whole node folder from the linke
To use a community workflow, download the `.json` node graph file and load it into Invoke AI via the **Load Workflow** button in the Workflow Editor.
- Community Nodes
+ [Anamorphic Tools](#anamorphic-tools)
+ [Adapters-Linked](#adapters-linked-nodes)
+ [Autostereogram](#autostereogram-nodes)
+ [Average Images](#average-images)
@@ -20,9 +21,12 @@ To use a community workflow, download the `.json` node graph file and load it in
+ [Close Color Mask](#close-color-mask)
+ [Clothing Mask](#clothing-mask)
+ [Contrast Limited Adaptive Histogram Equalization](#contrast-limited-adaptive-histogram-equalization)
+ [Curves](#curves)
+ [Depth Map from Wavefront OBJ](#depth-map-from-wavefront-obj)
+ [Enhance Detail](#enhance-detail)
+ [Film Grain](#film-grain)
+ [Flip Pose](#flip-pose)
+ [Flux Ideal Size](#flux-ideal-size)
+ [Generative Grammar-Based Prompt Nodes](#generative-grammar-based-prompt-nodes)
+ [GPT2RandomPromptMaker](#gpt2randompromptmaker)
+ [Grid to Gif](#grid-to-gif)
@@ -61,6 +65,13 @@ To use a community workflow, download the `.json` node graph file and load it in
- [Help](#help)
--------------------------------
### Anamorphic Tools
**Description:** A set of nodes to perform anamorphic modifications to images, like lens blur, streaks, spherical distortion, and vignetting.
**Node Link:** https://github.com/JPPhoto/anamorphic-tools
--------------------------------
### Adapters Linked Nodes
@@ -132,6 +143,13 @@ Node Link: https://github.com/VeyDlin/clahe-node
View:
</br><img src="https://raw.githubusercontent.com/VeyDlin/clahe-node/master/.readme/node.png" width="500" />
--------------------------------
### Curves
**Description:** Adjust an image's curve based on a user-defined string.
**Node Link:** https://github.com/JPPhoto/curves-node
--------------------------------
### Depth Map from Wavefront OBJ
@@ -162,6 +180,20 @@ To be imported, an .obj must use triangulated meshes, so make sure to enable tha
**Node Link:** https://github.com/JPPhoto/film-grain-node
--------------------------------
### Flip Pose
**Description:** This node will flip an openpose image horizontally, recoloring it to make sure that it isn't facing the wrong direction. Note that it does not work with openpose hands.
**Node Link:** https://github.com/JPPhoto/flip-pose-node
--------------------------------
### Flux Ideal Size
**Description:** This node returns an ideal size to use for the first stage of a Flux image generation pipeline. Generating at the right size helps limit duplication and odd subject placement.
**Node Link:** https://github.com/JPPhoto/flux-ideal-size
--------------------------------
### Generative Grammar-Based Prompt Nodes

View File

@@ -23,6 +23,10 @@ from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_images.model_images_default import ModelImageFileStorageDisk
from invokeai.app.services.model_manager.model_manager_default import ModelManagerService
from invokeai.app.services.model_records.model_records_sql import ModelRecordServiceSQL
from invokeai.app.services.model_relationship_records.model_relationship_records_sqlite import (
SqliteModelRelationshipRecordStorage,
)
from invokeai.app.services.model_relationships.model_relationships_default import ModelRelationshipsService
from invokeai.app.services.names.names_default import SimpleNameService
from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
@@ -113,7 +117,6 @@ class ApiDependencies:
safe_globals=[torch.Tensor],
ephemeral=True,
),
max_cache_size=0,
)
conditioning = ObjectSerializerForwardCache(
ObjectSerializerDisk[ConditioningFieldData](
@@ -137,6 +140,8 @@ class ApiDependencies:
download_queue=download_queue_service,
events=events,
)
model_relationships = ModelRelationshipsService()
model_relationship_records = SqliteModelRelationshipRecordStorage(db=db)
names = SimpleNameService()
performance_statistics = InvocationStatsService()
session_processor = DefaultSessionProcessor(session_runner=DefaultSessionRunner())
@@ -162,6 +167,8 @@ class ApiDependencies:
logger=logger,
model_images=model_images_service,
model_manager=model_manager,
model_relationships=model_relationships,
model_relationship_records=model_relationship_records,
download_queue=download_queue_service,
names=names,
performance_statistics=performance_statistics,

View File

@@ -1,8 +1,7 @@
import typing
from enum import Enum
from importlib.metadata import PackageNotFoundError, version
from importlib.metadata import distributions
from pathlib import Path
from platform import python_version
from typing import Optional
import torch
@@ -44,24 +43,6 @@ class AppVersion(BaseModel):
highlights: Optional[list[str]] = Field(default=None, description="Highlights of release")
class AppDependencyVersions(BaseModel):
"""App depencency Versions Response"""
accelerate: str = Field(description="accelerate version")
compel: str = Field(description="compel version")
cuda: Optional[str] = Field(description="CUDA version")
diffusers: str = Field(description="diffusers version")
numpy: str = Field(description="Numpy version")
opencv: str = Field(description="OpenCV version")
onnx: str = Field(description="ONNX version")
pillow: str = Field(description="Pillow (PIL) version")
python: str = Field(description="Python version")
torch: str = Field(description="PyTorch version")
torchvision: str = Field(description="PyTorch Vision version")
transformers: str = Field(description="transformers version")
xformers: Optional[str] = Field(description="xformers version")
class AppConfig(BaseModel):
"""App Config Response"""
@@ -76,27 +57,19 @@ async def get_version() -> AppVersion:
return AppVersion(version=__version__)
@app_router.get("/app_deps", operation_id="get_app_deps", status_code=200, response_model=AppDependencyVersions)
async def get_app_deps() -> AppDependencyVersions:
@app_router.get("/app_deps", operation_id="get_app_deps", status_code=200, response_model=dict[str, str])
async def get_app_deps() -> dict[str, str]:
deps: dict[str, str] = {dist.metadata["Name"]: dist.version for dist in distributions()}
try:
xformers = version("xformers")
except PackageNotFoundError:
xformers = None
return AppDependencyVersions(
accelerate=version("accelerate"),
compel=version("compel"),
cuda=torch.version.cuda,
diffusers=version("diffusers"),
numpy=version("numpy"),
opencv=version("opencv-python"),
onnx=version("onnx"),
pillow=version("pillow"),
python=python_version(),
torch=torch.version.__version__,
torchvision=version("torchvision"),
transformers=version("transformers"),
xformers=xformers,
)
cuda = torch.version.cuda or "N/A"
except Exception:
cuda = "N/A"
deps["CUDA"] = cuda
sorted_deps = dict(sorted(deps.items(), key=lambda item: item[0].lower()))
return sorted_deps
@app_router.get("/config", operation_id="get_config", status_code=200, response_model=AppConfig)

View File

@@ -1,21 +1,12 @@
from fastapi import Body, HTTPException
from fastapi.routing import APIRouter
from pydantic import BaseModel, Field
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.images.images_common import AddImagesToBoardResult, RemoveImagesFromBoardResult
board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"])
class AddImagesToBoardResult(BaseModel):
board_id: str = Field(description="The id of the board the images were added to")
added_image_names: list[str] = Field(description="The image names that were added to the board")
class RemoveImagesFromBoardResult(BaseModel):
removed_image_names: list[str] = Field(description="The image names that were removed from their board")
@board_images_router.post(
"/",
operation_id="add_image_to_board",
@@ -23,17 +14,26 @@ class RemoveImagesFromBoardResult(BaseModel):
201: {"description": "The image was added to a board successfully"},
},
status_code=201,
response_model=AddImagesToBoardResult,
)
async def add_image_to_board(
board_id: str = Body(description="The id of the board to add to"),
image_name: str = Body(description="The name of the image to add"),
):
) -> AddImagesToBoardResult:
"""Creates a board_image"""
try:
result = ApiDependencies.invoker.services.board_images.add_image_to_board(
board_id=board_id, image_name=image_name
added_images: set[str] = set()
affected_boards: set[str] = set()
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
ApiDependencies.invoker.services.board_images.add_image_to_board(board_id=board_id, image_name=image_name)
added_images.add(image_name)
affected_boards.add(board_id)
affected_boards.add(old_board_id)
return AddImagesToBoardResult(
added_images=list(added_images),
affected_boards=list(affected_boards),
)
return result
except Exception:
raise HTTPException(status_code=500, detail="Failed to add image to board")
@@ -45,14 +45,25 @@ async def add_image_to_board(
201: {"description": "The image was removed from the board successfully"},
},
status_code=201,
response_model=RemoveImagesFromBoardResult,
)
async def remove_image_from_board(
image_name: str = Body(description="The name of the image to remove", embed=True),
):
) -> RemoveImagesFromBoardResult:
"""Removes an image from its board, if it had one"""
try:
result = ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
return result
removed_images: set[str] = set()
affected_boards: set[str] = set()
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
removed_images.add(image_name)
affected_boards.add("none")
affected_boards.add(old_board_id)
return RemoveImagesFromBoardResult(
removed_images=list(removed_images),
affected_boards=list(affected_boards),
)
except Exception:
raise HTTPException(status_code=500, detail="Failed to remove image from board")
@@ -72,16 +83,25 @@ async def add_images_to_board(
) -> AddImagesToBoardResult:
"""Adds a list of images to a board"""
try:
added_image_names: list[str] = []
added_images: set[str] = set()
affected_boards: set[str] = set()
for image_name in image_names:
try:
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
ApiDependencies.invoker.services.board_images.add_image_to_board(
board_id=board_id, image_name=image_name
board_id=board_id,
image_name=image_name,
)
added_image_names.append(image_name)
added_images.add(image_name)
affected_boards.add(board_id)
affected_boards.add(old_board_id)
except Exception:
pass
return AddImagesToBoardResult(board_id=board_id, added_image_names=added_image_names)
return AddImagesToBoardResult(
added_images=list(added_images),
affected_boards=list(affected_boards),
)
except Exception:
raise HTTPException(status_code=500, detail="Failed to add images to board")
@@ -100,13 +120,20 @@ async def remove_images_from_board(
) -> RemoveImagesFromBoardResult:
"""Removes a list of images from their board, if they had one"""
try:
removed_image_names: list[str] = []
removed_images: set[str] = set()
affected_boards: set[str] = set()
for image_name in image_names:
try:
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
removed_image_names.append(image_name)
removed_images.add(image_name)
affected_boards.add("none")
affected_boards.add(old_board_id)
except Exception:
pass
return RemoveImagesFromBoardResult(removed_image_names=removed_image_names)
return RemoveImagesFromBoardResult(
removed_images=list(removed_images),
affected_boards=list(affected_boards),
)
except Exception:
raise HTTPException(status_code=500, detail="Failed to remove images from board")

View File

@@ -146,7 +146,7 @@ async def list_boards(
response_model=list[str],
)
async def list_all_board_image_names(
board_id: str = Path(description="The id of the board"),
board_id: str = Path(description="The id of the board or 'none' for uncategorized images"),
categories: list[ImageCategory] | None = Query(default=None, description="The categories of image to include."),
is_intermediate: bool | None = Query(default=None, description="Whether to list intermediate images."),
) -> list[str]:

View File

@@ -1,24 +1,34 @@
import io
import json
import traceback
from typing import Optional
from typing import ClassVar, Optional
from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request, Response, UploadFile
from fastapi.responses import FileResponse
from fastapi.routing import APIRouter
from PIL import Image
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.api.extract_metadata_from_image import extract_metadata_from_image
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
ImageNamesResult,
ImageRecordChanges,
ResourceOrigin,
)
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
from invokeai.app.services.images.images_common import (
DeleteImagesResult,
ImageDTO,
ImageUrlsDTO,
StarredImagesResult,
UnstarredImagesResult,
)
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.util.controlnet_utils import heuristic_resize_fast
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
images_router = APIRouter(prefix="/v1/images", tags=["images"])
@@ -27,6 +37,19 @@ images_router = APIRouter(prefix="/v1/images", tags=["images"])
IMAGE_MAX_AGE = 31536000
class ResizeToDimensions(BaseModel):
width: int = Field(..., gt=0)
height: int = Field(..., gt=0)
MAX_SIZE: ClassVar[int] = 4096 * 4096
@model_validator(mode="after")
def validate_total_output_size(self):
if self.width * self.height > self.MAX_SIZE:
raise ValueError(f"Max total output size for resizing is {self.MAX_SIZE} pixels")
return self
@images_router.post(
"/upload",
operation_id="upload_image",
@@ -46,6 +69,11 @@ async def upload_image(
board_id: Optional[str] = Query(default=None, description="The board to add this image to, if any"),
session_id: Optional[str] = Query(default=None, description="The session ID associated with this upload, if any"),
crop_visible: Optional[bool] = Query(default=False, description="Whether to crop the image"),
resize_to: Optional[str] = Body(
default=None,
description=f"Dimensions to resize the image to, must be stringified tuple of 2 integers. Max total pixel count: {ResizeToDimensions.MAX_SIZE}",
examples=['"[1024,1024]"'],
),
metadata: Optional[str] = Body(
default=None,
description="The metadata to associate with the image, must be a stringified JSON dict",
@@ -59,13 +87,33 @@ async def upload_image(
contents = await file.read()
try:
pil_image = Image.open(io.BytesIO(contents))
if crop_visible:
bbox = pil_image.getbbox()
pil_image = pil_image.crop(bbox)
except Exception:
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
raise HTTPException(status_code=415, detail="Failed to read image")
if crop_visible:
try:
bbox = pil_image.getbbox()
pil_image = pil_image.crop(bbox)
except Exception:
raise HTTPException(status_code=500, detail="Failed to crop image")
if resize_to:
try:
dims = json.loads(resize_to)
resize_dims = ResizeToDimensions(**dims)
except Exception:
raise HTTPException(status_code=400, detail="Invalid resize_to format or size")
try:
# heuristic_resize_fast expects an RGB or RGBA image
pil_rgba = pil_image.convert("RGBA")
np_image = pil_to_np(pil_rgba)
np_image = heuristic_resize_fast(np_image, (resize_dims.width, resize_dims.height))
pil_image = np_to_pil(np_image)
except Exception:
raise HTTPException(status_code=500, detail="Failed to resize image")
extracted_metadata = extract_metadata_from_image(
pil_image=pil_image,
invokeai_metadata_override=metadata,
@@ -112,18 +160,30 @@ async def create_image_upload_entry(
raise HTTPException(status_code=501, detail="Not implemented")
@images_router.delete("/i/{image_name}", operation_id="delete_image")
@images_router.delete("/i/{image_name}", operation_id="delete_image", response_model=DeleteImagesResult)
async def delete_image(
image_name: str = Path(description="The name of the image to delete"),
) -> None:
) -> DeleteImagesResult:
"""Deletes an image"""
deleted_images: set[str] = set()
affected_boards: set[str] = set()
try:
image_dto = ApiDependencies.invoker.services.images.get_dto(image_name)
board_id = image_dto.board_id or "none"
ApiDependencies.invoker.services.images.delete(image_name)
deleted_images.add(image_name)
affected_boards.add(board_id)
except Exception:
# TODO: Does this need any exception handling at all?
pass
return DeleteImagesResult(
deleted_images=list(deleted_images),
affected_boards=list(affected_boards),
)
@images_router.delete("/intermediates", operation_id="clear_intermediates")
async def clear_intermediates() -> int:
@@ -335,23 +395,52 @@ async def list_image_dtos(
return image_dtos
class DeleteImagesFromListResult(BaseModel):
deleted_images: list[str]
@images_router.post("/delete", operation_id="delete_images_from_list", response_model=DeleteImagesFromListResult)
@images_router.post("/delete", operation_id="delete_images_from_list", response_model=DeleteImagesResult)
async def delete_images_from_list(
image_names: list[str] = Body(description="The list of names of images to delete", embed=True),
) -> DeleteImagesFromListResult:
) -> DeleteImagesResult:
try:
deleted_images: list[str] = []
deleted_images: set[str] = set()
affected_boards: set[str] = set()
for image_name in image_names:
try:
image_dto = ApiDependencies.invoker.services.images.get_dto(image_name)
board_id = image_dto.board_id or "none"
ApiDependencies.invoker.services.images.delete(image_name)
deleted_images.add(image_name)
affected_boards.add(board_id)
except Exception:
pass
return DeleteImagesResult(
deleted_images=list(deleted_images),
affected_boards=list(affected_boards),
)
except Exception:
raise HTTPException(status_code=500, detail="Failed to delete images")
@images_router.delete("/uncategorized", operation_id="delete_uncategorized_images", response_model=DeleteImagesResult)
async def delete_uncategorized_images() -> DeleteImagesResult:
"""Deletes all images that are uncategorized"""
image_names = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board(
board_id="none", categories=None, is_intermediate=None
)
try:
deleted_images: set[str] = set()
affected_boards: set[str] = set()
for image_name in image_names:
try:
ApiDependencies.invoker.services.images.delete(image_name)
deleted_images.append(image_name)
deleted_images.add(image_name)
affected_boards.add("none")
except Exception:
pass
return DeleteImagesFromListResult(deleted_images=deleted_images)
return DeleteImagesResult(
deleted_images=list(deleted_images),
affected_boards=list(affected_boards),
)
except Exception:
raise HTTPException(status_code=500, detail="Failed to delete images")
@@ -360,36 +449,50 @@ class ImagesUpdatedFromListResult(BaseModel):
updated_image_names: list[str] = Field(description="The image names that were updated")
@images_router.post("/star", operation_id="star_images_in_list", response_model=ImagesUpdatedFromListResult)
@images_router.post("/star", operation_id="star_images_in_list", response_model=StarredImagesResult)
async def star_images_in_list(
image_names: list[str] = Body(description="The list of names of images to star", embed=True),
) -> ImagesUpdatedFromListResult:
) -> StarredImagesResult:
try:
updated_image_names: list[str] = []
starred_images: set[str] = set()
affected_boards: set[str] = set()
for image_name in image_names:
try:
ApiDependencies.invoker.services.images.update(image_name, changes=ImageRecordChanges(starred=True))
updated_image_names.append(image_name)
updated_image_dto = ApiDependencies.invoker.services.images.update(
image_name, changes=ImageRecordChanges(starred=True)
)
starred_images.add(image_name)
affected_boards.add(updated_image_dto.board_id or "none")
except Exception:
pass
return ImagesUpdatedFromListResult(updated_image_names=updated_image_names)
return StarredImagesResult(
starred_images=list(starred_images),
affected_boards=list(affected_boards),
)
except Exception:
raise HTTPException(status_code=500, detail="Failed to star images")
@images_router.post("/unstar", operation_id="unstar_images_in_list", response_model=ImagesUpdatedFromListResult)
@images_router.post("/unstar", operation_id="unstar_images_in_list", response_model=UnstarredImagesResult)
async def unstar_images_in_list(
image_names: list[str] = Body(description="The list of names of images to unstar", embed=True),
) -> ImagesUpdatedFromListResult:
) -> UnstarredImagesResult:
try:
updated_image_names: list[str] = []
unstarred_images: set[str] = set()
affected_boards: set[str] = set()
for image_name in image_names:
try:
ApiDependencies.invoker.services.images.update(image_name, changes=ImageRecordChanges(starred=False))
updated_image_names.append(image_name)
updated_image_dto = ApiDependencies.invoker.services.images.update(
image_name, changes=ImageRecordChanges(starred=False)
)
unstarred_images.add(image_name)
affected_boards.add(updated_image_dto.board_id or "none")
except Exception:
pass
return ImagesUpdatedFromListResult(updated_image_names=updated_image_names)
return UnstarredImagesResult(
unstarred_images=list(unstarred_images),
affected_boards=list(affected_boards),
)
except Exception:
raise HTTPException(status_code=500, detail="Failed to unstar images")
@@ -460,3 +563,61 @@ async def get_bulk_download_item(
return response
except Exception:
raise HTTPException(status_code=404)
@images_router.get("/names", operation_id="get_image_names")
async def get_image_names(
image_origin: Optional[ResourceOrigin] = Query(default=None, description="The origin of images to list."),
categories: Optional[list[ImageCategory]] = Query(default=None, description="The categories of image to include."),
is_intermediate: Optional[bool] = Query(default=None, description="Whether to list intermediate images."),
board_id: Optional[str] = Query(
default=None,
description="The board id to filter by. Use 'none' to find images without a board.",
),
order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"),
starred_first: bool = Query(default=True, description="Whether to sort by starred images first"),
search_term: Optional[str] = Query(default=None, description="The term to search for"),
) -> ImageNamesResult:
"""Gets ordered list of image names with metadata for optimistic updates"""
try:
result = ApiDependencies.invoker.services.images.get_image_names(
starred_first=starred_first,
order_dir=order_dir,
image_origin=image_origin,
categories=categories,
is_intermediate=is_intermediate,
board_id=board_id,
search_term=search_term,
)
return result
except Exception:
raise HTTPException(status_code=500, detail="Failed to get image names")
@images_router.post(
"/images_by_names",
operation_id="get_images_by_names",
responses={200: {"model": list[ImageDTO]}},
)
async def get_images_by_names(
image_names: list[str] = Body(embed=True, description="Object containing list of image names to fetch DTOs for"),
) -> list[ImageDTO]:
"""Gets image DTOs for the specified image names. Maintains order of input names."""
try:
image_service = ApiDependencies.invoker.services.images
# Fetch DTOs preserving the order of requested names
image_dtos: list[ImageDTO] = []
for name in image_names:
try:
dto = image_service.get_dto(name)
image_dtos.append(dto)
except Exception:
# Skip missing images - they may have been deleted between name fetch and DTO fetch
continue
return image_dtos
except Exception:
raise HTTPException(status_code=500, detail="Failed to get image DTOs")

View File

@@ -41,6 +41,7 @@ from invokeai.backend.model_manager.starter_models import (
STARTER_BUNDLES,
STARTER_MODELS,
StarterModel,
StarterModelBundle,
StarterModelWithoutDependencies,
)
@@ -291,7 +292,7 @@ async def get_hugging_face_models(
)
async def update_model_record(
key: Annotated[str, Path(description="Unique key of model")],
changes: Annotated[ModelRecordChanges, Body(description="Model config", example=example_model_input)],
changes: Annotated[ModelRecordChanges, Body(description="Model config", examples=[example_model_input])],
) -> AnyModelConfig:
"""Update a model's config."""
logger = ApiDependencies.invoker.services.logger
@@ -449,7 +450,7 @@ async def install_model(
access_token: Optional[str] = Query(description="access token for the remote resource", default=None),
config: ModelRecordChanges = Body(
description="Object containing fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
example={"name": "string", "description": "string"},
examples=[{"name": "string", "description": "string"}],
),
) -> ModelInstallJob:
"""Install a model using a string identifier.
@@ -799,7 +800,7 @@ async def convert_model(
class StarterModelResponse(BaseModel):
starter_models: list[StarterModel]
starter_bundles: dict[str, list[StarterModel]]
starter_bundles: dict[str, StarterModelBundle]
def get_is_installed(
@@ -833,7 +834,7 @@ async def get_starter_models() -> StarterModelResponse:
model.dependencies = missing_deps
for bundle in starter_bundles.values():
for model in bundle:
for model in bundle.models:
model.is_installed = get_is_installed(model, installed_models)
# Remove already-installed dependencies
missing_deps: list[StarterModelWithoutDependencies] = []
@@ -893,6 +894,12 @@ class HFTokenHelper:
huggingface_hub.login(token=token, add_to_git_credential=False)
return cls.get_status()
@classmethod
def reset_token(cls) -> HFTokenStatus:
with SuppressOutput(), contextlib.suppress(Exception):
huggingface_hub.logout()
return cls.get_status()
@model_manager_router.get("/hf_login", operation_id="get_hf_login_status", response_model=HFTokenStatus)
async def get_hf_login_status() -> HFTokenStatus:
@@ -915,3 +922,8 @@ async def do_hf_login(
ApiDependencies.invoker.services.logger.warning("Unable to verify HF token")
return token_status
@model_manager_router.delete("/hf_login", operation_id="reset_hf_token", response_model=HFTokenStatus)
async def reset_hf_token() -> HFTokenStatus:
return HFTokenHelper.reset_token()

View File

@@ -0,0 +1,215 @@
"""FastAPI route for model relationship records."""
from typing import List
from fastapi import APIRouter, Body, HTTPException, Path, status
from pydantic import BaseModel, Field
from invokeai.app.api.dependencies import ApiDependencies
model_relationships_router = APIRouter(prefix="/v1/model_relationships", tags=["model_relationships"])
# === Schemas ===
class ModelRelationshipCreateRequest(BaseModel):
model_key_1: str = Field(
...,
description="The key of the first model in the relationship",
examples=[
"aa3b247f-90c9-4416-bfcd-aeaa57a5339e",
"ac32b914-10ab-496e-a24a-3068724b9c35",
"d944abfd-c7c3-42e2-a4ff-da640b29b8b4",
"b1c2d3e4-f5a6-7890-abcd-ef1234567890",
"12345678-90ab-cdef-1234-567890abcdef",
"fedcba98-7654-3210-fedc-ba9876543210",
],
)
model_key_2: str = Field(
...,
description="The key of the second model in the relationship",
examples=[
"3bb7c0eb-b6c8-469c-ad8c-4d69c06075e4",
"f0c3da4e-d9ff-42b5-a45c-23be75c887c9",
"38170dd8-f1e5-431e-866c-2c81f1277fcc",
"c57fea2d-7646-424c-b9ad-c0ba60fc68be",
"10f7807b-ab54-46a9-ab03-600e88c630a1",
"f6c1d267-cf87-4ee0-bee0-37e791eacab7",
],
)
class ModelRelationshipBatchRequest(BaseModel):
model_keys: List[str] = Field(
...,
description="List of model keys to fetch related models for",
examples=[
[
"aa3b247f-90c9-4416-bfcd-aeaa57a5339e",
"ac32b914-10ab-496e-a24a-3068724b9c35",
],
[
"b1c2d3e4-f5a6-7890-abcd-ef1234567890",
"12345678-90ab-cdef-1234-567890abcdef",
"fedcba98-7654-3210-fedc-ba9876543210",
],
[
"3bb7c0eb-b6c8-469c-ad8c-4d69c06075e4",
],
],
)
# === Routes ===
@model_relationships_router.get(
"/i/{model_key}",
operation_id="get_related_models",
response_model=list[str],
responses={
200: {
"description": "A list of related model keys was retrieved successfully",
"content": {
"application/json": {
"example": [
"15e9eb28-8cfe-47c9-b610-37907a79fc3c",
"71272e82-0e5f-46d5-bca9-9a61f4bd8a82",
"a5d7cd49-1b98-4534-a475-aeee4ccf5fa2",
]
}
},
},
404: {"description": "The specified model could not be found"},
422: {"description": "Validation error"},
},
)
async def get_related_models(
model_key: str = Path(..., description="The key of the model to get relationships for"),
) -> list[str]:
"""
Get a list of model keys related to a given model.
"""
try:
return ApiDependencies.invoker.services.model_relationships.get_related_model_keys(model_key)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@model_relationships_router.post(
"/",
status_code=status.HTTP_204_NO_CONTENT,
responses={
204: {"description": "The relationship was successfully created"},
400: {"description": "Invalid model keys or self-referential relationship"},
409: {"description": "The relationship already exists"},
422: {"description": "Validation error"},
500: {"description": "Internal server error"},
},
summary="Add Model Relationship",
description="Creates a **bidirectional** relationship between two models, allowing each to reference the other as related.",
)
async def add_model_relationship(
req: ModelRelationshipCreateRequest = Body(..., description="The model keys to relate"),
) -> None:
"""
Add a relationship between two models.
Relationships are bidirectional and will be accessible from both models.
- Raises 400 if keys are invalid or identical.
- Raises 409 if the relationship already exists.
"""
try:
if req.model_key_1 == req.model_key_2:
raise HTTPException(status_code=400, detail="Cannot relate a model to itself.")
ApiDependencies.invoker.services.model_relationships.add_model_relationship(
req.model_key_1,
req.model_key_2,
)
except ValueError as e:
raise HTTPException(status_code=409, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@model_relationships_router.delete(
"/",
status_code=status.HTTP_204_NO_CONTENT,
responses={
204: {"description": "The relationship was successfully removed"},
400: {"description": "Invalid model keys or self-referential relationship"},
404: {"description": "The relationship does not exist"},
422: {"description": "Validation error"},
500: {"description": "Internal server error"},
},
summary="Remove Model Relationship",
description="Removes a **bidirectional** relationship between two models. The relationship must already exist.",
)
async def remove_model_relationship(
req: ModelRelationshipCreateRequest = Body(..., description="The model keys to disconnect"),
) -> None:
"""
Removes a bidirectional relationship between two model keys.
- Raises 400 if attempting to unlink a model from itself.
- Raises 404 if the relationship was not found.
"""
try:
if req.model_key_1 == req.model_key_2:
raise HTTPException(status_code=400, detail="Cannot unlink a model from itself.")
ApiDependencies.invoker.services.model_relationships.remove_model_relationship(
req.model_key_1,
req.model_key_2,
)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@model_relationships_router.post(
"/batch",
operation_id="get_related_models_batch",
response_model=List[str],
responses={
200: {
"description": "Related model keys retrieved successfully",
"content": {
"application/json": {
"example": [
"ca562b14-995e-4a42-90c1-9528f1a5921d",
"cc0c2b8a-c62e-41d6-878e-cc74dde5ca8f",
"18ca7649-6a9e-47d5-bc17-41ab1e8cec81",
"7c12d1b2-0ef9-4bec-ba55-797b2d8f2ee1",
"c382eaa3-0e28-4ab0-9446-408667699aeb",
"71272e82-0e5f-46d5-bca9-9a61f4bd8a82",
"a5d7cd49-1b98-4534-a475-aeee4ccf5fa2",
]
}
},
},
422: {"description": "Validation error"},
500: {"description": "Internal server error"},
},
summary="Get Related Model Keys (Batch)",
description="Retrieves all **unique related model keys** for a list of given models. This is useful for contextual suggestions or filtering.",
)
async def get_related_models_batch(
req: ModelRelationshipBatchRequest = Body(..., description="Model keys to check for related connections"),
) -> list[str]:
"""
Accepts multiple model keys and returns a flat list of all unique related keys.
Useful when working with multiple selections in the UI or cross-model comparisons.
"""
try:
all_related: set[str] = set()
for key in req.model_keys:
related = ApiDependencies.invoker.services.model_relationships.get_related_model_keys(key)
all_related.update(related)
return list(all_related)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -1,6 +1,6 @@
from typing import Optional
from fastapi import Body, Path, Query
from fastapi import Body, HTTPException, Path, Query
from fastapi.routing import APIRouter
from pydantic import BaseModel, Field
@@ -14,13 +14,15 @@ from invokeai.app.services.session_queue.session_queue_common import (
CancelByBatchIDsResult,
CancelByDestinationResult,
ClearResult,
DeleteAllExceptCurrentResult,
DeleteByDestinationResult,
EnqueueBatchResult,
FieldIdentifier,
PruneResult,
RetryItemsResult,
SessionQueueCountsByDestination,
SessionQueueItem,
SessionQueueItemDTO,
SessionQueueItemNotFoundError,
SessionQueueStatus,
)
from invokeai.app.services.shared.pagination import CursorPaginatedResults
@@ -58,17 +60,19 @@ async def enqueue_batch(
),
) -> EnqueueBatchResult:
"""Processes a batch and enqueues the output graphs for execution."""
return await ApiDependencies.invoker.services.session_queue.enqueue_batch(
queue_id=queue_id, batch=batch, prepend=prepend
)
try:
return await ApiDependencies.invoker.services.session_queue.enqueue_batch(
queue_id=queue_id, batch=batch, prepend=prepend
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while enqueuing batch: {e}")
@session_queue_router.get(
"/{queue_id}/list",
operation_id="list_queue_items",
responses={
200: {"model": CursorPaginatedResults[SessionQueueItemDTO]},
200: {"model": CursorPaginatedResults[SessionQueueItem]},
},
)
async def list_queue_items(
@@ -77,12 +81,42 @@ async def list_queue_items(
status: Optional[QUEUE_ITEM_STATUS] = Query(default=None, description="The status of items to fetch"),
cursor: Optional[int] = Query(default=None, description="The pagination cursor"),
priority: int = Query(default=0, description="The pagination cursor priority"),
) -> CursorPaginatedResults[SessionQueueItemDTO]:
"""Gets all queue items (without graphs)"""
destination: Optional[str] = Query(default=None, description="The destination of queue items to fetch"),
) -> CursorPaginatedResults[SessionQueueItem]:
"""Gets cursor-paginated queue items"""
return ApiDependencies.invoker.services.session_queue.list_queue_items(
queue_id=queue_id, limit=limit, status=status, cursor=cursor, priority=priority
)
try:
return ApiDependencies.invoker.services.session_queue.list_queue_items(
queue_id=queue_id,
limit=limit,
status=status,
cursor=cursor,
priority=priority,
destination=destination,
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while listing all items: {e}")
@session_queue_router.get(
"/{queue_id}/list_all",
operation_id="list_all_queue_items",
responses={
200: {"model": list[SessionQueueItem]},
},
)
async def list_all_queue_items(
queue_id: str = Path(description="The queue id to perform this operation on"),
destination: Optional[str] = Query(default=None, description="The destination of queue items to fetch"),
) -> list[SessionQueueItem]:
"""Gets all queue items"""
try:
return ApiDependencies.invoker.services.session_queue.list_all_queue_items(
queue_id=queue_id,
destination=destination,
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while listing all queue items: {e}")
@session_queue_router.put(
@@ -94,7 +128,10 @@ async def resume(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> SessionProcessorStatus:
"""Resumes session processor"""
return ApiDependencies.invoker.services.session_processor.resume()
try:
return ApiDependencies.invoker.services.session_processor.resume()
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while resuming queue: {e}")
@session_queue_router.put(
@@ -106,7 +143,10 @@ async def Pause(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> SessionProcessorStatus:
"""Pauses session processor"""
return ApiDependencies.invoker.services.session_processor.pause()
try:
return ApiDependencies.invoker.services.session_processor.pause()
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while pausing queue: {e}")
@session_queue_router.put(
@@ -118,7 +158,25 @@ async def cancel_all_except_current(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> CancelAllExceptCurrentResult:
"""Immediately cancels all queue items except in-processing items"""
return ApiDependencies.invoker.services.session_queue.cancel_all_except_current(queue_id=queue_id)
try:
return ApiDependencies.invoker.services.session_queue.cancel_all_except_current(queue_id=queue_id)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling all except current: {e}")
@session_queue_router.put(
"/{queue_id}/delete_all_except_current",
operation_id="delete_all_except_current",
responses={200: {"model": DeleteAllExceptCurrentResult}},
)
async def delete_all_except_current(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> DeleteAllExceptCurrentResult:
"""Immediately deletes all queue items except in-processing items"""
try:
return ApiDependencies.invoker.services.session_queue.delete_all_except_current(queue_id=queue_id)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while deleting all except current: {e}")
@session_queue_router.put(
@@ -131,7 +189,12 @@ async def cancel_by_batch_ids(
batch_ids: list[str] = Body(description="The list of batch_ids to cancel all queue items for", embed=True),
) -> CancelByBatchIDsResult:
"""Immediately cancels all queue items from the given batch ids"""
return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(queue_id=queue_id, batch_ids=batch_ids)
try:
return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(
queue_id=queue_id, batch_ids=batch_ids
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling by batch id: {e}")
@session_queue_router.put(
@@ -144,9 +207,12 @@ async def cancel_by_destination(
destination: str = Query(description="The destination to cancel all queue items for"),
) -> CancelByDestinationResult:
"""Immediately cancels all queue items with the given origin"""
return ApiDependencies.invoker.services.session_queue.cancel_by_destination(
queue_id=queue_id, destination=destination
)
try:
return ApiDependencies.invoker.services.session_queue.cancel_by_destination(
queue_id=queue_id, destination=destination
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling by destination: {e}")
@session_queue_router.put(
@@ -159,7 +225,10 @@ async def retry_items_by_id(
item_ids: list[int] = Body(description="The queue item ids to retry"),
) -> RetryItemsResult:
"""Immediately cancels all queue items with the given origin"""
return ApiDependencies.invoker.services.session_queue.retry_items_by_id(queue_id=queue_id, item_ids=item_ids)
try:
return ApiDependencies.invoker.services.session_queue.retry_items_by_id(queue_id=queue_id, item_ids=item_ids)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while retrying queue items: {e}")
@session_queue_router.put(
@@ -173,11 +242,14 @@ async def clear(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> ClearResult:
"""Clears the queue entirely, immediately canceling the currently-executing session"""
queue_item = ApiDependencies.invoker.services.session_queue.get_current(queue_id)
if queue_item is not None:
ApiDependencies.invoker.services.session_queue.cancel_queue_item(queue_item.item_id)
clear_result = ApiDependencies.invoker.services.session_queue.clear(queue_id)
return clear_result
try:
queue_item = ApiDependencies.invoker.services.session_queue.get_current(queue_id)
if queue_item is not None:
ApiDependencies.invoker.services.session_queue.cancel_queue_item(queue_item.item_id)
clear_result = ApiDependencies.invoker.services.session_queue.clear(queue_id)
return clear_result
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while clearing queue: {e}")
@session_queue_router.put(
@@ -191,7 +263,10 @@ async def prune(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> PruneResult:
"""Prunes all completed or errored queue items"""
return ApiDependencies.invoker.services.session_queue.prune(queue_id)
try:
return ApiDependencies.invoker.services.session_queue.prune(queue_id)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while pruning queue: {e}")
@session_queue_router.get(
@@ -205,7 +280,10 @@ async def get_current_queue_item(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> Optional[SessionQueueItem]:
"""Gets the currently execution queue item"""
return ApiDependencies.invoker.services.session_queue.get_current(queue_id)
try:
return ApiDependencies.invoker.services.session_queue.get_current(queue_id)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while getting current queue item: {e}")
@session_queue_router.get(
@@ -219,7 +297,10 @@ async def get_next_queue_item(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> Optional[SessionQueueItem]:
"""Gets the next queue item, without executing it"""
return ApiDependencies.invoker.services.session_queue.get_next(queue_id)
try:
return ApiDependencies.invoker.services.session_queue.get_next(queue_id)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while getting next queue item: {e}")
@session_queue_router.get(
@@ -233,9 +314,12 @@ async def get_queue_status(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> SessionQueueAndProcessorStatus:
"""Gets the status of the session queue"""
queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id)
processor = ApiDependencies.invoker.services.session_processor.get_status()
return SessionQueueAndProcessorStatus(queue=queue, processor=processor)
try:
queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id)
processor = ApiDependencies.invoker.services.session_processor.get_status()
return SessionQueueAndProcessorStatus(queue=queue, processor=processor)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while getting queue status: {e}")
@session_queue_router.get(
@@ -250,7 +334,10 @@ async def get_batch_status(
batch_id: str = Path(description="The batch to get the status of"),
) -> BatchStatus:
"""Gets the status of the session queue"""
return ApiDependencies.invoker.services.session_queue.get_batch_status(queue_id=queue_id, batch_id=batch_id)
try:
return ApiDependencies.invoker.services.session_queue.get_batch_status(queue_id=queue_id, batch_id=batch_id)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while getting batch status: {e}")
@session_queue_router.get(
@@ -266,7 +353,27 @@ async def get_queue_item(
item_id: int = Path(description="The queue item to get"),
) -> SessionQueueItem:
"""Gets a queue item"""
return ApiDependencies.invoker.services.session_queue.get_queue_item(item_id)
try:
return ApiDependencies.invoker.services.session_queue.get_queue_item(item_id)
except SessionQueueItemNotFoundError:
raise HTTPException(status_code=404, detail=f"Queue item with id {item_id} not found in queue {queue_id}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while fetching queue item: {e}")
@session_queue_router.delete(
"/{queue_id}/i/{item_id}",
operation_id="delete_queue_item",
)
async def delete_queue_item(
queue_id: str = Path(description="The queue id to perform this operation on"),
item_id: int = Path(description="The queue item to delete"),
) -> None:
"""Deletes a queue item"""
try:
ApiDependencies.invoker.services.session_queue.delete_queue_item(item_id)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while deleting queue item: {e}")
@session_queue_router.put(
@@ -281,8 +388,12 @@ async def cancel_queue_item(
item_id: int = Path(description="The queue item to cancel"),
) -> SessionQueueItem:
"""Deletes a queue item"""
return ApiDependencies.invoker.services.session_queue.cancel_queue_item(item_id)
try:
return ApiDependencies.invoker.services.session_queue.cancel_queue_item(item_id)
except SessionQueueItemNotFoundError:
raise HTTPException(status_code=404, detail=f"Queue item with id {item_id} not found in queue {queue_id}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling queue item: {e}")
@session_queue_router.get(
@@ -295,6 +406,27 @@ async def counts_by_destination(
destination: str = Query(description="The destination to query"),
) -> SessionQueueCountsByDestination:
"""Gets the counts of queue items by destination"""
return ApiDependencies.invoker.services.session_queue.get_counts_by_destination(
queue_id=queue_id, destination=destination
)
try:
return ApiDependencies.invoker.services.session_queue.get_counts_by_destination(
queue_id=queue_id, destination=destination
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while fetching counts by destination: {e}")
@session_queue_router.delete(
"/{queue_id}/d/{destination}",
operation_id="delete_by_destination",
responses={200: {"model": DeleteByDestinationResult}},
)
async def delete_by_destination(
queue_id: str = Path(description="The queue id to query"),
destination: str = Path(description="The destination to query"),
) -> DeleteByDestinationResult:
"""Deletes all items with the given destination"""
try:
return ApiDependencies.invoker.services.session_queue.delete_by_destination(
queue_id=queue_id, destination=destination
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while deleting by destination: {e}")

View File

@@ -22,6 +22,7 @@ from invokeai.app.api.routers import (
download_queue,
images,
model_manager,
model_relationships,
session_queue,
style_presets,
utilities,
@@ -125,6 +126,7 @@ app.include_router(download_queue.download_queue_router, prefix="/api")
app.include_router(images.images_router, prefix="/api")
app.include_router(boards.boards_router, prefix="/api")
app.include_router(board_images.board_images_router, prefix="/api")
app.include_router(model_relationships.model_relationships_router, prefix="/api")
app.include_router(app_info.app_router, prefix="/api")
app.include_router(session_queue.session_queue_router, prefix="/api")
app.include_router(workflows.workflows_router, prefix="/api")
@@ -156,7 +158,7 @@ web_root_path = Path(list(web_dir.__path__)[0])
try:
app.mount("/", NoCacheStaticFiles(directory=Path(web_root_path, "dist"), html=True), name="ui")
except RuntimeError:
logger.warn(f"No UI found at {web_root_path}/dist, skipping UI mount")
logger.warning(f"No UI found at {web_root_path}/dist, skipping UI mount")
app.mount(
"/static", NoCacheStaticFiles(directory=Path(web_root_path, "static/")), name="static"
) # docs favicon is in here

View File

@@ -5,6 +5,8 @@ from __future__ import annotations
import inspect
import re
import sys
import types
import typing
import warnings
from abc import ABC, abstractmethod
from enum import Enum
@@ -20,12 +22,14 @@ from typing import (
Literal,
Optional,
Type,
TypedDict,
TypeVar,
Union,
cast,
)
import semver
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, create_model
from pydantic import BaseModel, ConfigDict, Field, JsonValue, TypeAdapter, create_model
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefined
@@ -72,13 +76,24 @@ class Classification(str, Enum, metaclass=MetaEnum):
Special = "special"
class Bottleneck(str, Enum, metaclass=MetaEnum):
"""
The bottleneck of an invocation.
- `Network`: The invocation's execution is network-bound.
- `GPU`: The invocation's execution is GPU-bound.
"""
Network = "network"
GPU = "gpu"
class UIConfigBase(BaseModel):
"""
Provides additional node configuration to the UI.
This is used internally by the @invocation decorator logic. Do not use this directly.
"""
tags: Optional[list[str]] = Field(default_factory=None, description="The node's tags")
tags: Optional[list[str]] = Field(default=None, description="The node's tags")
title: Optional[str] = Field(default=None, description="The node's display name")
category: Optional[str] = Field(default=None, description="The node's category")
version: str = Field(
@@ -93,6 +108,11 @@ class UIConfigBase(BaseModel):
)
class OriginalModelField(TypedDict):
annotation: Any
field_info: FieldInfo
class BaseInvocationOutput(BaseModel):
"""
Base class for all invocation outputs.
@@ -100,6 +120,12 @@ class BaseInvocationOutput(BaseModel):
All invocation outputs must use the `@invocation_output` decorator to provide their unique type.
"""
output_meta: Optional[dict[str, JsonValue]] = Field(
default=None,
description="Optional dictionary of metadata for the invocation output, unrelated to the invocation's actual output value. This is not exposed as an output field.",
json_schema_extra={"field_kind": FieldKind.NodeAttribute},
)
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocationOutput]) -> None:
"""Adds various UI-facing attributes to the invocation output's OpenAPI schema."""
@@ -115,6 +141,9 @@ class BaseInvocationOutput(BaseModel):
"""Gets the invocation output's type, as provided by the `@invocation_output` decorator."""
return cls.model_fields["type"].default
_original_model_fields: ClassVar[dict[str, OriginalModelField]] = {}
"""The original model fields, before any modifications were made by the @invocation_output decorator."""
model_config = ConfigDict(
protected_namespaces=(),
validate_assignment=True,
@@ -148,7 +177,7 @@ class BaseInvocation(ABC, BaseModel):
return cls.model_fields["type"].default
@classmethod
def get_output_annotation(cls) -> BaseInvocationOutput:
def get_output_annotation(cls) -> Type[BaseInvocationOutput]:
"""Gets the invocation's output annotation (i.e. the return annotation of its `invoke()` method)."""
return signature(cls.invoke).return_annotation
@@ -180,7 +209,7 @@ class BaseInvocation(ABC, BaseModel):
Internal invoke method, calls `invoke()` after some prep.
Handles optional fields that are required to call `invoke()` and invocation cache.
"""
for field_name, field in self.model_fields.items():
for field_name, field in type(self).model_fields.items():
if not field.json_schema_extra or callable(field.json_schema_extra):
# something has gone terribly awry, we should always have this and it should be a dict
continue
@@ -195,9 +224,9 @@ class BaseInvocation(ABC, BaseModel):
setattr(self, field_name, orig_default)
if orig_required and orig_default is PydanticUndefined and getattr(self, field_name) is None:
if input_ == Input.Connection:
raise RequiredConnectionException(self.model_fields["type"].default, field_name)
raise RequiredConnectionException(type(self).model_fields["type"].default, field_name)
elif input_ == Input.Any:
raise MissingInputException(self.model_fields["type"].default, field_name)
raise MissingInputException(type(self).model_fields["type"].default, field_name)
# skip node cache codepath if it's disabled
if services.configuration.node_cache_size == 0:
@@ -235,6 +264,8 @@ class BaseInvocation(ABC, BaseModel):
json_schema_extra={"field_kind": FieldKind.NodeAttribute},
)
bottleneck: ClassVar[Bottleneck]
UIConfig: ClassVar[UIConfigBase]
model_config = ConfigDict(
@@ -245,6 +276,9 @@ class BaseInvocation(ABC, BaseModel):
coerce_numbers_to_str=True,
)
_original_model_fields: ClassVar[dict[str, OriginalModelField]] = {}
"""The original model fields, before any modifications were made by the @invocation decorator."""
TBaseInvocation = TypeVar("TBaseInvocation", bound=BaseInvocation)
@@ -256,6 +290,26 @@ class InvocationRegistry:
@classmethod
def register_invocation(cls, invocation: type[BaseInvocation]) -> None:
"""Registers an invocation."""
invocation_type = invocation.get_type()
node_pack = invocation.UIConfig.node_pack
# Log a warning when an existing invocation is being clobbered by the one we are registering
clobbered_invocation = InvocationRegistry.get_invocation_for_type(invocation_type)
if clobbered_invocation is not None:
# This should always be true - we just checked if the invocation type was in the set
clobbered_node_pack = clobbered_invocation.UIConfig.node_pack
if clobbered_node_pack == "invokeai":
# The invocation being clobbered is a core invocation
logger.warning(f'Overriding core node "{invocation_type}" with node from "{node_pack}"')
else:
# The invocation being clobbered is a custom invocation
logger.warning(
f'Overriding node "{invocation_type}" from "{node_pack}" with node from "{clobbered_node_pack}"'
)
cls._invocation_classes.remove(clobbered_invocation)
cls._invocation_classes.add(invocation)
cls.invalidate_invocation_typeadapter()
@@ -314,6 +368,15 @@ class InvocationRegistry:
@classmethod
def register_output(cls, output: "type[TBaseInvocationOutput]") -> None:
"""Registers an invocation output."""
output_type = output.get_type()
# Log a warning when an existing invocation is being clobbered by the one we are registering
clobbered_output = InvocationRegistry.get_output_for_type(output_type)
if clobbered_output is not None:
# TODO(psyche): We do not record the node pack of the output, so we cannot log it here
logger.warning(f'Overriding invocation output "{output_type}"')
cls._output_classes.remove(clobbered_output)
cls._output_classes.add(output)
cls.invalidate_output_typeadapter()
@@ -322,6 +385,11 @@ class InvocationRegistry:
"""Gets all invocation outputs."""
return cls._output_classes
@classmethod
def get_outputs_map(cls) -> dict[str, type[BaseInvocationOutput]]:
"""Gets a map of all output types to their output classes."""
return {i.get_type(): i for i in cls.get_output_classes()}
@classmethod
@lru_cache(maxsize=1)
def get_output_typeadapter(cls) -> TypeAdapter[Any]:
@@ -347,6 +415,11 @@ class InvocationRegistry:
"""Gets all invocation output types."""
return (i.get_type() for i in cls.get_output_classes())
@classmethod
def get_output_for_type(cls, output_type: str) -> type[BaseInvocationOutput] | None:
"""Gets the output class for a given output type."""
return cls.get_outputs_map().get(output_type)
RESERVED_NODE_ATTRIBUTE_FIELD_NAMES = {
"id",
@@ -354,11 +427,12 @@ RESERVED_NODE_ATTRIBUTE_FIELD_NAMES = {
"use_cache",
"type",
"workflow",
"bottleneck",
}
RESERVED_INPUT_FIELD_NAMES = {"metadata", "board"}
RESERVED_OUTPUT_FIELD_NAMES = {"type"}
RESERVED_OUTPUT_FIELD_NAMES = {"type", "output_meta"}
class _Model(BaseModel):
@@ -425,11 +499,53 @@ def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None
ui_type = field.json_schema_extra.get("ui_type", None)
if isinstance(ui_type, str) and ui_type.startswith("DEPRECATED_"):
logger.warn(f'"UIType.{ui_type.split("_")[-1]}" is deprecated, ignoring')
logger.warning(f'"UIType.{ui_type.split("_")[-1]}" is deprecated, ignoring')
field.json_schema_extra.pop("ui_type")
return None
class NoDefaultSentinel:
pass
def validate_field_default(
cls_name: str, field_name: str, invocation_type: str, annotation: Any, field_info: FieldInfo
) -> None:
"""Validates the default value of a field against its pydantic field definition."""
assert isinstance(field_info.json_schema_extra, dict), "json_schema_extra is not a dict"
# By the time we are doing this, we've already done some pydantic magic by overriding the original default value.
# We store the original default value in the json_schema_extra dict, so we can validate it here.
orig_default = field_info.json_schema_extra.get("orig_default", NoDefaultSentinel)
if orig_default is NoDefaultSentinel:
return
# To validate the default value, we can create a temporary pydantic model with the field we are validating as its
# only field. Then validate the default value against this temporary model.
TempDefaultValidator = cast(BaseModel, create_model(cls_name, **{field_name: (annotation, field_info)}))
try:
TempDefaultValidator.model_validate({field_name: orig_default})
except Exception as e:
raise InvalidFieldError(
f'Default value for field "{field_name}" on invocation "{invocation_type}" is invalid, {e}'
) from e
def is_optional(annotation: Any) -> bool:
"""
Checks if the given annotation is optional (i.e. Optional[X], Union[X, None] or X | None).
"""
origin = typing.get_origin(annotation)
# PEP 604 unions (int|None) have origin types.UnionType
is_union = origin is typing.Union or origin is types.UnionType
if not is_union:
return False
return any(arg is type(None) for arg in typing.get_args(annotation))
def invocation(
invocation_type: str,
title: Optional[str] = None,
@@ -438,6 +554,7 @@ def invocation(
version: Optional[str] = None,
use_cache: Optional[bool] = True,
classification: Classification = Classification.Stable,
bottleneck: Bottleneck = Bottleneck.GPU,
) -> Callable[[Type[TBaseInvocation]], Type[TBaseInvocation]]:
"""
Registers an invocation.
@@ -449,6 +566,7 @@ def invocation(
:param Optional[str] version: Adds a version to the invocation. Must be a valid semver string. Defaults to None.
:param Optional[bool] use_cache: Whether or not to use the invocation cache. Defaults to True. The user may override this in the workflow editor.
:param Classification classification: The classification of the invocation. Defaults to FeatureClassification.Stable. Use Beta or Prototype if the invocation is unstable.
:param Bottleneck bottleneck: The bottleneck of the invocation. Defaults to Bottleneck.GPU. Use Network if the invocation is network-bound.
"""
def wrapper(cls: Type[TBaseInvocation]) -> Type[TBaseInvocation]:
@@ -460,27 +578,28 @@ def invocation(
# The node pack is the module name - will be "invokeai" for built-in nodes
node_pack = cls.__module__.split(".")[0]
# Handle the case where an existing node is being clobbered by the one we are registering
if invocation_type in InvocationRegistry.get_invocation_types():
clobbered_invocation = InvocationRegistry.get_invocation_for_type(invocation_type)
# This should always be true - we just checked if the invocation type was in the set
assert clobbered_invocation is not None
clobbered_node_pack = clobbered_invocation.UIConfig.node_pack
if clobbered_node_pack == "invokeai":
# The node being clobbered is a core node
raise ValueError(
f'Cannot load node "{invocation_type}" from node pack "{node_pack}" - a core node with the same type already exists'
)
else:
# The node being clobbered is a custom node
raise ValueError(
f'Cannot load node "{invocation_type}" from node pack "{node_pack}" - a node with the same type already exists in node pack "{clobbered_node_pack}"'
)
validate_fields(cls.model_fields, invocation_type)
fields: dict[str, tuple[Any, FieldInfo]] = {}
original_model_fields: dict[str, OriginalModelField] = {}
for field_name, field_info in cls.model_fields.items():
annotation = field_info.annotation
assert annotation is not None, f"{field_name} on invocation {invocation_type} has no type annotation."
assert isinstance(field_info.json_schema_extra, dict), (
f"{field_name} on invocation {invocation_type} has a non-dict json_schema_extra, did you forget to use InputField?"
)
original_model_fields[field_name] = OriginalModelField(annotation=annotation, field_info=field_info)
validate_field_default(cls.__name__, field_name, invocation_type, annotation, field_info)
if field_info.default is None and not is_optional(annotation):
annotation = annotation | None
fields[field_name] = (annotation, field_info)
# Add OpenAPI schema extras
uiconfig: dict[str, Any] = {}
uiconfig["title"] = title
@@ -496,7 +615,7 @@ def invocation(
raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e
uiconfig["version"] = version
else:
logger.warn(f'No version specified for node "{invocation_type}", using "1.0.0"')
logger.warning(f'No version specified for node "{invocation_type}", using "1.0.0"')
uiconfig["version"] = "1.0.0"
cls.UIConfig = UIConfigBase(**uiconfig)
@@ -504,6 +623,8 @@ def invocation(
if use_cache is not None:
cls.model_fields["use_cache"].default = use_cache
cls.bottleneck = bottleneck
# Add the invocation type to the model.
# You'd be tempted to just add the type field and rebuild the model, like this:
@@ -513,11 +634,27 @@ def invocation(
# Unfortunately, because the `GraphInvocation` uses a forward ref in its `graph` field's annotation, this does
# not work. Instead, we have to create a new class with the type field and patch the original class with it.
invocation_type_annotation = Literal[invocation_type] # type: ignore
invocation_type_field = Field(
title="type", default=invocation_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute}
invocation_type_annotation = Literal[invocation_type]
# Field() returns an instance of FieldInfo, but thanks to a pydantic implementation detail, it is _typed_ as Any.
# This cast makes the type annotation match the class's true type.
invocation_type_field_info = cast(
FieldInfo,
Field(title="type", default=invocation_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute}),
)
fields["type"] = (invocation_type_annotation, invocation_type_field_info)
# Invocation outputs must be registered using the @invocation_output decorator, but it is possible that the
# output is registered _after_ this invocation is registered. It depends on module import ordering.
#
# We can only confirm the output for an invocation is registered after all modules are imported. There's
# only really one good time to do that - during application startup, in `run_app.py`, after loading all
# custom nodes.
#
# We can still do some basic validation here - ensure the invoke method is defined and returns an instance
# of BaseInvocationOutput.
# Validate the `invoke()` method is implemented
if "invoke" in cls.__abstractmethods__:
raise ValueError(f'Invocation "{invocation_type}" must implement the "invoke" method')
@@ -539,17 +676,13 @@ def invocation(
)
docstring = cls.__doc__
cls = create_model(
cls.__qualname__,
__base__=cls,
__module__=cls.__module__,
type=(invocation_type_annotation, invocation_type_field),
)
cls.__doc__ = docstring
new_class = create_model(cls.__qualname__, __base__=cls, __module__=cls.__module__, **fields) # type: ignore
new_class.__doc__ = docstring
new_class._original_model_fields = original_model_fields
InvocationRegistry.register_invocation(cls)
InvocationRegistry.register_invocation(new_class)
return cls
return new_class
return wrapper
@@ -572,29 +705,41 @@ def invocation_output(
if re.compile(r"^\S+$").match(output_type) is None:
raise ValueError(f'"output_type" must consist of non-whitespace characters, got "{output_type}"')
if output_type in InvocationRegistry.get_output_types():
raise ValueError(f'Invocation type "{output_type}" already exists')
validate_fields(cls.model_fields, output_type)
# Add the output type to the model.
fields: dict[str, tuple[Any, FieldInfo]] = {}
output_type_annotation = Literal[output_type] # type: ignore
output_type_field = Field(
title="type", default=output_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute}
for field_name, field_info in cls.model_fields.items():
annotation = field_info.annotation
assert annotation is not None, f"{field_name} on invocation output {output_type} has no type annotation."
assert isinstance(field_info.json_schema_extra, dict), (
f"{field_name} on invocation output {output_type} has a non-dict json_schema_extra, did you forget to use InputField?"
)
cls._original_model_fields[field_name] = OriginalModelField(annotation=annotation, field_info=field_info)
if field_info.default is not PydanticUndefined and is_optional(annotation):
annotation = annotation | None
fields[field_name] = (annotation, field_info)
# Add the output type to the model.
output_type_annotation = Literal[output_type]
# Field() returns an instance of FieldInfo, but thanks to a pydantic implementation detail, it is _typed_ as Any.
# This cast makes the type annotation match the class's true type.
output_type_field_info = cast(
FieldInfo,
Field(title="type", default=output_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute}),
)
fields["type"] = (output_type_annotation, output_type_field_info)
docstring = cls.__doc__
cls = create_model(
cls.__qualname__,
__base__=cls,
__module__=cls.__module__,
type=(output_type_annotation, output_type_field),
)
cls.__doc__ = docstring
new_class = create_model(cls.__qualname__, __base__=cls, __module__=cls.__module__, **fields)
new_class.__doc__ = docstring
InvocationRegistry.register_output(cls)
InvocationRegistry.register_output(new_class)
return cls
return new_class
return wrapper

View File

@@ -64,7 +64,6 @@ class ImageBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each image in the batch."""
images: list[ImageField] = InputField(
default=[],
min_length=1,
description="The images to batch over",
)
@@ -120,7 +119,6 @@ class StringBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each string in the batch."""
strings: list[str] = InputField(
default=[],
min_length=1,
description="The strings to batch over",
)
@@ -176,7 +174,6 @@ class IntegerBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each integer in the batch."""
integers: list[int] = InputField(
default=[],
min_length=1,
description="The integers to batch over",
)
@@ -230,7 +227,6 @@ class FloatBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each float in the batch."""
floats: list[float] = InputField(
default=[],
min_length=1,
description="The floats to batch over",
)

View File

@@ -0,0 +1,158 @@
import cv2
import numpy as np
from PIL import Image
from pydantic import BaseModel, Field
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
InputField,
OutputField,
UIType,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.bria.controlnet_aux.open_pose import Body, Face, Hand, OpenposeDetector
from invokeai.backend.bria.controlnet_bria import BRIA_CONTROL_MODES
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
from invokeai.invocation_api import Classification, ImageOutput
DEPTH_SMALL_V2_URL = "depth-anything/Depth-Anything-V2-Small-hf"
HF_LLLYASVIEL = "https://huggingface.co/lllyasviel/Annotators/resolve/main/"
class BriaControlNetField(BaseModel):
image: ImageField = Field(description="The control image")
model: ModelIdentifierField = Field(description="The ControlNet model to use")
mode: BRIA_CONTROL_MODES = Field(description="The mode of the ControlNet")
conditioning_scale: float = Field(description="The weight given to the ControlNet")
@invocation_output("bria_controlnet_output")
class BriaControlNetOutput(BaseInvocationOutput):
"""Bria ControlNet info"""
control: BriaControlNetField = OutputField(description=FieldDescriptions.control)
preprocessed_images: ImageField = OutputField(description="The preprocessed control image")
@invocation(
"bria_controlnet",
title="ControlNet - Bria",
tags=["controlnet", "bria"],
category="controlnet",
version="1.0.0",
classification=Classification.Prototype,
)
class BriaControlNetInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Collect Bria ControlNet info to pass to denoiser node."""
control_image: ImageField = InputField(description="The control image")
control_model: ModelIdentifierField = InputField(
description=FieldDescriptions.controlnet_model, ui_type=UIType.BriaControlNetModel
)
control_mode: BRIA_CONTROL_MODES = InputField(default="depth", description="The mode of the ControlNet")
control_weight: float = InputField(default=1.0, ge=-1, le=2, description="The weight given to the ControlNet")
def invoke(self, context: InvocationContext) -> BriaControlNetOutput:
image_in = resize_img(context.images.get_pil(self.control_image.image_name))
if self.control_mode == "canny":
control_image = extract_canny(image_in)
elif self.control_mode == "depth":
control_image = extract_depth(image_in, context)
elif self.control_mode == "pose":
control_image = extract_openpose(image_in, context)
elif self.control_mode == "colorgrid":
control_image = tile(64, image_in)
elif self.control_mode == "recolor":
control_image = convert_to_grayscale(image_in)
elif self.control_mode == "tile":
control_image = tile(16, image_in)
control_image = resize_img(control_image)
image_dto = context.images.save(image=control_image)
image_output = ImageOutput.build(image_dto)
return BriaControlNetOutput(
preprocessed_images=image_output.image,
control=BriaControlNetField(
image=ImageField(image_name=image_dto.image_name),
model=self.control_model,
mode=self.control_mode,
conditioning_scale=self.control_weight,
),
)
RATIO_CONFIGS_1024 = {
0.6666666666666666: {"width": 832, "height": 1248},
0.7432432432432432: {"width": 880, "height": 1184},
0.8028169014084507: {"width": 912, "height": 1136},
1.0: {"width": 1024, "height": 1024},
1.2456140350877194: {"width": 1136, "height": 912},
1.3454545454545455: {"width": 1184, "height": 880},
1.4339622641509433: {"width": 1216, "height": 848},
1.5: {"width": 1248, "height": 832},
1.5490196078431373: {"width": 1264, "height": 816},
1.62: {"width": 1296, "height": 800},
1.7708333333333333: {"width": 1360, "height": 768},
}
def extract_depth(image: Image.Image, context: InvocationContext):
loaded_model = context.models.load_remote_model(DEPTH_SMALL_V2_URL, DepthAnythingPipeline.load_model)
with loaded_model as depth_anything_detector:
assert isinstance(depth_anything_detector, DepthAnythingPipeline)
depth_map = depth_anything_detector.generate_depth(image)
return depth_map
def extract_openpose(image: Image.Image, context: InvocationContext):
body_model = context.models.load_remote_model(f"{HF_LLLYASVIEL}body_pose_model.pth", Body)
hand_model = context.models.load_remote_model(f"{HF_LLLYASVIEL}hand_pose_model.pth", Hand)
face_model = context.models.load_remote_model(f"{HF_LLLYASVIEL}facenet.pth", Face)
with body_model as body_model, hand_model as hand_model, face_model as face_model:
open_pose_model = OpenposeDetector(body_model, hand_model, face_model)
processed_image_open_pose = open_pose_model(image, hand_and_face=True)
processed_image_open_pose = processed_image_open_pose.resize(image.size)
return processed_image_open_pose
def extract_canny(input_image):
image = np.array(input_image)
image = cv2.Canny(image, 100, 200)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
canny_image = Image.fromarray(image)
return canny_image
def convert_to_grayscale(image):
gray_image = image.convert("L").convert("RGB")
return gray_image
def tile(downscale_factor, input_image):
control_image = input_image.resize(
(input_image.size[0] // downscale_factor, input_image.size[1] // downscale_factor)
).resize(input_image.size, Image.Resampling.NEAREST)
return control_image
def resize_img(control_image):
image_ratio = control_image.width / control_image.height
ratio = min(RATIO_CONFIGS_1024.keys(), key=lambda k: abs(k - image_ratio))
to_height = RATIO_CONFIGS_1024[ratio]["height"]
to_width = RATIO_CONFIGS_1024[ratio]["width"]
resized_image = control_image.resize((to_width, to_height), resample=Image.Resampling.LANCZOS)
return resized_image

View File

@@ -0,0 +1,46 @@
import torch
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from PIL import Image
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import FieldDescriptions, Input, InputField, LatentsField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.invocation_api import BaseInvocation, Classification, ImageOutput, invocation
@invocation(
"bria_decoder",
title="Decoder - Bria",
tags=["image", "bria"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class BriaDecoderInvocation(BaseInvocation):
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
latents: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
latents = latents.view(1, 64, 64, 4, 2, 2).permute(0, 3, 1, 4, 2, 5).reshape(1, 4, 128, 128)
with context.models.load(self.vae.vae) as vae:
assert isinstance(vae, AutoencoderKL)
latents = latents / vae.config.scaling_factor
latents = latents.to(device=vae.device, dtype=vae.dtype)
decoded_output = vae.decode(latents)
image = decoded_output.sample
# Convert to numpy with proper gradient handling
image = ((image.clamp(-1, 1) + 1) / 2 * 255).cpu().detach().permute(0, 2, 3, 1).numpy().astype("uint8")[0]
img = Image.fromarray(image)
image_dto = context.images.save(image=img)
return ImageOutput.build(image_dto)

View File

@@ -0,0 +1,180 @@
from typing import List, Tuple
import torch
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
from invokeai.app.invocations.bria_controlnet import BriaControlNetField
from invokeai.app.invocations.fields import Input, InputField, LatentsField, OutputField
from invokeai.app.invocations.model import SubModelType, T5EncoderField, TransformerField, VAEField
from invokeai.app.invocations.primitives import BaseInvocationOutput, FieldDescriptions
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.bria.controlnet_bria import BriaControlModes, BriaMultiControlNetModel
from invokeai.backend.bria.controlnet_utils import prepare_control_images
from invokeai.backend.bria.pipeline_bria_controlnet import BriaControlNetPipeline
from invokeai.backend.bria.transformer_bria import BriaTransformer2DModel
from invokeai.invocation_api import BaseInvocation, Classification, invocation, invocation_output
@invocation_output("bria_denoise_output")
class BriaDenoiseInvocationOutput(BaseInvocationOutput):
latents: LatentsField = OutputField(description=FieldDescriptions.latents)
@invocation(
"bria_denoise",
title="Denoise - Bria",
tags=["image", "bria"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class BriaDenoiseInvocation(BaseInvocation):
num_steps: int = InputField(
default=30, title="Number of Steps", description="The number of steps to use for the denoiser"
)
guidance_scale: float = InputField(
default=5.0, title="Guidance Scale", description="The guidance scale to use for the denoiser"
)
transformer: TransformerField = InputField(
description="Bria model (Transformer) to load",
input=Input.Connection,
title="Transformer",
)
t5_encoder: T5EncoderField = InputField(
title="T5Encoder",
description=FieldDescriptions.t5_encoder,
input=Input.Connection,
)
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
title="VAE",
)
latents: LatentsField = InputField(
description="Latents to denoise",
input=Input.Connection,
title="Latents",
)
latent_image_ids: LatentsField = InputField(
description="Latent Image IDs to denoise",
input=Input.Connection,
title="Latent Image IDs",
)
pos_embeds: LatentsField = InputField(
description="Positive Prompt Embeds",
input=Input.Connection,
title="Positive Prompt Embeds",
)
neg_embeds: LatentsField = InputField(
description="Negative Prompt Embeds",
input=Input.Connection,
title="Negative Prompt Embeds",
)
text_ids: LatentsField = InputField(
description="Text IDs",
input=Input.Connection,
title="Text IDs",
)
control: BriaControlNetField | list[BriaControlNetField] | None = InputField(
description="ControlNet",
input=Input.Connection,
title="ControlNet",
default=None,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput:
latents = context.tensors.load(self.latents.latents_name)
pos_embeds = context.tensors.load(self.pos_embeds.latents_name)
neg_embeds = context.tensors.load(self.neg_embeds.latents_name)
text_ids = context.tensors.load(self.text_ids.latents_name)
latent_image_ids = context.tensors.load(self.latent_image_ids.latents_name)
scheduler_identifier = self.transformer.transformer.model_copy(update={"submodel_type": SubModelType.Scheduler})
device = None
dtype = None
with (
context.models.load(self.transformer.transformer) as transformer,
context.models.load(scheduler_identifier) as scheduler,
context.models.load(self.vae.vae) as vae,
context.models.load(self.t5_encoder.text_encoder) as t5_encoder,
context.models.load(self.t5_encoder.tokenizer) as t5_tokenizer,
):
assert isinstance(transformer, BriaTransformer2DModel)
assert isinstance(scheduler, FlowMatchEulerDiscreteScheduler)
assert isinstance(vae, AutoencoderKL)
dtype = transformer.dtype
device = transformer.device
latents, pos_embeds, neg_embeds = (x.to(device, dtype) for x in (latents, pos_embeds, neg_embeds))
control_model, control_images, control_modes, control_scales = None, None, None, None
if self.control is not None:
control_model, control_images, control_modes, control_scales = self._prepare_multi_control(
context=context,
vae=vae,
width=1024,
height=1024,
device=vae.device,
)
pipeline = BriaControlNetPipeline(
transformer=transformer,
scheduler=scheduler,
vae=vae,
text_encoder=t5_encoder,
tokenizer=t5_tokenizer,
controlnet=control_model,
)
pipeline.to(device=transformer.device, dtype=transformer.dtype)
latents = pipeline(
control_image=control_images,
control_mode=control_modes,
width=1024,
height=1024,
controlnet_conditioning_scale=control_scales,
num_inference_steps=self.num_steps,
max_sequence_length=128,
guidance_scale=self.guidance_scale,
latents=latents,
latent_image_ids=latent_image_ids,
text_ids=text_ids,
prompt_embeds=pos_embeds,
negative_prompt_embeds=neg_embeds,
output_type="latent",
)[0]
assert isinstance(latents, torch.Tensor)
saved_input_latents_tensor = context.tensors.save(latents)
latents_output = LatentsField(latents_name=saved_input_latents_tensor)
return BriaDenoiseInvocationOutput(latents=latents_output)
def _prepare_multi_control(
self, context: InvocationContext, vae: AutoencoderKL, width: int, height: int, device: torch.device
) -> Tuple[BriaMultiControlNetModel, List[torch.Tensor], List[torch.Tensor], List[float]]:
control = self.control if isinstance(self.control, list) else [self.control]
control_images, control_models, control_modes, control_scales = [], [], [], []
for controlnet in control:
if controlnet is not None:
control_models.append(context.models.load(controlnet.model).model)
control_modes.append(BriaControlModes[controlnet.mode].value)
control_scales.append(controlnet.conditioning_scale)
try:
control_images.append(context.images.get_pil(controlnet.image.image_name))
except Exception:
raise FileNotFoundError(
f"Control image {controlnet.image.image_name} not found. Make sure not to delete the preprocessed image before finishing the pipeline."
)
control_model = BriaMultiControlNetModel(control_models).to(device)
tensored_control_images, tensored_control_modes = prepare_control_images(
vae=vae,
control_images=control_images,
control_modes=control_modes,
width=width,
height=height,
device=device,
)
return control_model, tensored_control_images, tensored_control_modes, control_scales

View File

@@ -0,0 +1,76 @@
import torch
from invokeai.app.invocations.fields import Input, InputField, OutputField
from invokeai.app.invocations.model import TransformerField
from invokeai.app.invocations.primitives import (
BaseInvocationOutput,
FieldDescriptions,
LatentsField,
)
from invokeai.backend.bria.pipeline_bria_controlnet import prepare_latents
from invokeai.invocation_api import (
BaseInvocation,
Classification,
InvocationContext,
invocation,
invocation_output,
)
@invocation_output("bria_latent_sampler_output")
class BriaLatentSamplerInvocationOutput(BaseInvocationOutput):
"""Base class for nodes that output a CogView text conditioning tensor."""
latents: LatentsField = OutputField(description=FieldDescriptions.cond)
latent_image_ids: LatentsField = OutputField(description=FieldDescriptions.cond)
@invocation(
"bria_latent_sampler",
title="Latent Sampler - Bria",
tags=["image", "bria"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class BriaLatentSamplerInvocation(BaseInvocation):
seed: int = InputField(
default=42,
title="Seed",
description="The seed to use for the latent sampler",
)
transformer: TransformerField = InputField(
description="Bria model (Transformer) to load",
input=Input.Connection,
title="Transformer",
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> BriaLatentSamplerInvocationOutput:
with context.models.load(self.transformer.transformer) as transformer:
device = transformer.device
dtype = transformer.dtype
height, width = 1024, 1024
generator = torch.Generator(device=device).manual_seed(self.seed)
num_channels_latents = 4
latents, latent_image_ids = prepare_latents(
batch_size=1,
num_channels_latents=num_channels_latents,
height=height,
width=width,
dtype=dtype,
device=device,
generator=generator,
)
saved_latents_tensor = context.tensors.save(latents)
saved_latent_image_ids_tensor = context.tensors.save(latent_image_ids)
latents_output = LatentsField(latents_name=saved_latents_tensor)
latent_image_ids_output = LatentsField(latents_name=saved_latent_image_ids_tensor)
return BriaLatentSamplerInvocationOutput(
latents=latents_output,
latent_image_ids=latent_image_ids_output,
)

View File

@@ -0,0 +1,58 @@
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import (
ModelIdentifierField,
SubModelType,
T5EncoderField,
TransformerField,
VAEField,
)
from invokeai.invocation_api import (
BaseInvocation,
BaseInvocationOutput,
Classification,
InvocationContext,
invocation,
invocation_output,
)
@invocation_output("bria_model_loader_output")
class BriaModelLoaderOutput(BaseInvocationOutput):
"""Bria base model loader output"""
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation(
"bria_model_loader",
title="Main Model - Bria",
tags=["model", "bria"],
version="1.0.0",
classification=Classification.Prototype,
)
class BriaModelLoaderInvocation(BaseInvocation):
"""Loads a bria base model, outputting its submodels."""
model: ModelIdentifierField = InputField(
description="Bria model (Transformer) to load",
ui_type=UIType.BriaMainModel,
input=Input.Direct,
)
def invoke(self, context: InvocationContext) -> BriaModelLoaderOutput:
for key in [self.model.key]:
if not context.models.exists(key):
raise ValueError(f"Unknown model: {key}")
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
return BriaModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[]),
t5_encoder=T5EncoderField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[]),
vae=VAEField(vae=vae),
)

View File

@@ -0,0 +1,93 @@
from typing import Optional
import torch
from transformers import (
T5EncoderModel,
T5TokenizerFast,
)
from invokeai.app.invocations.model import T5EncoderField
from invokeai.app.invocations.primitives import BaseInvocationOutput, FieldDescriptions, Input, OutputField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.bria.pipeline_bria_controlnet import encode_prompt
from invokeai.invocation_api import (
BaseInvocation,
Classification,
InputField,
LatentsField,
invocation,
invocation_output,
)
@invocation_output("bria_text_encoder_output")
class BriaTextEncoderInvocationOutput(BaseInvocationOutput):
"""Base class for nodes that output a CogView text conditioning tensor."""
pos_embeds: LatentsField = OutputField(description=FieldDescriptions.cond)
neg_embeds: LatentsField = OutputField(description=FieldDescriptions.cond)
text_ids: LatentsField = OutputField(description=FieldDescriptions.cond)
@invocation(
"bria_text_encoder",
title="Prompt - Bria",
tags=["prompt", "conditioning", "bria"],
category="conditioning",
version="1.0.0",
classification=Classification.Prototype,
)
class BriaTextEncoderInvocation(BaseInvocation):
prompt: str = InputField(
title="Prompt",
description="The prompt to encode",
)
negative_prompt: Optional[str] = InputField(
title="Negative Prompt",
description="The negative prompt to encode",
default="Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate",
)
max_length: int = InputField(
default=128,
title="Max Length",
description="The maximum length of the prompt",
)
t5_encoder: T5EncoderField = InputField(
title="T5Encoder",
description=FieldDescriptions.t5_encoder,
input=Input.Connection,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> BriaTextEncoderInvocationOutput:
t5_encoder_info = context.models.load(self.t5_encoder.text_encoder)
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
with (
t5_encoder_info as text_encoder,
t5_tokenizer_info as tokenizer,
):
assert isinstance(tokenizer, T5TokenizerFast)
assert isinstance(text_encoder, T5EncoderModel)
(prompt_embeds, negative_prompt_embeds, text_ids) = encode_prompt(
prompt=self.prompt,
tokenizer=tokenizer,
text_encoder=text_encoder,
negative_prompt=self.negative_prompt,
device=text_encoder.device,
num_images_per_prompt=1,
max_sequence_length=self.max_length,
lora_scale=1.0,
)
saved_pos_tensor = context.tensors.save(prompt_embeds)
saved_neg_tensor = context.tensors.save(negative_prompt_embeds)
saved_text_ids_tensor = context.tensors.save(text_ids)
pos_embeds_output = LatentsField(latents_name=saved_pos_tensor)
neg_embeds_output = LatentsField(latents_name=saved_neg_tensor)
text_ids_output = LatentsField(latents_name=saved_text_ids_tensor)
return BriaTextEncoderInvocationOutput(
pos_embeds=pos_embeds_output,
neg_embeds=neg_embeds_output,
text_ids=text_ids_output,
)

View File

@@ -1,7 +1,7 @@
from typing import Iterator, List, Optional, Tuple, Union, cast
import torch
from compel import Compel, ReturnedEmbeddingsType
from compel import Compel, ReturnedEmbeddingsType, SplitLongTextMode
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
@@ -104,6 +104,7 @@ class CompelInvocation(BaseInvocation):
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
truncate_long_prompts=False,
device=TorchDevice.choose_torch_device(),
split_long_text_mode=SplitLongTextMode.SENTENCES,
)
conjunction = Compel.parse_prompt_string(self.prompt)
@@ -113,6 +114,13 @@ class CompelInvocation(BaseInvocation):
c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction)
del compel
del patched_tokenizer
del tokenizer
del ti_manager
del text_encoder
del text_encoder_info
c = c.detach().to("cpu")
conditioning_data = ConditioningFieldData(conditionings=[BasicConditioningInfo(embeds=c)])
@@ -205,6 +213,7 @@ class SDXLPromptInvocationBase:
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
requires_pooled=get_pooled,
device=TorchDevice.choose_torch_device(),
split_long_text_mode=SplitLongTextMode.SENTENCES,
)
conjunction = Compel.parse_prompt_string(prompt)
@@ -220,7 +229,10 @@ class SDXLPromptInvocationBase:
else:
c_pooled = None
del compel
del patched_tokenizer
del tokenizer
del ti_manager
del text_encoder
del text_encoder_info

View File

@@ -274,12 +274,12 @@ class InvokeAdjustImageHuePlusInvocation(BaseInvocation, WithMetadata, WithBoard
title="Enhance Image",
tags=["enhance", "image"],
category="image",
version="1.2.0",
version="1.2.1",
)
class InvokeImageEnhanceInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Applies processing from PIL's ImageEnhance module. Originally created by @dwringer"""
image: ImageField = InputField(default=None, description="The image for which to apply processing")
image: ImageField = InputField(description="The image for which to apply processing")
invert: bool = InputField(default=False, description="Whether to invert the image colors")
color: float = InputField(ge=0, default=1.0, description="Color enhancement factor")
contrast: float = InputField(ge=0, default=1.0, description="Contrast enhancement factor")

View File

@@ -22,7 +22,11 @@ from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
from invokeai.app.util.controlnet_utils import (
CONTROLNET_MODE_VALUES,
CONTROLNET_RESIZE_VALUES,
heuristic_resize_fast,
)
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
@@ -109,7 +113,7 @@ class ControlNetInvocation(BaseInvocation):
title="Heuristic Resize",
tags=["image, controlnet"],
category="image",
version="1.0.1",
version="1.1.1",
classification=Classification.Prototype,
)
class HeuristicResizeInvocation(BaseInvocation):
@@ -122,7 +126,7 @@ class HeuristicResizeInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
np_img = pil_to_np(image)
np_resized = heuristic_resize(np_img, (self.width, self.height))
np_resized = heuristic_resize_fast(np_img, (self.width, self.height))
resized = np_to_pil(np_resized)
image_dto = context.images.save(image=resized)
return ImageOutput.build(image_dto)

View File

@@ -1,12 +1,14 @@
from typing import Literal, Optional
import cv2
import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image, ImageFilter
from PIL import Image
from torchvision.transforms.functional import resize as tv_resize
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
DenoiseMaskField,
FieldDescriptions,
@@ -42,15 +44,13 @@ class GradientMaskOutput(BaseInvocationOutput):
title="Create Gradient Mask",
tags=["mask", "denoise"],
category="latents",
version="1.2.0",
version="1.3.0",
)
class CreateGradientMaskInvocation(BaseInvocation):
"""Creates mask for denoising model run."""
"""Creates mask for denoising."""
mask: ImageField = InputField(default=None, description="Image which will be masked", ui_order=1)
edge_radius: int = InputField(
default=16, ge=0, description="How far to blur/expand the edges of the mask", ui_order=2
)
mask: ImageField = InputField(description="Image which will be masked", ui_order=1)
edge_radius: int = InputField(default=16, ge=0, description="How far to expand the edges of the mask", ui_order=2)
coherence_mode: Literal["Gaussian Blur", "Box Blur", "Staged"] = InputField(default="Gaussian Blur", ui_order=3)
minimum_denoise: float = InputField(
default=0.0, ge=0, le=1, description="Minimum denoise level for the coherence region", ui_order=4
@@ -81,45 +81,110 @@ class CreateGradientMaskInvocation(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> GradientMaskOutput:
mask_image = context.images.get_pil(self.mask.image_name, mode="L")
# Resize the mask_image. Makes the filter 64x faster and doesn't hurt quality in latent scale anyway
mask_image = mask_image.resize(
(
mask_image.width // LATENT_SCALE_FACTOR,
mask_image.height // LATENT_SCALE_FACTOR,
),
resample=Image.Resampling.BILINEAR,
)
mask_np_orig = np.array(mask_image, dtype=np.float32)
self.edge_radius = self.edge_radius // LATENT_SCALE_FACTOR # scale the edge radius to match the mask size
if self.edge_radius > 0:
mask_np = 255 - mask_np_orig # invert so 0 is unmasked (higher values = higher denoise strength)
dilated_mask = mask_np.copy()
# Create kernel based on coherence mode
if self.coherence_mode == "Box Blur":
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
else: # Gaussian Blur OR Staged
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
# Create a circular distance kernel that fades from center outward
kernel_size = self.edge_radius * 2 + 1
center = self.edge_radius
kernel = np.zeros((kernel_size, kernel_size), dtype=np.float32)
for i in range(kernel_size):
for j in range(kernel_size):
dist = np.sqrt((i - center) ** 2 + (j - center) ** 2)
if dist <= self.edge_radius:
kernel[i, j] = 1.0 - (dist / self.edge_radius)
else: # Gaussian Blur or Staged
# Create a Gaussian kernel
kernel_size = self.edge_radius * 2 + 1
kernel = cv2.getGaussianKernel(
kernel_size, self.edge_radius / 2.5
) # 2.5 is a magic number (standard deviation capturing)
kernel = kernel * kernel.T # Make 2D gaussian kernel
kernel = kernel / np.max(kernel) # Normalize center to 1.0
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
# Ensure values outside radius are 0
center = self.edge_radius
for i in range(kernel_size):
for j in range(kernel_size):
dist = np.sqrt((i - center) ** 2 + (j - center) ** 2)
if dist > self.edge_radius:
kernel[i, j] = 0
# redistribute blur so that the original edges are 0 and blur outwards to 1
blur_tensor = (blur_tensor - 0.5) * 2
blur_tensor[blur_tensor < 0] = 0.0
# 2D max filter
mask_tensor = torch.tensor(mask_np)
kernel_tensor = torch.tensor(kernel)
dilated_mask = 255 - self.max_filter2D_torch(mask_tensor, kernel_tensor).cpu()
dilated_mask = dilated_mask.numpy()
threshold = 1 - self.minimum_denoise
threshold = (1 - self.minimum_denoise) * 255
if self.coherence_mode == "Staged":
# wherever the blur_tensor is less than fully masked, convert it to threshold
blur_tensor = torch.where((blur_tensor < 1) & (blur_tensor > 0), threshold, blur_tensor)
else:
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
# wherever expanded mask is darker than the original mask but original was above threshhold, set it to the threshold
# makes any expansion areas drop to threshhold. Raising minimum across the image happen outside of this if
threshold_mask = (dilated_mask < mask_np_orig) & (mask_np_orig > threshold)
dilated_mask = np.where(threshold_mask, threshold, mask_np_orig)
# wherever expanded mask is less than 255 but greater than threshold, drop it to threshold (minimum denoise)
threshold_mask = (dilated_mask > threshold) & (dilated_mask < 255)
dilated_mask = np.where(threshold_mask, threshold, dilated_mask)
else:
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
dilated_mask = mask_np_orig.copy()
mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1))
# convert to tensor
dilated_mask = np.clip(dilated_mask, 0, 255).astype(np.uint8)
mask_tensor = torch.tensor(dilated_mask, device=torch.device("cpu"))
# compute a [0, 1] mask from the blur_tensor
expanded_mask = torch.where((blur_tensor < 1), 0, 1)
expanded_mask_image = Image.fromarray((expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L")
# binary mask for compositing
expanded_mask = np.where((dilated_mask < 255), 0, 255)
expanded_mask_image = Image.fromarray(expanded_mask.astype(np.uint8), mode="L")
expanded_mask_image = expanded_mask_image.resize(
(
mask_image.width * LATENT_SCALE_FACTOR,
mask_image.height * LATENT_SCALE_FACTOR,
),
resample=Image.Resampling.NEAREST,
)
expanded_image_dto = context.images.save(expanded_mask_image)
# restore the original mask size
dilated_mask = Image.fromarray(dilated_mask.astype(np.uint8))
dilated_mask = dilated_mask.resize(
(
mask_image.width * LATENT_SCALE_FACTOR,
mask_image.height * LATENT_SCALE_FACTOR,
),
resample=Image.Resampling.NEAREST,
)
# stack the mask as a tensor, repeating 4 times on dimmension 1
dilated_mask_tensor = image_resized_to_grid_as_tensor(dilated_mask, normalize=False)
mask_name = context.tensors.save(tensor=dilated_mask_tensor.unsqueeze(0))
masked_latents_name = None
if self.unet is not None and self.vae is not None and self.image is not None:
# all three fields must be present at the same time
main_model_config = context.models.get_config(self.unet.unet.key)
assert isinstance(main_model_config, MainConfigBase)
if main_model_config.variant is ModelVariantType.Inpaint:
mask = blur_tensor
mask = dilated_mask_tensor
vae_info: LoadedModel = context.models.load(self.vae.vae)
image = context.images.get_pil(self.image.image_name)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
@@ -137,3 +202,29 @@ class CreateGradientMaskInvocation(BaseInvocation):
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=masked_latents_name, gradient=True),
expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),
)
def max_filter2D_torch(self, image: torch.Tensor, kernel: torch.Tensor) -> torch.Tensor:
"""
This morphological operation is much faster in torch than numpy or opencv
For reasonable kernel sizes, the overhead of copying the data to the GPU is not worth it.
"""
h, w = kernel.shape
pad_h, pad_w = h // 2, w // 2
padded = torch.nn.functional.pad(image, (pad_w, pad_w, pad_h, pad_h), mode="constant", value=0)
result = torch.zeros_like(image)
# This looks like it's inside out, but it does the same thing and is more efficient
for i in range(h):
for j in range(w):
weight = kernel[i, j]
if weight <= 0:
continue
# Extract the region from padded tensor
region = padded[i : i + image.shape[0], j : j + image.shape[1]]
# Apply weight and update max
result = torch.maximum(result, region * weight)
return result

View File

@@ -608,6 +608,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
end_step_percent=single_ip_adapter.end_step_percent,
ip_adapter_conditioning=IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds),
mask=mask,
method=single_ip_adapter.method,
)
)

View File

@@ -42,6 +42,8 @@ class UIType(str, Enum, metaclass=MetaEnum):
MainModel = "MainModelField"
CogView4MainModel = "CogView4MainModelField"
FluxMainModel = "FluxMainModelField"
BriaMainModel = "BriaMainModelField"
BriaControlNetModel = "BriaControlNetModelField"
SD3MainModel = "SD3MainModelField"
SDXLMainModel = "SDXLMainModelField"
SDXLRefinerModel = "SDXLRefinerModelField"
@@ -61,6 +63,10 @@ class UIType(str, Enum, metaclass=MetaEnum):
SigLipModel = "SigLipModelField"
FluxReduxModel = "FluxReduxModelField"
LlavaOnevisionModel = "LLaVAModelField"
Imagen3Model = "Imagen3ModelField"
Imagen4Model = "Imagen4ModelField"
ChatGPT4oModel = "ChatGPT4oModelField"
FluxKontextModel = "FluxKontextModelField"
# endregion
# region Misc Field Types
@@ -211,6 +217,7 @@ class FieldDescriptions:
flux_redux_conditioning = "FLUX Redux conditioning tensor"
vllm_model = "The VLLM model to use"
flux_fill_conditioning = "FLUX Fill conditioning tensor"
flux_kontext_conditioning = "FLUX Kontext conditioning (reference image)"
class ImageField(BaseModel):
@@ -287,6 +294,12 @@ class FluxFillConditioningField(BaseModel):
mask: TensorField = Field(description="The FLUX Fill inpaint mask.")
class FluxKontextConditioningField(BaseModel):
"""A conditioning field for FLUX Kontext (reference image)."""
image: ImageField = Field(description="The Kontext reference image.")
class SD3ConditioningField(BaseModel):
"""A conditioning tensor primitive value"""
@@ -398,8 +411,8 @@ class InputFieldJSONSchemaExtra(BaseModel):
"""
input: Input
orig_required: bool
field_kind: FieldKind
orig_required: bool = True
default: Optional[Any] = None
orig_default: Optional[Any] = None
ui_hidden: bool = False
@@ -434,7 +447,7 @@ class WithWorkflow:
workflow = None
def __init_subclass__(cls) -> None:
logger.warn(
logger.warning(
f"{cls.__module__.split('.')[0]}.{cls.__name__}: WithWorkflow is deprecated. Use `context.workflow` to access the workflow."
)
super().__init_subclass__()
@@ -496,7 +509,7 @@ def InputField(
input: Input = Input.Any,
ui_type: Optional[UIType] = None,
ui_component: Optional[UIComponent] = None,
ui_hidden: bool = False,
ui_hidden: Optional[bool] = None,
ui_order: Optional[int] = None,
ui_choice_labels: Optional[dict[str, str]] = None,
) -> Any:
@@ -532,15 +545,20 @@ def InputField(
json_schema_extra_ = InputFieldJSONSchemaExtra(
input=input,
ui_type=ui_type,
ui_component=ui_component,
ui_hidden=ui_hidden,
ui_order=ui_order,
ui_choice_labels=ui_choice_labels,
field_kind=FieldKind.Input,
orig_required=True,
)
if ui_type is not None:
json_schema_extra_.ui_type = ui_type
if ui_component is not None:
json_schema_extra_.ui_component = ui_component
if ui_hidden is not None:
json_schema_extra_.ui_hidden = ui_hidden
if ui_order is not None:
json_schema_extra_.ui_order = ui_order
if ui_choice_labels is not None:
json_schema_extra_.ui_choice_labels = ui_choice_labels
"""
There is a conflict between the typing of invocation definitions and the typing of an invocation's
`invoke()` function.
@@ -570,7 +588,7 @@ def InputField(
if default_factory is not _Unset and default_factory is not None:
default = default_factory()
logger.warn('"default_factory" is not supported, calling it now to set "default"')
logger.warning('"default_factory" is not supported, calling it now to set "default"')
# These are the args we may wish pass to the pydantic `Field()` function
field_args = {
@@ -612,7 +630,7 @@ def InputField(
return Field(
**provided_args,
json_schema_extra=json_schema_extra_.model_dump(exclude_none=True),
json_schema_extra=json_schema_extra_.model_dump(exclude_unset=True),
)

View File

@@ -16,13 +16,12 @@ from invokeai.app.invocations.fields import (
FieldDescriptions,
FluxConditioningField,
FluxFillConditioningField,
FluxKontextConditioningField,
FluxReduxConditioningField,
ImageField,
Input,
InputField,
LatentsField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.flux_controlnet import FluxControlNetField
from invokeai.app.invocations.flux_vae_encode import FluxVaeEncodeInvocation
@@ -34,6 +33,7 @@ from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXCo
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux
from invokeai.backend.flux.denoise import denoise
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
from invokeai.backend.flux.extensions.kontext_extension import KontextExtension
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
@@ -63,9 +63,9 @@ from invokeai.backend.util.devices import TorchDevice
title="FLUX Denoise",
tags=["image", "flux"],
category="image",
version="3.3.0",
version="4.0.0",
)
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
class FluxDenoiseInvocation(BaseInvocation):
"""Run denoising process with a FLUX transformer model."""
# If latents is provided, this means we are doing image-to-image.
@@ -145,11 +145,20 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
description=FieldDescriptions.vae,
input=Input.Connection,
)
# This node accepts a images for features like FLUX Fill, ControlNet, and Kontext, but needs to operate on them in
# latent space. We'll run the VAE to encode them in this node instead of requiring the user to run the VAE in
# upstream nodes.
ip_adapter: IPAdapterField | list[IPAdapterField] | None = InputField(
description=FieldDescriptions.ip_adapter, title="IP-Adapter", default=None, input=Input.Connection
)
kontext_conditioning: Optional[FluxKontextConditioningField] = InputField(
default=None,
description="FLUX Kontext conditioning (reference image).",
input=Input.Connection,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = self._run_diffusion(context)
@@ -376,6 +385,27 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
dtype=inference_dtype,
)
kontext_extension = None
if self.kontext_conditioning is not None:
if not self.controlnet_vae:
raise ValueError("A VAE (e.g., controlnet_vae) must be provided to use Kontext conditioning.")
kontext_extension = KontextExtension(
context=context,
kontext_conditioning=self.kontext_conditioning,
vae_field=self.controlnet_vae,
device=TorchDevice.choose_torch_device(),
dtype=inference_dtype,
)
# Prepare Kontext conditioning if provided
img_cond_seq = None
img_cond_seq_ids = None
if kontext_extension is not None:
# Ensure batch sizes match
kontext_extension.ensure_batch_size(x.shape[0])
img_cond_seq, img_cond_seq_ids = kontext_extension.kontext_latents, kontext_extension.kontext_ids
x = denoise(
model=transformer,
img=x,
@@ -391,6 +421,8 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
pos_ip_adapter_extensions=pos_ip_adapter_extensions,
neg_ip_adapter_extensions=neg_ip_adapter_extensions,
img_cond=img_cond,
img_cond_seq=img_cond_seq,
img_cond_seq_ids=img_cond_seq_ids,
)
x = unpack(x.float(), self.height, self.width)
@@ -865,7 +897,10 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
def step_callback(state: PipelineIntermediateState) -> None:
state.latents = unpack(state.latents.float(), self.height, self.width).squeeze()
# The denoise function now handles Kontext conditioning correctly,
# so we don't need to slice the latents here
latents = state.latents.float()
state.latents = unpack(latents, self.height, self.width).squeeze()
context.util.flux_step_callback(state)
return step_callback

View File

@@ -0,0 +1,40 @@
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import (
FieldDescriptions,
FluxKontextConditioningField,
InputField,
OutputField,
)
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.services.shared.invocation_context import InvocationContext
@invocation_output("flux_kontext_output")
class FluxKontextOutput(BaseInvocationOutput):
"""The conditioning output of a FLUX Kontext invocation."""
kontext_cond: FluxKontextConditioningField = OutputField(
description=FieldDescriptions.flux_kontext_conditioning, title="Kontext Conditioning"
)
@invocation(
"flux_kontext",
title="Kontext Conditioning - FLUX",
tags=["conditioning", "kontext", "flux"],
category="conditioning",
version="1.0.0",
)
class FluxKontextInvocation(BaseInvocation):
"""Prepares a reference image for FLUX Kontext conditioning."""
image: ImageField = InputField(description="The Kontext reference image.")
def invoke(self, context: InvocationContext) -> FluxKontextOutput:
"""Packages the provided image into a Kontext conditioning field."""
return FluxKontextOutput(kontext_cond=FluxKontextConditioningField(image=self.image))

View File

@@ -3,6 +3,7 @@ from typing import Literal, Optional
import torch
from PIL import Image
from transformers import SiglipImageProcessor, SiglipVisionModel
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
@@ -115,8 +116,14 @@ class FluxReduxInvocation(BaseInvocation):
@torch.no_grad()
def _siglip_encode(self, context: InvocationContext, image: Image.Image) -> torch.Tensor:
siglip_model_config = self._get_siglip_model(context)
with context.models.load(siglip_model_config.key).model_on_device() as (_, siglip_pipeline):
assert isinstance(siglip_pipeline, SigLipPipeline)
with context.models.load(siglip_model_config.key).model_on_device() as (_, model):
assert isinstance(model, SiglipVisionModel)
model_abs_path = context.models.get_absolute_path(siglip_model_config)
processor = SiglipImageProcessor.from_pretrained(model_abs_path, local_files_only=True)
assert isinstance(processor, SiglipImageProcessor)
siglip_pipeline = SigLipPipeline(processor, model)
return siglip_pipeline.encode_image(
x=image, device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()
)

View File

@@ -1,5 +1,5 @@
from contextlib import ExitStack
from typing import Iterator, Literal, Optional, Tuple
from typing import Iterator, Literal, Optional, Tuple, Union
import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer, T5TokenizerFast
@@ -111,6 +111,9 @@ class FluxTextEncoderInvocation(BaseInvocation):
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)
if context.config.get().log_tokenization:
self._log_t5_tokenization(context, t5_tokenizer)
context.util.signal_progress("Running T5 encoder")
prompt_embeds = t5_encoder(prompt)
@@ -151,6 +154,9 @@ class FluxTextEncoderInvocation(BaseInvocation):
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
if context.config.get().log_tokenization:
self._log_clip_tokenization(context, clip_tokenizer)
context.util.signal_progress("Running CLIP encoder")
pooled_prompt_embeds = clip_encoder(prompt)
@@ -170,3 +176,88 @@ class FluxTextEncoderInvocation(BaseInvocation):
assert isinstance(lora_info.model, ModelPatchRaw)
yield (lora_info.model, lora.weight)
del lora_info
def _log_t5_tokenization(
self,
context: InvocationContext,
tokenizer: Union[T5Tokenizer, T5TokenizerFast],
) -> None:
"""Logs the tokenization of a prompt for a T5-based model like FLUX."""
# Tokenize the prompt using the same parameters as the model's text encoder.
# T5 tokenizers add an EOS token (</s>) and then pad to max_length.
tokenized_output = tokenizer(
self.prompt,
padding="max_length",
max_length=self.t5_max_seq_len,
truncation=True,
add_special_tokens=True, # This is important for T5 to add the EOS token.
return_tensors="pt",
)
input_ids = tokenized_output.input_ids[0]
tokens = tokenizer.convert_ids_to_tokens(input_ids)
# The T5 tokenizer uses a space-like character ' ' (U+2581) to denote spaces.
# We'll replace it with a regular space for readability.
tokens = [t.replace("\u2581", " ") for t in tokens]
tokenized_str = ""
used_tokens = 0
for token in tokens:
if token == tokenizer.eos_token:
tokenized_str += f"\x1b[0;31m{token}\x1b[0m" # Red for EOS
used_tokens += 1
elif token == tokenizer.pad_token:
# tokenized_str += f"\x1b[0;34m{token}\x1b[0m" # Blue for PAD
continue
else:
color = (used_tokens % 6) + 1 # Cycle through 6 colors
tokenized_str += f"\x1b[0;3{color}m{token}\x1b[0m"
used_tokens += 1
context.logger.info(f">> [T5 TOKENLOG] Tokens ({used_tokens}/{self.t5_max_seq_len}):")
context.logger.info(f"{tokenized_str}\x1b[0m")
def _log_clip_tokenization(
self,
context: InvocationContext,
tokenizer: CLIPTokenizer,
) -> None:
"""Logs the tokenization of a prompt for a CLIP-based model."""
max_length = tokenizer.model_max_length
tokenized_output = tokenizer(
self.prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
input_ids = tokenized_output.input_ids[0]
attention_mask = tokenized_output.attention_mask[0]
tokens = tokenizer.convert_ids_to_tokens(input_ids)
# The CLIP tokenizer uses '</w>' to denote spaces.
# We'll replace it with a regular space for readability.
tokens = [t.replace("</w>", " ") for t in tokens]
tokenized_str = ""
used_tokens = 0
for i, token in enumerate(tokens):
if attention_mask[i] == 0:
# Do not log padding tokens.
continue
if token == tokenizer.bos_token:
tokenized_str += f"\x1b[0;32m{token}\x1b[0m" # Green for BOS
elif token == tokenizer.eos_token:
tokenized_str += f"\x1b[0;31m{token}\x1b[0m" # Red for EOS
else:
color = (used_tokens % 6) + 1 # Cycle through 6 colors
tokenized_str += f"\x1b[0;3{color}m{token}\x1b[0m"
used_tokens += 1
context.logger.info(f">> [CLIP TOKENLOG] Tokens ({used_tokens}/{max_length}):")
context.logger.info(f"{tokenized_str}\x1b[0m")

View File

@@ -21,14 +21,14 @@ class IdealSizeOutput(BaseInvocationOutput):
"ideal_size",
title="Ideal Size - SD1.5, SDXL",
tags=["latents", "math", "ideal_size"],
version="1.0.5",
version="1.0.6",
)
class IdealSizeInvocation(BaseInvocation):
"""Calculates the ideal size for generation to avoid duplication"""
width: int = InputField(default=1024, description="Final image width")
height: int = InputField(default=576, description="Final image height")
unet: UNetField = InputField(default=None, description=FieldDescriptions.unet)
unet: UNetField = InputField(description=FieldDescriptions.unet)
multiplier: float = InputField(
default=1.0,
description="Amount to multiply the model's dimensions by when calculating the ideal size (may result in "

View File

@@ -975,13 +975,13 @@ class SaveImageInvocation(BaseInvocation, WithMetadata, WithBoard):
title="Canvas Paste Back",
tags=["image", "combine"],
category="image",
version="1.0.0",
version="1.0.1",
)
class CanvasPasteBackInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Combines two images by using the mask provided. Intended for use on the Unified Canvas."""
source_image: ImageField = InputField(description="The source image")
target_image: ImageField = InputField(default=None, description="The target image")
target_image: ImageField = InputField(description="The target image")
mask: ImageField = InputField(
description="The mask to use when pasting",
)
@@ -1218,12 +1218,15 @@ class ApplyMaskToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
title="Add Image Noise",
tags=["image", "noise"],
category="image",
version="1.0.1",
version="1.1.0",
)
class ImageNoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Add noise to an image"""
image: ImageField = InputField(description="The image to add noise to")
mask: Optional[ImageField] = InputField(
default=None, description="Optional mask determining where to apply noise (black=noise, white=no noise)"
)
seed: int = InputField(
default=0,
ge=0,
@@ -1267,12 +1270,27 @@ class ImageNoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
noise = Image.fromarray(noise.astype(numpy.uint8), mode="RGB").resize(
(image.width, image.height), Image.Resampling.NEAREST
)
# Create a noisy version of the input image
noisy_image = Image.blend(image.convert("RGB"), noise, self.amount).convert("RGBA")
# Paste back the alpha channel
noisy_image.putalpha(alpha)
# Apply mask if provided
if self.mask is not None:
mask_image = context.images.get_pil(self.mask.image_name, mode="L")
image_dto = context.images.save(image=noisy_image)
if mask_image.size != image.size:
mask_image = mask_image.resize(image.size, Image.Resampling.LANCZOS)
result_image = image.copy()
mask_image = ImageOps.invert(mask_image)
result_image.paste(noisy_image, (0, 0), mask=mask_image)
else:
result_image = noisy_image
# Paste back the alpha channel from the original image
result_image.putalpha(alpha)
image_dto = context.images.save(image=result_image)
return ImageOutput.build(image_dto)

View File

@@ -31,6 +31,7 @@ class IPAdapterField(BaseModel):
image_encoder_model: ModelIdentifierField = Field(description="The name of the CLIP image encoder model.")
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the IP-Adapter.")
target_blocks: List[str] = Field(default=[], description="The IP Adapter blocks to apply")
method: str = Field(default="full", description="Weight apply method")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
)
@@ -94,7 +95,7 @@ class IPAdapterInvocation(BaseInvocation):
weight: Union[float, List[float]] = InputField(
default=1, description="The weight given to the IP-Adapter", title="Weight"
)
method: Literal["full", "style", "composition"] = InputField(
method: Literal["full", "style", "composition", "style_strong", "style_precise"] = InputField(
default="full", description="The method to apply the IP-Adapter"
)
begin_step_percent: float = InputField(
@@ -147,6 +148,38 @@ class IPAdapterInvocation(BaseInvocation):
target_blocks = ["down_blocks.2.attentions.1"]
else:
raise ValueError(f"Unsupported IP-Adapter base type: '{ip_adapter_info.base}'.")
elif self.method == "style_precise":
if ip_adapter_info.base == "sd-1":
target_blocks = ["up_blocks.1", "down_blocks.2", "mid_block"]
elif ip_adapter_info.base == "sdxl":
target_blocks = ["up_blocks.0.attentions.1", "down_blocks.2.attentions.1"]
else:
raise ValueError(f"Unsupported IP-Adapter base type: '{ip_adapter_info.base}'.")
elif self.method == "style_strong":
if ip_adapter_info.base == "sd-1":
target_blocks = ["up_blocks.0", "up_blocks.1", "up_blocks.2", "down_blocks.0", "down_blocks.1"]
elif ip_adapter_info.base == "sdxl":
target_blocks = [
"up_blocks.0.attentions.1",
"up_blocks.1.attentions.1",
"up_blocks.2.attentions.1",
"up_blocks.0.attentions.2",
"up_blocks.1.attentions.2",
"up_blocks.2.attentions.2",
"up_blocks.0.attentions.0",
"up_blocks.1.attentions.0",
"up_blocks.2.attentions.0",
"down_blocks.0.attentions.0",
"down_blocks.0.attentions.1",
"down_blocks.0.attentions.2",
"down_blocks.1.attentions.0",
"down_blocks.1.attentions.1",
"down_blocks.1.attentions.2",
"down_blocks.2.attentions.0",
"down_blocks.2.attentions.2",
]
else:
raise ValueError(f"Unsupported IP-Adapter base type: '{ip_adapter_info.base}'.")
elif self.method == "full":
target_blocks = ["block"]
else:
@@ -162,6 +195,7 @@ class IPAdapterInvocation(BaseInvocation):
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,
mask=self.mask,
method=self.method,
),
)

View File

@@ -3,13 +3,14 @@ from typing import Any
import torch
from PIL.Image import Image
from pydantic import field_validator
from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration, LlavaOnevisionProcessor
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, UIComponent, UIType
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import StringOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.llava_onevision_model import LlavaOnevisionModel
from invokeai.backend.llava_onevision_pipeline import LlavaOnevisionPipeline
from invokeai.backend.util.devices import TorchDevice
@@ -54,10 +55,17 @@ class LlavaOnevisionVllmInvocation(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> StringOutput:
images = self._get_images(context)
model_config = context.models.get_config(self.vllm_model)
with context.models.load(self.vllm_model) as vllm_model:
assert isinstance(vllm_model, LlavaOnevisionModel)
output = vllm_model.run(
with context.models.load(self.vllm_model).model_on_device() as (_, model):
assert isinstance(model, LlavaOnevisionForConditionalGeneration)
model_abs_path = context.models.get_absolute_path(model_config)
processor = AutoProcessor.from_pretrained(model_abs_path, local_files_only=True)
assert isinstance(processor, LlavaOnevisionProcessor)
model = LlavaOnevisionPipeline(model, processor)
output = model.run(
prompt=self.prompt,
images=images,
device=TorchDevice.choose_torch_device(),

View File

@@ -42,7 +42,9 @@ class IPAdapterMetadataField(BaseModel):
image: ImageField = Field(description="The IP-Adapter image prompt.")
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model.")
clip_vision_model: Literal["ViT-L", "ViT-H", "ViT-G"] = Field(description="The CLIP Vision model")
method: Literal["full", "style", "composition"] = Field(description="Method to apply IP Weights with")
method: Literal["full", "style", "composition", "style_strong", "style_precise"] = Field(
description="Method to apply IP Weights with"
)
weight: Union[float, list[float]] = Field(description="The weight given to the IP-Adapter")
begin_step_percent: float = Field(description="When the IP-Adapter is first applied (% of total steps)")
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")

View File

@@ -430,6 +430,15 @@ class FluxConditioningOutput(BaseInvocationOutput):
return cls(conditioning=FluxConditioningField(conditioning_name=conditioning_name))
@invocation_output("flux_conditioning_collection_output")
class FluxConditioningCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of conditioning tensors"""
collection: list[FluxConditioningField] = OutputField(
description="The output conditioning tensors",
)
@invocation_output("sd3_conditioning_output")
class SD3ConditioningOutput(BaseInvocationOutput):
"""Base class for nodes that output a single SD3 conditioning tensor"""

View File

@@ -6,7 +6,7 @@ import numpy as np
import torch
from PIL import Image
from pydantic import BaseModel, Field
from transformers import AutoModelForMaskGeneration, AutoProcessor
from transformers import AutoProcessor
from transformers.models.sam import SamModel
from transformers.models.sam.processing_sam import SamProcessor
@@ -104,14 +104,13 @@ class SegmentAnythingInvocation(BaseInvocation):
@staticmethod
def _load_sam_model(model_path: Path):
sam_model = AutoModelForMaskGeneration.from_pretrained(
sam_model = SamModel.from_pretrained(
model_path,
local_files_only=True,
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
# model, and figure out how to make it work in the pipeline.
# torch_dtype=TorchDevice.choose_torch_dtype(),
)
assert isinstance(sam_model, SamModel)
sam_processor = AutoProcessor.from_pretrained(model_path, local_files_only=True)
assert isinstance(sam_processor, SamProcessor)

View File

@@ -1,12 +1,3 @@
import uvicorn
from invokeai.app.invocations.load_custom_nodes import load_custom_nodes
from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.torch_cuda_allocator import configure_torch_cuda_allocator
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
def get_app():
"""Import the app and event loop. We wrap this in a function to more explicitly control when it happens, because
importing from api_app does a bunch of stuff - it's more like calling a function than importing a module.
@@ -18,9 +9,18 @@ def get_app():
def run_app() -> None:
"""The main entrypoint for the app."""
# Parse the CLI arguments.
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
# Parse the CLI arguments before doing anything else, which ensures CLI args correctly override settings from other
# sources like `invokeai.yaml` or env vars.
InvokeAIArgs.parse_args()
import uvicorn
from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.torch_cuda_allocator import configure_torch_cuda_allocator
from invokeai.backend.util.logging import InvokeAILogger
# Load config.
app_config = get_config()
@@ -32,6 +32,8 @@ def run_app() -> None:
configure_torch_cuda_allocator(app_config.pytorch_cuda_alloc_conf, logger)
# This import must happen after configure_torch_cuda_allocator() is called, because the module imports torch.
from invokeai.app.invocations.baseinvocation import InvocationRegistry
from invokeai.app.invocations.load_custom_nodes import load_custom_nodes
from invokeai.backend.util.devices import TorchDevice
torch_device_name = TorchDevice.get_torch_device_name()
@@ -66,6 +68,15 @@ def run_app() -> None:
# core nodes have been imported so that we can catch when a custom node clobbers a core node.
load_custom_nodes(custom_nodes_path=app_config.custom_nodes_path, logger=logger)
# Check all invocations and ensure their outputs are registered.
for invocation in InvocationRegistry.get_invocation_classes():
invocation_type = invocation.get_type()
output_annotation = invocation.get_output_annotation()
if output_annotation not in InvocationRegistry.get_output_classes():
logger.warning(
f'Invocation "{invocation_type}" has unregistered output class "{output_annotation.__name__}"'
)
if app_config.dev_reload:
# load_custom_nodes seems to bypass jurrigged's import sniffer, so be sure to call it *after* they're already
# imported.

View File

@@ -14,15 +14,14 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._conn = db.conn
self._db = db
def add_image_to_board(
self,
board_id: str,
image_name: str,
) -> None:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
INSERT INTO board_images (board_id, image_name)
@@ -31,17 +30,12 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
""",
(board_id, image_name, board_id),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise e
def remove_image_from_board(
self,
image_name: str,
) -> None:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
DELETE FROM board_images
@@ -49,10 +43,6 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
""",
(image_name,),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise e
def get_images_for_board(
self,
@@ -60,27 +50,26 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
offset: int = 0,
limit: int = 10,
) -> OffsetPaginatedResults[ImageRecord]:
# TODO: this isn't paginated yet?
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT images.*
FROM board_images
INNER JOIN images ON board_images.image_name = images.image_name
WHERE board_images.board_id = ?
ORDER BY board_images.updated_at DESC;
""",
(board_id,),
)
result = cast(list[sqlite3.Row], cursor.fetchall())
images = [deserialize_image_record(dict(r)) for r in result]
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT images.*
FROM board_images
INNER JOIN images ON board_images.image_name = images.image_name
WHERE board_images.board_id = ?
ORDER BY board_images.updated_at DESC;
""",
(board_id,),
)
result = cast(list[sqlite3.Row], cursor.fetchall())
images = [deserialize_image_record(dict(r)) for r in result]
cursor.execute(
"""--sql
SELECT COUNT(*) FROM images WHERE 1=1;
"""
)
count = cast(int, cursor.fetchone()[0])
cursor.execute(
"""--sql
SELECT COUNT(*) FROM images WHERE 1=1;
"""
)
count = cast(int, cursor.fetchone()[0])
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
@@ -90,47 +79,55 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
categories: list[ImageCategory] | None,
is_intermediate: bool | None,
) -> list[str]:
params: list[str | bool] = []
with self._db.transaction() as cursor:
params: list[str | bool] = []
# Base query is a join between images and board_images
stmt = """
SELECT images.image_name
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
AND board_images.board_id = ?
"""
params.append(board_id)
# Base query is a join between images and board_images
stmt = """
SELECT images.image_name
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
# Add the category filter
if categories is not None:
# Convert the enum values to unique list of strings
category_strings = [c.value for c in set(categories)]
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
stmt += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
# Handle board_id filter
if board_id == "none":
stmt += """--sql
AND board_images.board_id IS NULL
"""
else:
stmt += """--sql
AND board_images.board_id = ?
"""
params.append(board_id)
# Unpack the included categories into the query params
for c in category_strings:
params.append(c)
# Add the category filter
if categories is not None:
# Convert the enum values to unique list of strings
category_strings = [c.value for c in set(categories)]
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
stmt += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
# Add the is_intermediate filter
if is_intermediate is not None:
stmt += """--sql
AND images.is_intermediate = ?
"""
params.append(is_intermediate)
# Unpack the included categories into the query params
for c in category_strings:
params.append(c)
# Put a ring on it
stmt += ";"
# Add the is_intermediate filter
if is_intermediate is not None:
stmt += """--sql
AND images.is_intermediate = ?
"""
params.append(is_intermediate)
# Execute the query
cursor = self._conn.cursor()
cursor.execute(stmt, params)
# Put a ring on it
stmt += ";"
result = cast(list[sqlite3.Row], cursor.fetchall())
cursor.execute(stmt, params)
result = cast(list[sqlite3.Row], cursor.fetchall())
image_names = [r[0] for r in result]
return image_names
@@ -138,31 +135,31 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
self,
image_name: str,
) -> Optional[str]:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT board_id
FROM board_images
WHERE image_name = ?;
""",
(image_name,),
)
result = cursor.fetchone()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT board_id
FROM board_images
WHERE image_name = ?;
""",
(image_name,),
)
result = cursor.fetchone()
if result is None:
return None
return cast(str, result[0])
def get_image_count_for_board(self, board_id: str) -> int:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT COUNT(*)
FROM board_images
INNER JOIN images ON board_images.image_name = images.image_name
WHERE images.is_intermediate = FALSE
AND board_images.board_id = ?;
""",
(board_id,),
)
count = cast(int, cursor.fetchone()[0])
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT COUNT(*)
FROM board_images
INNER JOIN images ON board_images.image_name = images.image_name
WHERE images.is_intermediate = FALSE
AND board_images.board_id = ?;
""",
(board_id,),
)
count = cast(int, cursor.fetchone()[0])
return count

View File

@@ -20,61 +20,57 @@ from invokeai.app.util.misc import uuid_string
class SqliteBoardRecordStorage(BoardRecordStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._conn = db.conn
self._db = db
def delete(self, board_id: str) -> None:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
DELETE FROM boards
WHERE board_id = ?;
""",
(board_id,),
)
self._conn.commit()
except Exception as e:
self._conn.rollback()
raise BoardRecordDeleteException from e
with self._db.transaction() as cursor:
try:
cursor.execute(
"""--sql
DELETE FROM boards
WHERE board_id = ?;
""",
(board_id,),
)
except Exception as e:
raise BoardRecordDeleteException from e
def save(
self,
board_name: str,
) -> BoardRecord:
try:
board_id = uuid_string()
cursor = self._conn.cursor()
cursor.execute(
"""--sql
INSERT OR IGNORE INTO boards (board_id, board_name)
VALUES (?, ?);
""",
(board_id, board_name),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordSaveException from e
with self._db.transaction() as cursor:
try:
board_id = uuid_string()
cursor.execute(
"""--sql
INSERT OR IGNORE INTO boards (board_id, board_name)
VALUES (?, ?);
""",
(board_id, board_name),
)
except sqlite3.Error as e:
raise BoardRecordSaveException from e
return self.get(board_id)
def get(
self,
board_id: str,
) -> BoardRecord:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM boards
WHERE board_id = ?;
""",
(board_id,),
)
with self._db.transaction() as cursor:
try:
cursor.execute(
"""--sql
SELECT *
FROM boards
WHERE board_id = ?;
""",
(board_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
except sqlite3.Error as e:
raise BoardRecordNotFoundException from e
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
except sqlite3.Error as e:
raise BoardRecordNotFoundException from e
if result is None:
raise BoardRecordNotFoundException
return BoardRecord(**dict(result))
@@ -84,45 +80,43 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
board_id: str,
changes: BoardChanges,
) -> BoardRecord:
try:
cursor = self._conn.cursor()
# Change the name of a board
if changes.board_name is not None:
cursor.execute(
"""--sql
UPDATE boards
SET board_name = ?
WHERE board_id = ?;
""",
(changes.board_name, board_id),
)
with self._db.transaction() as cursor:
try:
# Change the name of a board
if changes.board_name is not None:
cursor.execute(
"""--sql
UPDATE boards
SET board_name = ?
WHERE board_id = ?;
""",
(changes.board_name, board_id),
)
# Change the cover image of a board
if changes.cover_image_name is not None:
cursor.execute(
"""--sql
UPDATE boards
SET cover_image_name = ?
WHERE board_id = ?;
""",
(changes.cover_image_name, board_id),
)
# Change the cover image of a board
if changes.cover_image_name is not None:
cursor.execute(
"""--sql
UPDATE boards
SET cover_image_name = ?
WHERE board_id = ?;
""",
(changes.cover_image_name, board_id),
)
# Change the archived status of a board
if changes.archived is not None:
cursor.execute(
"""--sql
UPDATE boards
SET archived = ?
WHERE board_id = ?;
""",
(changes.archived, board_id),
)
# Change the archived status of a board
if changes.archived is not None:
cursor.execute(
"""--sql
UPDATE boards
SET archived = ?
WHERE board_id = ?;
""",
(changes.archived, board_id),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordSaveException from e
except sqlite3.Error as e:
raise BoardRecordSaveException from e
return self.get(board_id)
def get_many(
@@ -133,78 +127,77 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
limit: int = 10,
include_archived: bool = False,
) -> OffsetPaginatedResults[BoardRecord]:
cursor = self._conn.cursor()
# Build base query
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY {order_by} {direction}
LIMIT ? OFFSET ?;
"""
# Determine archived filter condition
archived_filter = "" if include_archived else "WHERE archived = 0"
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
# Execute query to fetch boards
cursor.execute(final_query, (limit, offset))
result = cast(list[sqlite3.Row], cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]
# Determine count query
if include_archived:
count_query = """
SELECT COUNT(*)
FROM boards;
"""
else:
count_query = """
SELECT COUNT(*)
with self._db.transaction() as cursor:
# Build base query
base_query = """
SELECT *
FROM boards
WHERE archived = 0;
{archived_filter}
ORDER BY {order_by} {direction}
LIMIT ? OFFSET ?;
"""
# Execute count query
cursor.execute(count_query)
# Determine archived filter condition
archived_filter = "" if include_archived else "WHERE archived = 0"
count = cast(int, cursor.fetchone()[0])
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
# Execute query to fetch boards
cursor.execute(final_query, (limit, offset))
result = cast(list[sqlite3.Row], cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]
# Determine count query
if include_archived:
count_query = """
SELECT COUNT(*)
FROM boards;
"""
else:
count_query = """
SELECT COUNT(*)
FROM boards
WHERE archived = 0;
"""
# Execute count query
cursor.execute(count_query)
count = cast(int, cursor.fetchone()[0])
return OffsetPaginatedResults[BoardRecord](items=boards, offset=offset, limit=limit, total=count)
def get_all(
self, order_by: BoardRecordOrderBy, direction: SQLiteDirection, include_archived: bool = False
) -> list[BoardRecord]:
cursor = self._conn.cursor()
if order_by == BoardRecordOrderBy.Name:
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY LOWER(board_name) {direction}
"""
else:
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY {order_by} {direction}
"""
with self._db.transaction() as cursor:
if order_by == BoardRecordOrderBy.Name:
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY LOWER(board_name) {direction}
"""
else:
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY {order_by} {direction}
"""
archived_filter = "" if include_archived else "WHERE archived = 0"
archived_filter = "" if include_archived else "WHERE archived = 0"
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
cursor.execute(final_query)
cursor.execute(final_query)
result = cast(list[sqlite3.Row], cursor.fetchall())
result = cast(list[sqlite3.Row], cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]
return boards

View File

@@ -24,7 +24,6 @@ from invokeai.frontend.cli.arg_parser import InvokeAIArgs
INIT_FILE = Path("invokeai.yaml")
DB_FILE = Path("invokeai.db")
LEGACY_INIT_FILE = Path("invokeai.init")
DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
PRECISION = Literal["auto", "float16", "bfloat16", "float32"]
ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
@@ -93,7 +92,7 @@ class InvokeAIAppConfig(BaseSettings):
vram: DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_vram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.
lazy_offload: DEPRECATED: This setting is no longer used. Lazy-offloading is enabled by default. This config setting will be removed once the new model cache behavior is stable.
pytorch_cuda_alloc_conf: Configure the Torch CUDA memory allocator. This will impact peak reserved VRAM usage and performance. Setting to "backend:cudaMallocAsync" works well on many systems. The optimal configuration is highly dependent on the system configuration (device type, VRAM, CUDA driver version, etc.), so must be tuned experimentally.
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number)
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
attention_type: Attention type.<br>Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp`
@@ -176,7 +175,7 @@ class InvokeAIAppConfig(BaseSettings):
pytorch_cuda_alloc_conf: Optional[str] = Field(default=None, description="Configure the Torch CUDA memory allocator. This will impact peak reserved VRAM usage and performance. Setting to \"backend:cudaMallocAsync\" works well on many systems. The optimal configuration is highly dependent on the system configuration (device type, VRAM, CUDA driver version, etc.), so must be tuned experimentally.")
# DEVICE
device: DEVICE = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.")
device: str = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number)", pattern=r"^(auto|cpu|mps|cuda(:\d+)?)$")
precision: PRECISION = Field(default="auto", description="Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.")
# GENERATION

View File

@@ -8,6 +8,7 @@ import time
import traceback
from pathlib import Path
from queue import Empty, PriorityQueue
from shutil import disk_usage
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Set
import requests
@@ -335,6 +336,14 @@ class DownloadQueueService(DownloadQueueServiceBase):
assert job.download_path
free_space = disk_usage(job.download_path.parent).free
GB = 2**30
self._logger.debug(f"Download is {job.total_bytes / GB:.2f} GB of {free_space / GB:.2f} GB free.")
if free_space < job.total_bytes:
raise RuntimeError(
f"Free disk space {free_space / GB:.2f} GB is not enough for download of {job.total_bytes / GB:.2f} GB."
)
# Don't clobber an existing file. See commit 82c2c85202f88c6d24ff84710f297cfc6ae174af
# for code that instead resumes an interrupted download.
if job.download_path.exists():

View File

@@ -241,6 +241,7 @@ class QueueItemStatusChangedEvent(QueueItemEventBase):
batch_status: BatchStatus = Field(description="The status of the batch")
queue_status: SessionQueueStatus = Field(description="The status of the queue")
session_id: str = Field(description="The ID of the session (aka graph execution state)")
credits: Optional[float] = Field(default=None, description="The total credits used for this queue item")
@classmethod
def build(
@@ -263,6 +264,7 @@ class QueueItemStatusChangedEvent(QueueItemEventBase):
completed_at=str(queue_item.completed_at) if queue_item.completed_at else None,
batch_status=batch_status,
queue_status=queue_status,
credits=queue_item.credits,
)

View File

@@ -5,6 +5,7 @@ from typing import Optional
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
ImageNamesResult,
ImageRecord,
ImageRecordChanges,
ResourceOrigin,
@@ -97,3 +98,17 @@ class ImageRecordStorageBase(ABC):
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
"""Gets the most recent image for a board."""
pass
@abstractmethod
def get_image_names(
self,
starred_first: bool = True,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> ImageNamesResult:
"""Gets ordered list of image names with metadata for optimistic updates."""
pass

View File

@@ -3,7 +3,7 @@ import datetime
from enum import Enum
from typing import Optional, Union
from pydantic import Field, StrictBool, StrictStr
from pydantic import BaseModel, Field, StrictBool, StrictStr
from invokeai.app.util.metaenum import MetaEnum
from invokeai.app.util.misc import get_iso_timestamp
@@ -207,3 +207,16 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
starred=starred,
has_workflow=has_workflow,
)
class ImageCollectionCounts(BaseModel):
starred_count: int = Field(description="The number of starred images in the collection.")
unstarred_count: int = Field(description="The number of unstarred images in the collection.")
class ImageNamesResult(BaseModel):
"""Response containing ordered image names with metadata for optimistic updates."""
image_names: list[str] = Field(description="Ordered list of image names")
starred_count: int = Field(description="Number of starred images (when starred_first=True)")
total_count: int = Field(description="Total number of images matching the query")

View File

@@ -7,6 +7,7 @@ from invokeai.app.services.image_records.image_records_base import ImageRecordSt
from invokeai.app.services.image_records.image_records_common import (
IMAGE_DTO_COLS,
ImageCategory,
ImageNamesResult,
ImageRecord,
ImageRecordChanges,
ImageRecordDeleteException,
@@ -23,22 +24,22 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class SqliteImageRecordStorage(ImageRecordStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._conn = db.conn
self._db = db
def get(self, image_name: str) -> ImageRecord:
try:
cursor = self._conn.cursor()
cursor.execute(
f"""--sql
SELECT {IMAGE_DTO_COLS} FROM images
WHERE image_name = ?;
""",
(image_name,),
)
with self._db.transaction() as cursor:
try:
cursor.execute(
f"""--sql
SELECT {IMAGE_DTO_COLS} FROM images
WHERE image_name = ?;
""",
(image_name,),
)
result = cast(Optional[sqlite3.Row], cursor.fetchone())
except sqlite3.Error as e:
raise ImageRecordNotFoundException from e
result = cast(Optional[sqlite3.Row], cursor.fetchone())
except sqlite3.Error as e:
raise ImageRecordNotFoundException from e
if not result:
raise ImageRecordNotFoundException
@@ -46,17 +47,20 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
return deserialize_image_record(dict(result))
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT metadata FROM images
WHERE image_name = ?;
""",
(image_name,),
)
with self._db.transaction() as cursor:
try:
cursor.execute(
"""--sql
SELECT metadata FROM images
WHERE image_name = ?;
""",
(image_name,),
)
result = cast(Optional[sqlite3.Row], cursor.fetchone())
result = cast(Optional[sqlite3.Row], cursor.fetchone())
except sqlite3.Error as e:
raise ImageRecordNotFoundException from e
if not result:
raise ImageRecordNotFoundException
@@ -64,64 +68,60 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
as_dict = dict(result)
metadata_raw = cast(Optional[str], as_dict.get("metadata", None))
return MetadataFieldValidator.validate_json(metadata_raw) if metadata_raw is not None else None
except sqlite3.Error as e:
raise ImageRecordNotFoundException from e
def update(
self,
image_name: str,
changes: ImageRecordChanges,
) -> None:
try:
cursor = self._conn.cursor()
# Change the category of the image
if changes.image_category is not None:
cursor.execute(
"""--sql
UPDATE images
SET image_category = ?
WHERE image_name = ?;
""",
(changes.image_category, image_name),
)
with self._db.transaction() as cursor:
try:
# Change the category of the image
if changes.image_category is not None:
cursor.execute(
"""--sql
UPDATE images
SET image_category = ?
WHERE image_name = ?;
""",
(changes.image_category, image_name),
)
# Change the session associated with the image
if changes.session_id is not None:
cursor.execute(
"""--sql
UPDATE images
SET session_id = ?
WHERE image_name = ?;
""",
(changes.session_id, image_name),
)
# Change the session associated with the image
if changes.session_id is not None:
cursor.execute(
"""--sql
UPDATE images
SET session_id = ?
WHERE image_name = ?;
""",
(changes.session_id, image_name),
)
# Change the image's `is_intermediate`` flag
if changes.is_intermediate is not None:
cursor.execute(
"""--sql
UPDATE images
SET is_intermediate = ?
WHERE image_name = ?;
""",
(changes.is_intermediate, image_name),
)
# Change the image's `is_intermediate`` flag
if changes.is_intermediate is not None:
cursor.execute(
"""--sql
UPDATE images
SET is_intermediate = ?
WHERE image_name = ?;
""",
(changes.is_intermediate, image_name),
)
# Change the image's `starred`` state
if changes.starred is not None:
cursor.execute(
"""--sql
UPDATE images
SET starred = ?
WHERE image_name = ?;
""",
(changes.starred, image_name),
)
# Change the image's `starred`` state
if changes.starred is not None:
cursor.execute(
"""--sql
UPDATE images
SET starred = ?
WHERE image_name = ?;
""",
(changes.starred, image_name),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordSaveException from e
except sqlite3.Error as e:
raise ImageRecordSaveException from e
def get_many(
self,
@@ -135,166 +135,162 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageRecord]:
cursor = self._conn.cursor()
# Manually build two queries - one for the count, one for the records
count_query = """--sql
SELECT COUNT(*)
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
images_query = f"""--sql
SELECT {IMAGE_DTO_COLS}
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
query_conditions = ""
query_params: list[Union[int, str, bool]] = []
if image_origin is not None:
query_conditions += """--sql
AND images.image_origin = ?
"""
query_params.append(image_origin.value)
if categories is not None:
# Convert the enum values to unique list of strings
category_strings = [c.value for c in set(categories)]
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
query_conditions += f"""--sql
AND images.image_category IN ( {placeholders} )
with self._db.transaction() as cursor:
# Manually build two queries - one for the count, one for the records
count_query = """--sql
SELECT COUNT(*)
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
# Unpack the included categories into the query params
for c in category_strings:
query_params.append(c)
if is_intermediate is not None:
query_conditions += """--sql
AND images.is_intermediate = ?
images_query = f"""--sql
SELECT {IMAGE_DTO_COLS}
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
query_params.append(is_intermediate)
query_conditions = ""
query_params: list[Union[int, str, bool]] = []
# board_id of "none" is reserved for images without a board
if board_id == "none":
query_conditions += """--sql
AND board_images.board_id IS NULL
"""
elif board_id is not None:
query_conditions += """--sql
AND board_images.board_id = ?
"""
query_params.append(board_id)
if image_origin is not None:
query_conditions += """--sql
AND images.image_origin = ?
"""
query_params.append(image_origin.value)
# Search term condition
if search_term:
query_conditions += """--sql
AND images.metadata LIKE ?
"""
query_params.append(f"%{search_term.lower()}%")
if categories is not None:
# Convert the enum values to unique list of strings
category_strings = [c.value for c in set(categories)]
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
if starred_first:
query_pagination = f"""--sql
ORDER BY images.starred DESC, images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
else:
query_pagination = f"""--sql
ORDER BY images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
query_conditions += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
# Final images query with pagination
images_query += query_conditions + query_pagination + ";"
# Add all the parameters
images_params = query_params.copy()
# Add the pagination parameters
images_params.extend([limit, offset])
# Unpack the included categories into the query params
for c in category_strings:
query_params.append(c)
# Build the list of images, deserializing each row
cursor.execute(images_query, images_params)
result = cast(list[sqlite3.Row], cursor.fetchall())
images = [deserialize_image_record(dict(r)) for r in result]
if is_intermediate is not None:
query_conditions += """--sql
AND images.is_intermediate = ?
"""
# Set up and execute the count query, without pagination
count_query += query_conditions + ";"
count_params = query_params.copy()
cursor.execute(count_query, count_params)
count = cast(int, cursor.fetchone()[0])
query_params.append(is_intermediate)
# board_id of "none" is reserved for images without a board
if board_id == "none":
query_conditions += """--sql
AND board_images.board_id IS NULL
"""
elif board_id is not None:
query_conditions += """--sql
AND board_images.board_id = ?
"""
query_params.append(board_id)
# Search term condition
if search_term:
query_conditions += """--sql
AND (
images.metadata LIKE ?
OR images.created_at LIKE ?
)
"""
query_params.append(f"%{search_term.lower()}%")
query_params.append(f"%{search_term.lower()}%")
if starred_first:
query_pagination = f"""--sql
ORDER BY images.starred DESC, images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
else:
query_pagination = f"""--sql
ORDER BY images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
# Final images query with pagination
images_query += query_conditions + query_pagination + ";"
# Add all the parameters
images_params = query_params.copy()
# Add the pagination parameters
images_params.extend([limit, offset])
# Build the list of images, deserializing each row
cursor.execute(images_query, images_params)
result = cast(list[sqlite3.Row], cursor.fetchall())
images = [deserialize_image_record(dict(r)) for r in result]
# Set up and execute the count query, without pagination
count_query += query_conditions + ";"
count_params = query_params.copy()
cursor.execute(count_query, count_params)
count = cast(int, cursor.fetchone()[0])
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
def delete(self, image_name: str) -> None:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
DELETE FROM images
WHERE image_name = ?;
""",
(image_name,),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordDeleteException from e
with self._db.transaction() as cursor:
try:
cursor.execute(
"""--sql
DELETE FROM images
WHERE image_name = ?;
""",
(image_name,),
)
except sqlite3.Error as e:
raise ImageRecordDeleteException from e
def delete_many(self, image_names: list[str]) -> None:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
try:
placeholders = ",".join("?" for _ in image_names)
placeholders = ",".join("?" for _ in image_names)
# Construct the SQLite query with the placeholders
query = f"DELETE FROM images WHERE image_name IN ({placeholders})"
# Construct the SQLite query with the placeholders
query = f"DELETE FROM images WHERE image_name IN ({placeholders})"
# Execute the query with the list of IDs as parameters
cursor.execute(query, image_names)
# Execute the query with the list of IDs as parameters
cursor.execute(query, image_names)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordDeleteException from e
except sqlite3.Error as e:
raise ImageRecordDeleteException from e
def get_intermediates_count(self) -> int:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT COUNT(*) FROM images
WHERE is_intermediate = TRUE;
"""
)
count = cast(int, cursor.fetchone()[0])
self._conn.commit()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT COUNT(*) FROM images
WHERE is_intermediate = TRUE;
"""
)
count = cast(int, cursor.fetchone()[0])
return count
def delete_intermediates(self) -> list[str]:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT image_name FROM images
WHERE is_intermediate = TRUE;
"""
)
result = cast(list[sqlite3.Row], cursor.fetchall())
image_names = [r[0] for r in result]
cursor.execute(
"""--sql
DELETE FROM images
WHERE is_intermediate = TRUE;
"""
)
self._conn.commit()
return image_names
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordDeleteException from e
with self._db.transaction() as cursor:
try:
cursor.execute(
"""--sql
SELECT image_name FROM images
WHERE is_intermediate = TRUE;
"""
)
result = cast(list[sqlite3.Row], cursor.fetchall())
image_names = [r[0] for r in result]
cursor.execute(
"""--sql
DELETE FROM images
WHERE is_intermediate = TRUE;
"""
)
except sqlite3.Error as e:
raise ImageRecordDeleteException from e
return image_names
def save(
self,
@@ -310,75 +306,165 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
node_id: Optional[str] = None,
metadata: Optional[str] = None,
) -> datetime:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
INSERT OR IGNORE INTO images (
image_name,
image_origin,
image_category,
width,
height,
node_id,
session_id,
metadata,
is_intermediate,
starred,
has_workflow
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
""",
(
image_name,
image_origin.value,
image_category.value,
width,
height,
node_id,
session_id,
metadata,
is_intermediate,
starred,
has_workflow,
),
)
self._conn.commit()
with self._db.transaction() as cursor:
try:
cursor.execute(
"""--sql
INSERT OR IGNORE INTO images (
image_name,
image_origin,
image_category,
width,
height,
node_id,
session_id,
metadata,
is_intermediate,
starred,
has_workflow
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
""",
(
image_name,
image_origin.value,
image_category.value,
width,
height,
node_id,
session_id,
metadata,
is_intermediate,
starred,
has_workflow,
),
)
cursor.execute(
"""--sql
SELECT created_at
FROM images
WHERE image_name = ?;
""",
(image_name,),
)
cursor.execute(
"""--sql
SELECT created_at
FROM images
WHERE image_name = ?;
""",
(image_name,),
)
created_at = datetime.fromisoformat(cursor.fetchone()[0])
created_at = datetime.fromisoformat(cursor.fetchone()[0])
return created_at
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordSaveException from e
except sqlite3.Error as e:
raise ImageRecordSaveException from e
return created_at
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT images.*
FROM images
JOIN board_images ON images.image_name = board_images.image_name
WHERE board_images.board_id = ?
AND images.is_intermediate = FALSE
ORDER BY images.starred DESC, images.created_at DESC
LIMIT 1;
""",
(board_id,),
)
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT images.*
FROM images
JOIN board_images ON images.image_name = board_images.image_name
WHERE board_images.board_id = ?
AND images.is_intermediate = FALSE
ORDER BY images.starred DESC, images.created_at DESC
LIMIT 1;
""",
(board_id,),
)
result = cast(Optional[sqlite3.Row], cursor.fetchone())
result = cast(Optional[sqlite3.Row], cursor.fetchone())
if result is None:
return None
return deserialize_image_record(dict(result))
def get_image_names(
self,
starred_first: bool = True,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> ImageNamesResult:
with self._db.transaction() as cursor:
# Build query conditions (reused for both starred count and image names queries)
query_conditions = ""
query_params: list[Union[int, str, bool]] = []
if image_origin is not None:
query_conditions += """--sql
AND images.image_origin = ?
"""
query_params.append(image_origin.value)
if categories is not None:
category_strings = [c.value for c in set(categories)]
placeholders = ",".join("?" * len(category_strings))
query_conditions += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
for c in category_strings:
query_params.append(c)
if is_intermediate is not None:
query_conditions += """--sql
AND images.is_intermediate = ?
"""
query_params.append(is_intermediate)
if board_id == "none":
query_conditions += """--sql
AND board_images.board_id IS NULL
"""
elif board_id is not None:
query_conditions += """--sql
AND board_images.board_id = ?
"""
query_params.append(board_id)
if search_term:
query_conditions += """--sql
AND (
images.metadata LIKE ?
OR images.created_at LIKE ?
)
"""
query_params.append(f"%{search_term.lower()}%")
query_params.append(f"%{search_term.lower()}%")
# Get starred count if starred_first is enabled
starred_count = 0
if starred_first:
starred_count_query = f"""--sql
SELECT COUNT(*)
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE images.starred = TRUE AND (1=1{query_conditions})
"""
cursor.execute(starred_count_query, query_params)
starred_count = cast(int, cursor.fetchone()[0])
# Get all image names with proper ordering
if starred_first:
names_query = f"""--sql
SELECT images.image_name
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1{query_conditions}
ORDER BY images.starred DESC, images.created_at {order_dir.value}
"""
else:
names_query = f"""--sql
SELECT images.image_name
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1{query_conditions}
ORDER BY images.created_at {order_dir.value}
"""
cursor.execute(names_query, query_params)
result = cast(list[sqlite3.Row], cursor.fetchall())
image_names = [row[0] for row in result]
return ImageNamesResult(image_names=image_names, starred_count=starred_count, total_count=len(image_names))

View File

@@ -6,6 +6,7 @@ from PIL.Image import Image as PILImageType
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
ImageNamesResult,
ImageRecord,
ImageRecordChanges,
ResourceOrigin,
@@ -125,7 +126,7 @@ class ImageServiceABC(ABC):
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets a paginated list of image DTOs."""
"""Gets a paginated list of image DTOs with starred images first when starred_first=True."""
pass
@abstractmethod
@@ -147,3 +148,17 @@ class ImageServiceABC(ABC):
def delete_images_on_board(self, board_id: str):
"""Deletes all images on a board."""
pass
@abstractmethod
def get_image_names(
self,
starred_first: bool = True,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> ImageNamesResult:
"""Gets ordered list of image names with metadata for optimistic updates."""
pass

View File

@@ -1,6 +1,6 @@
from typing import Optional
from pydantic import Field
from pydantic import BaseModel, Field
from invokeai.app.services.image_records.image_records_common import ImageRecord
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
@@ -39,3 +39,27 @@ def image_record_to_dto(
thumbnail_url=thumbnail_url,
board_id=board_id,
)
class ResultWithAffectedBoards(BaseModel):
affected_boards: list[str] = Field(description="The ids of boards affected by the delete operation")
class DeleteImagesResult(ResultWithAffectedBoards):
deleted_images: list[str] = Field(description="The names of the images that were deleted")
class StarredImagesResult(ResultWithAffectedBoards):
starred_images: list[str] = Field(description="The names of the images that were starred")
class UnstarredImagesResult(ResultWithAffectedBoards):
unstarred_images: list[str] = Field(description="The names of the images that were unstarred")
class AddImagesToBoardResult(ResultWithAffectedBoards):
added_images: list[str] = Field(description="The image names that were added to the board")
class RemoveImagesFromBoardResult(ResultWithAffectedBoards):
removed_images: list[str] = Field(description="The image names that were removed from their board")

View File

@@ -10,6 +10,7 @@ from invokeai.app.services.image_files.image_files_common import (
)
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
ImageNamesResult,
ImageRecord,
ImageRecordChanges,
ImageRecordDeleteException,
@@ -78,7 +79,7 @@ class ImageService(ImageServiceABC):
board_id=board_id, image_name=image_name
)
except Exception as e:
self.__invoker.services.logger.warn(f"Failed to add image to board {board_id}: {str(e)}")
self.__invoker.services.logger.warning(f"Failed to add image to board {board_id}: {str(e)}")
self.__invoker.services.image_files.save(
image_name=image_name, image=image, metadata=metadata, workflow=workflow, graph=graph
)
@@ -309,3 +310,27 @@ class ImageService(ImageServiceABC):
except Exception as e:
self.__invoker.services.logger.error("Problem getting intermediates count")
raise e
def get_image_names(
self,
starred_first: bool = True,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> ImageNamesResult:
try:
return self.__invoker.services.image_records.get_image_names(
starred_first=starred_first,
order_dir=order_dir,
image_origin=image_origin,
categories=categories,
is_intermediate=is_intermediate,
board_id=board_id,
search_term=search_term,
)
except Exception as e:
self.__invoker.services.logger.error("Problem getting image names")
raise e

View File

@@ -27,6 +27,10 @@ if TYPE_CHECKING:
from invokeai.app.services.invocation_stats.invocation_stats_base import InvocationStatsServiceBase
from invokeai.app.services.model_images.model_images_base import ModelImageFileStorageBase
from invokeai.app.services.model_manager.model_manager_base import ModelManagerServiceBase
from invokeai.app.services.model_relationship_records.model_relationship_records_base import (
ModelRelationshipRecordStorageBase,
)
from invokeai.app.services.model_relationships.model_relationships_base import ModelRelationshipsServiceABC
from invokeai.app.services.names.names_base import NameServiceBase
from invokeai.app.services.session_processor.session_processor_base import SessionProcessorBase
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
@@ -54,6 +58,8 @@ class InvocationServices:
logger: "Logger",
model_images: "ModelImageFileStorageBase",
model_manager: "ModelManagerServiceBase",
model_relationships: "ModelRelationshipsServiceABC",
model_relationship_records: "ModelRelationshipRecordStorageBase",
download_queue: "DownloadQueueServiceBase",
performance_statistics: "InvocationStatsServiceBase",
session_queue: "SessionQueueBase",
@@ -81,6 +87,8 @@ class InvocationServices:
self.logger = logger
self.model_images = model_images
self.model_manager = model_manager
self.model_relationships = model_relationships
self.model_relationship_records = model_relationship_records
self.download_queue = download_queue
self.performance_statistics = performance_statistics
self.session_queue = session_queue

View File

@@ -60,7 +60,7 @@ class InvocationStatsServiceBase(ABC):
pass
@abstractmethod
def reset_stats(self):
def reset_stats(self, graph_execution_state_id: str) -> None:
"""Reset all stored statistics."""
pass

View File

@@ -73,9 +73,9 @@ class InvocationStatsService(InvocationStatsServiceBase):
)
self._stats[graph_execution_state_id].add_node_execution_stats(node_stats)
def reset_stats(self):
self._stats = {}
self._cache_stats = {}
def reset_stats(self, graph_execution_state_id: str) -> None:
self._stats.pop(graph_execution_state_id, None)
self._cache_stats.pop(graph_execution_state_id, None)
def get_stats(self, graph_execution_state_id: str) -> InvocationStatsSummary:
graph_stats_summary = self._get_graph_summary(graph_execution_state_id)

View File

@@ -51,6 +51,7 @@ from invokeai.backend.model_manager.metadata import (
from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMetadata
from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.model_manager.taxonomy import ModelRepoVariant, ModelSourceType
from invokeai.backend.model_manager.util.lora_metadata_extractor import apply_lora_metadata
from invokeai.backend.util import InvokeAILogger
from invokeai.backend.util.catch_sigint import catch_sigint
from invokeai.backend.util.devices import TorchDevice
@@ -148,7 +149,7 @@ class ModelInstallService(ModelInstallServiceBase):
def _clear_pending_jobs(self) -> None:
for job in self.list_jobs():
if not job.in_terminal_state:
self._logger.warning("Cancelling job {job.id}")
self._logger.warning(f"Cancelling job {job.id}")
self.cancel_job(job)
while True:
try:
@@ -647,10 +648,18 @@ class ModelInstallService(ModelInstallServiceBase):
hash_algo = self._app_config.hashing_algorithm
fields = config.model_dump()
# WARNING!
# The legacy probe relies on the implicit order of tests to determine model classification.
# This can lead to regressions between the legacy and new probes.
# Do NOT change the order of `probe` and `classify` without implementing one of the following fixes:
# Short-term fix: `classify` tests `matches` in the same order as the legacy probe.
# Long-term fix: Improve `matches` to be more specific so that only one config matches
# any given model - eliminating ambiguity and removing reliance on order.
# After implementing either of these fixes, remove @pytest.mark.xfail from `test_regression_against_model_probe`
try:
return ModelConfigBase.classify(model_path=model_path, hash_algo=hash_algo, **fields)
except InvalidModelConfigException:
return ModelProbe.probe(model_path=model_path, fields=fields, hash_algo=hash_algo) # type: ignore
except InvalidModelConfigException:
return ModelConfigBase.classify(model_path, hash_algo, **fields)
def _register(
self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None
@@ -659,6 +668,10 @@ class ModelInstallService(ModelInstallServiceBase):
info = info or self._probe(model_path, config)
# Apply LoRA metadata if applicable
model_images_path = self.app_config.models_path / "model_images"
apply_lora_metadata(info, model_path.resolve(), model_images_path)
model_path = model_path.resolve()
# Models in the Invoke-managed models dir should use relative paths.

View File

@@ -78,11 +78,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
self._db = db
self._logger = logger
@property
def db(self) -> SqliteDatabase:
"""Return the underlying database."""
return self._db
def add_model(self, config: AnyModelConfig) -> AnyModelConfig:
"""
Add a model to the database.
@@ -93,38 +88,33 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
"""
try:
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
INSERT INTO models (
id,
config
)
VALUES (?,?);
""",
(
config.key,
config.model_dump_json(),
),
)
self._db.conn.commit()
with self._db.transaction() as cursor:
try:
cursor.execute(
"""--sql
INSERT INTO models (
id,
config
)
VALUES (?,?);
""",
(
config.key,
config.model_dump_json(),
),
)
except sqlite3.IntegrityError as e:
self._db.conn.rollback()
if "UNIQUE constraint failed" in str(e):
if "models.path" in str(e):
msg = f"A model with path '{config.path}' is already installed"
elif "models.name" in str(e):
msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
except sqlite3.IntegrityError as e:
if "UNIQUE constraint failed" in str(e):
if "models.path" in str(e):
msg = f"A model with path '{config.path}' is already installed"
elif "models.name" in str(e):
msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
else:
msg = f"A model with key '{config.key}' is already installed"
raise DuplicateModelException(msg) from e
else:
msg = f"A model with key '{config.key}' is already installed"
raise DuplicateModelException(msg) from e
else:
raise e
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
raise e
return self.get_model(config.key)
@@ -136,8 +126,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Can raise an UnknownModelException
"""
try:
cursor = self._db.conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
DELETE FROM models
@@ -147,22 +136,17 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
)
if cursor.rowcount == 0:
raise UnknownModelException("model not found")
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
record = self.get_model(key)
with self._db.transaction() as cursor:
record = self.get_model(key)
# Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic.
for field_name in changes.model_fields_set:
setattr(record, field_name, getattr(changes, field_name))
# Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic.
for field_name in changes.model_fields_set:
setattr(record, field_name, getattr(changes, field_name))
json_serialized = record.model_dump_json()
json_serialized = record.model_dump_json()
try:
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
UPDATE models
@@ -174,10 +158,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
)
if cursor.rowcount == 0:
raise UnknownModelException("model not found")
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
return self.get_model(key)
@@ -189,30 +169,30 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Exceptions: UnknownModelException
"""
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE id=?;
""",
(key,),
)
rows = cursor.fetchone()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE id=?;
""",
(key,),
)
rows = cursor.fetchone()
if not rows:
raise UnknownModelException("model not found")
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
return model
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
rows = cursor.fetchone()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
rows = cursor.fetchone()
if not rows:
raise UnknownModelException("model not found")
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
@@ -224,15 +204,15 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
:param key: Unique key for the model to be deleted
"""
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
select count(*) FROM models
WHERE id=?;
""",
(key,),
)
count = cursor.fetchone()[0]
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
select count(*) FROM models
WHERE id=?;
""",
(key,),
)
count = cursor.fetchone()[0]
return count > 0
def search_by_attr(
@@ -255,43 +235,42 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
If none of the optional filters are passed, will return all
models in the database.
"""
with self._db.transaction() as cursor:
assert isinstance(order_by, ModelRecordOrderBy)
ordering = {
ModelRecordOrderBy.Default: "type, base, name, format",
ModelRecordOrderBy.Type: "type",
ModelRecordOrderBy.Base: "base",
ModelRecordOrderBy.Name: "name",
ModelRecordOrderBy.Format: "format",
}
assert isinstance(order_by, ModelRecordOrderBy)
ordering = {
ModelRecordOrderBy.Default: "type, base, name, format",
ModelRecordOrderBy.Type: "type",
ModelRecordOrderBy.Base: "base",
ModelRecordOrderBy.Name: "name",
ModelRecordOrderBy.Format: "format",
}
where_clause: list[str] = []
bindings: list[str] = []
if model_name:
where_clause.append("name=?")
bindings.append(model_name)
if base_model:
where_clause.append("base=?")
bindings.append(base_model)
if model_type:
where_clause.append("type=?")
bindings.append(model_type)
if model_format:
where_clause.append("format=?")
bindings.append(model_format)
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
where_clause: list[str] = []
bindings: list[str] = []
if model_name:
where_clause.append("name=?")
bindings.append(model_name)
if base_model:
where_clause.append("base=?")
bindings.append(base_model)
if model_type:
where_clause.append("type=?")
bindings.append(model_type)
if model_format:
where_clause.append("format=?")
bindings.append(model_format)
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
cursor = self._db.conn.cursor()
cursor.execute(
f"""--sql
SELECT config, strftime('%s',updated_at)
FROM models
{where}
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason;
""",
tuple(bindings),
)
result = cursor.fetchall()
cursor.execute(
f"""--sql
SELECT config, strftime('%s',updated_at)
FROM models
{where}
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason;
""",
tuple(bindings),
)
result = cursor.fetchall()
# Parse the model configs.
results: list[AnyModelConfig] = []
@@ -313,69 +292,68 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
"""Return models with the indicated path."""
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE path=?;
""",
(str(path),),
)
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE path=?;
""",
(str(path),),
)
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
return results
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
"""Return models with the indicated hash."""
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
return results
def list_models(
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
) -> PaginatedResults[ModelSummary]:
"""Return a paginated summary listing of each model in the database."""
assert isinstance(order_by, ModelRecordOrderBy)
ordering = {
ModelRecordOrderBy.Default: "type, base, name, format",
ModelRecordOrderBy.Type: "type",
ModelRecordOrderBy.Base: "base",
ModelRecordOrderBy.Name: "name",
ModelRecordOrderBy.Format: "format",
}
with self._db.transaction() as cursor:
assert isinstance(order_by, ModelRecordOrderBy)
ordering = {
ModelRecordOrderBy.Default: "type, base, name, format",
ModelRecordOrderBy.Type: "type",
ModelRecordOrderBy.Base: "base",
ModelRecordOrderBy.Name: "name",
ModelRecordOrderBy.Format: "format",
}
cursor = self._db.conn.cursor()
# Lock so that the database isn't updated while we're doing the two queries.
# query1: get the total number of model configs
cursor.execute(
"""--sql
select count(*) from models;
""",
(),
)
total = int(cursor.fetchone()[0])
# Lock so that the database isn't updated while we're doing the two queries.
# query1: get the total number of model configs
cursor.execute(
"""--sql
select count(*) from models;
""",
(),
)
total = int(cursor.fetchone()[0])
# query2: fetch key fields
cursor.execute(
f"""--sql
SELECT config
FROM models
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
LIMIT ?
OFFSET ?;
""",
(
per_page,
page * per_page,
),
)
rows = cursor.fetchall()
# query2: fetch key fields
cursor.execute(
f"""--sql
SELECT config
FROM models
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
LIMIT ?
OFFSET ?;
""",
(
per_page,
page * per_page,
),
)
rows = cursor.fetchall()
items = [ModelSummary.model_validate(dict(x)) for x in rows]
return PaginatedResults(page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items)

View File

@@ -0,0 +1,25 @@
from abc import ABC, abstractmethod
class ModelRelationshipRecordStorageBase(ABC):
"""Abstract base class for model-to-model relationship record storage."""
@abstractmethod
def add_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
"""Creates a relationship between two models by keys."""
pass
@abstractmethod
def remove_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
"""Removes a relationship between two models by keys."""
pass
@abstractmethod
def get_related_model_keys(self, model_key: str) -> list[str]:
"""Gets all models keys related to a given model key."""
pass
@abstractmethod
def get_related_model_keys_batch(self, model_keys: list[str]) -> list[str]:
"""Get related model keys for multiple models given a list of keys."""
pass

View File

@@ -0,0 +1,55 @@
from invokeai.app.services.model_relationship_records.model_relationship_records_base import (
ModelRelationshipRecordStorageBase,
)
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class SqliteModelRelationshipRecordStorage(ModelRelationshipRecordStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._db = db
def add_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
with self._db.transaction() as cursor:
if model_key_1 == model_key_2:
raise ValueError("Cannot relate a model to itself.")
a, b = sorted([model_key_1, model_key_2])
cursor.execute(
"INSERT OR IGNORE INTO model_relationships (model_key_1, model_key_2) VALUES (?, ?)",
(a, b),
)
def remove_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
with self._db.transaction() as cursor:
a, b = sorted([model_key_1, model_key_2])
cursor.execute(
"DELETE FROM model_relationships WHERE model_key_1 = ? AND model_key_2 = ?",
(a, b),
)
def get_related_model_keys(self, model_key: str) -> list[str]:
with self._db.transaction() as cursor:
cursor.execute(
"""
SELECT model_key_2 FROM model_relationships WHERE model_key_1 = ?
UNION
SELECT model_key_1 FROM model_relationships WHERE model_key_2 = ?
""",
(model_key, model_key),
)
result = [row[0] for row in cursor.fetchall()]
return result
def get_related_model_keys_batch(self, model_keys: list[str]) -> list[str]:
with self._db.transaction() as cursor:
key_list = ",".join("?" for _ in model_keys)
cursor.execute(
f"""
SELECT model_key_2 FROM model_relationships WHERE model_key_1 IN ({key_list})
UNION
SELECT model_key_1 FROM model_relationships WHERE model_key_2 IN ({key_list})
""",
model_keys + model_keys,
)
result = [row[0] for row in cursor.fetchall()]
return result

View File

@@ -0,0 +1,25 @@
from abc import ABC, abstractmethod
class ModelRelationshipsServiceABC(ABC):
"""High-level service for managing model-to-model relationships."""
@abstractmethod
def add_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
"""Creates a relationship between two models keys."""
pass
@abstractmethod
def remove_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
"""Removes a relationship between two models keys."""
pass
@abstractmethod
def get_related_model_keys(self, model_key: str) -> list[str]:
"""Gets all models keys related to a given model key."""
pass
@abstractmethod
def get_related_model_keys_batch(self, model_keys: list[str]) -> list[str]:
"""Get related model keys for multiple models."""
pass

View File

@@ -0,0 +1,9 @@
from datetime import datetime
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
class ModelRelationship(BaseModelExcludeNull):
model_key_1: str
model_key_2: str
created_at: datetime

View File

@@ -0,0 +1,31 @@
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_relationships.model_relationships_base import ModelRelationshipsServiceABC
from invokeai.backend.model_manager.config import AnyModelConfig
class ModelRelationshipsService(ModelRelationshipsServiceABC):
__invoker: Invoker
def start(self, invoker: Invoker) -> None:
self.__invoker = invoker
def add_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
self.__invoker.services.model_relationship_records.add_model_relationship(model_key_1, model_key_2)
def remove_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
self.__invoker.services.model_relationship_records.remove_model_relationship(model_key_1, model_key_2)
def get_related_model_keys(self, model_key: str) -> list[str]:
return self.__invoker.services.model_relationship_records.get_related_model_keys(model_key)
def add_relationship_from_models(self, model_1: AnyModelConfig, model_2: AnyModelConfig) -> None:
self.add_model_relationship(model_1.key, model_2.key)
def remove_relationship_from_models(self, model_1: AnyModelConfig, model_2: AnyModelConfig) -> None:
self.remove_model_relationship(model_1.key, model_2.key)
def get_related_keys_from_model(self, model: AnyModelConfig) -> list[str]:
return self.get_related_model_keys(model.key)
def get_related_model_keys_batch(self, model_keys: list[str]) -> list[str]:
return self.__invoker.services.model_relationship_records.get_related_model_keys_batch(model_keys)

View File

@@ -1,3 +1,4 @@
import gc
import traceback
from contextlib import suppress
from threading import BoundedSemaphore, Thread
@@ -210,7 +211,7 @@ class DefaultSessionRunner(SessionRunnerBase):
# we don't care about that - suppress the error.
with suppress(GESStatsNotFoundError):
self._services.performance_statistics.log_stats(queue_item.session.id)
self._services.performance_statistics.reset_stats()
self._services.performance_statistics.reset_stats(queue_item.session.id)
for callback in self._on_after_run_session_callbacks:
callback(queue_item=queue_item)
@@ -439,6 +440,12 @@ class DefaultSessionProcessor(SessionProcessorBase):
poll_now_event.wait(self._polling_interval)
continue
# GC-ing here can reduce peak memory usage of the invoke process by freeing allocated memory blocks.
# Most queue items take seconds to execute, so the relative cost of a GC is very small.
# Python will never cede allocated memory back to the OS, so anything we can do to reduce the peak
# allocation is well worth it.
gc.collect()
self._invoker.services.logger.info(
f"Executing queue item {self._queue_item.item_id}, session {self._queue_item.session_id}"
)

View File

@@ -10,6 +10,8 @@ from invokeai.app.services.session_queue.session_queue_common import (
CancelByDestinationResult,
CancelByQueueIDResult,
ClearResult,
DeleteAllExceptCurrentResult,
DeleteByDestinationResult,
EnqueueBatchResult,
IsEmptyResult,
IsFullResult,
@@ -17,7 +19,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
RetryItemsResult,
SessionQueueCountsByDestination,
SessionQueueItem,
SessionQueueItemDTO,
SessionQueueStatus,
)
from invokeai.app.services.shared.graph import GraphExecutionState
@@ -92,6 +93,11 @@ class SessionQueueBase(ABC):
"""Cancels a session queue item"""
pass
@abstractmethod
def delete_queue_item(self, item_id: int) -> None:
"""Deletes a session queue item"""
pass
@abstractmethod
def fail_queue_item(
self, item_id: int, error_type: str, error_message: str, error_traceback: str
@@ -109,6 +115,11 @@ class SessionQueueBase(ABC):
"""Cancels all queue items with the given batch destination"""
pass
@abstractmethod
def delete_by_destination(self, queue_id: str, destination: str) -> DeleteByDestinationResult:
"""Deletes all queue items with the given batch destination"""
pass
@abstractmethod
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
"""Cancels all queue items with matching queue ID"""
@@ -119,6 +130,11 @@ class SessionQueueBase(ABC):
"""Cancels all queue items except in-progress items"""
pass
@abstractmethod
def delete_all_except_current(self, queue_id: str) -> DeleteAllExceptCurrentResult:
"""Deletes all queue items except in-progress items"""
pass
@abstractmethod
def list_queue_items(
self,
@@ -127,10 +143,20 @@ class SessionQueueBase(ABC):
priority: int,
cursor: Optional[int] = None,
status: Optional[QUEUE_ITEM_STATUS] = None,
) -> CursorPaginatedResults[SessionQueueItemDTO]:
destination: Optional[str] = None,
) -> CursorPaginatedResults[SessionQueueItem]:
"""Gets a page of session queue items"""
pass
@abstractmethod
def list_all_queue_items(
self,
queue_id: str,
destination: Optional[str] = None,
) -> list[SessionQueueItem]:
"""Gets all queue items that match the given parameters"""
pass
@abstractmethod
def get_queue_item(self, item_id: int) -> SessionQueueItem:
"""Gets a session queue item by ID"""

View File

@@ -148,7 +148,7 @@ class Batch(BaseModel):
node = cast(BaseInvocation, graph.get_node(batch_data.node_path))
except NodeNotFoundError:
raise NodeNotFoundError(f"Node {batch_data.node_path} not found in graph")
if batch_data.field_name not in node.model_fields:
if batch_data.field_name not in type(node).model_fields:
raise NodeNotFoundError(f"Field {batch_data.field_name} not found in node {batch_data.node_path}")
return values
@@ -205,9 +205,10 @@ class FieldIdentifier(BaseModel):
kind: Literal["input", "output"] = Field(description="The kind of field")
node_id: str = Field(description="The ID of the node")
field_name: str = Field(description="The name of the field")
user_label: str | None = Field(description="The user label of the field, if any")
class SessionQueueItemWithoutGraph(BaseModel):
class SessionQueueItem(BaseModel):
"""Session queue item without the full graph. Used for serialization."""
item_id: int = Field(description="The identifier of the session queue item")
@@ -251,41 +252,7 @@ class SessionQueueItemWithoutGraph(BaseModel):
default=None,
description="The ID of the published workflow associated with this queue item",
)
api_input_fields: Optional[list[FieldIdentifier]] = Field(
default=None, description="The fields that were used as input to the API"
)
api_output_fields: Optional[list[FieldIdentifier]] = Field(
default=None, description="The nodes that were used as output from the API"
)
@classmethod
def queue_item_dto_from_dict(cls, queue_item_dict: dict) -> "SessionQueueItemDTO":
# must parse these manually
queue_item_dict["field_values"] = get_field_values(queue_item_dict)
return SessionQueueItemDTO(**queue_item_dict)
model_config = ConfigDict(
json_schema_extra={
"required": [
"item_id",
"status",
"batch_id",
"queue_id",
"session_id",
"priority",
"session_id",
"created_at",
"updated_at",
]
}
)
class SessionQueueItemDTO(SessionQueueItemWithoutGraph):
pass
class SessionQueueItem(SessionQueueItemWithoutGraph):
credits: Optional[float] = Field(default=None, description="The total credits used for this queue item")
session: GraphExecutionState = Field(description="The fully-populated session to be executed")
workflow: Optional[WorkflowWithoutID] = Field(
default=None, description="The workflow associated with this queue item"
@@ -365,6 +332,7 @@ class EnqueueBatchResult(BaseModel):
requested: int = Field(description="The total number of queue items requested to be enqueued")
batch: Batch = Field(description="The batch that was enqueued")
priority: int = Field(description="The priority of the enqueued batch")
item_ids: list[int] = Field(description="The IDs of the queue items that were enqueued")
class RetryItemsResult(BaseModel):
@@ -396,6 +364,18 @@ class CancelByDestinationResult(CancelByBatchIDsResult):
pass
class DeleteByDestinationResult(BaseModel):
"""Result of deleting by a destination"""
deleted: int = Field(..., description="Number of queue items deleted")
class DeleteAllExceptCurrentResult(DeleteByDestinationResult):
"""Result of deleting all except current"""
pass
class CancelByQueueIDResult(CancelByBatchIDsResult):
"""Result of canceling by queue id"""

View File

@@ -17,6 +17,8 @@ from invokeai.app.services.session_queue.session_queue_common import (
CancelByDestinationResult,
CancelByQueueIDResult,
ClearResult,
DeleteAllExceptCurrentResult,
DeleteByDestinationResult,
EnqueueBatchResult,
IsEmptyResult,
IsFullResult,
@@ -24,7 +26,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
RetryItemsResult,
SessionQueueCountsByDestination,
SessionQueueItem,
SessionQueueItemDTO,
SessionQueueItemNotFoundError,
SessionQueueStatus,
ValueToInsertTuple,
@@ -46,22 +47,17 @@ class SqliteSessionQueue(SessionQueueBase):
clear_result = self.clear(DEFAULT_QUEUE_ID)
if clear_result.deleted > 0:
self.__invoker.services.logger.info(f"Cleared all {clear_result.deleted} queue items")
else:
prune_result = self.prune(DEFAULT_QUEUE_ID)
if prune_result.deleted > 0:
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._conn = db.conn
self._db = db
def _set_in_progress_to_canceled(self) -> None:
"""
Sets all in_progress queue items to canceled. Run on app startup, not associated with any queue.
This is necessary because the invoker may have been killed while processing a queue item.
"""
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
UPDATE session_queue
@@ -69,102 +65,104 @@ class SqliteSessionQueue(SessionQueueBase):
WHERE status = 'in_progress';
"""
)
except Exception:
self._conn.rollback()
raise
def _get_current_queue_size(self, queue_id: str) -> int:
"""Gets the current number of pending queue items"""
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
""",
(queue_id,),
)
return cast(int, cursor.fetchone()[0])
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
""",
(queue_id,),
)
count = cast(int, cursor.fetchone()[0])
return count
def _get_highest_priority(self, queue_id: str) -> int:
"""Gets the highest priority value in the queue"""
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT MAX(priority)
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
""",
(queue_id,),
)
return cast(Union[int, None], cursor.fetchone()[0]) or 0
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT MAX(priority)
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
""",
(queue_id,),
)
priority = cast(Union[int, None], cursor.fetchone()[0]) or 0
return priority
async def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
return await asyncio.to_thread(self._enqueue_batch, queue_id, batch, prepend)
current_queue_size = self._get_current_queue_size(queue_id)
max_queue_size = self.__invoker.services.configuration.max_queue_size
max_new_queue_items = max_queue_size - current_queue_size
def _enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
try:
cursor = self._conn.cursor()
# TODO: how does this work in a multi-user scenario?
current_queue_size = self._get_current_queue_size(queue_id)
max_queue_size = self.__invoker.services.configuration.max_queue_size
max_new_queue_items = max_queue_size - current_queue_size
priority = 0
if prepend:
priority = self._get_highest_priority(queue_id) + 1
priority = 0
if prepend:
priority = self._get_highest_priority(queue_id) + 1
requested_count = calc_session_count(batch)
values_to_insert = prepare_values_to_insert(
queue_id=queue_id,
batch=batch,
priority=priority,
max_new_queue_items=max_new_queue_items,
)
enqueued_count = len(values_to_insert)
if requested_count > enqueued_count:
values_to_insert = values_to_insert[:max_new_queue_items]
requested_count = await asyncio.to_thread(
calc_session_count,
batch=batch,
)
values_to_insert = await asyncio.to_thread(
prepare_values_to_insert,
queue_id=queue_id,
batch=batch,
priority=priority,
max_new_queue_items=max_new_queue_items,
)
enqueued_count = len(values_to_insert)
with self._db.transaction() as cursor:
cursor.executemany(
"""--sql
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
values_to_insert,
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
cursor.execute(
"""--sql
SELECT item_id
FROM session_queue
WHERE batch_id = ?
ORDER BY item_id DESC;
""",
(batch.batch_id,),
)
item_ids = [row[0] for row in cursor.fetchall()]
enqueue_result = EnqueueBatchResult(
queue_id=queue_id,
requested=requested_count,
enqueued=enqueued_count,
batch=batch,
priority=priority,
item_ids=item_ids,
)
self.__invoker.services.events.emit_batch_enqueued(enqueue_result)
return enqueue_result
def dequeue(self) -> Optional[SessionQueueItem]:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE status = 'pending'
ORDER BY
priority DESC,
item_id ASC
LIMIT 1
"""
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE status = 'pending'
ORDER BY
priority DESC,
item_id ASC
LIMIT 1
"""
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
if result is None:
return None
queue_item = SessionQueueItem.queue_item_from_dict(dict(result))
@@ -172,40 +170,40 @@ class SqliteSessionQueue(SessionQueueBase):
return queue_item
def get_next(self, queue_id: str) -> Optional[SessionQueueItem]:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
ORDER BY
priority DESC,
created_at ASC
LIMIT 1
""",
(queue_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
ORDER BY
priority DESC,
created_at ASC
LIMIT 1
""",
(queue_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
if result is None:
return None
return SessionQueueItem.queue_item_from_dict(dict(result))
def get_current(self, queue_id: str) -> Optional[SessionQueueItem]:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE
queue_id = ?
AND status = 'in_progress'
LIMIT 1
""",
(queue_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE
queue_id = ?
AND status = 'in_progress'
LIMIT 1
""",
(queue_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
if result is None:
return None
return SessionQueueItem.queue_item_from_dict(dict(result))
@@ -218,8 +216,23 @@ class SqliteSessionQueue(SessionQueueBase):
error_message: Optional[str] = None,
error_traceback: Optional[str] = None,
) -> SessionQueueItem:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT status FROM session_queue WHERE item_id = ?
""",
(item_id,),
)
row = cursor.fetchone()
if row is None:
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
current_status = row[0]
# Only update if not already finished (completed, failed or canceled)
if current_status in ("completed", "failed", "canceled"):
return self.get_queue_item(item_id)
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
UPDATE session_queue
@@ -228,10 +241,7 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(status, error_type, error_message, error_traceback, item_id),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
queue_item = self.get_queue_item(item_id)
batch_status = self.get_batch_status(queue_id=queue_item.queue_id, batch_id=queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_item.queue_id)
@@ -239,35 +249,34 @@ class SqliteSessionQueue(SessionQueueBase):
return queue_item
def is_empty(self, queue_id: str) -> IsEmptyResult:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE queue_id = ?
""",
(queue_id,),
)
is_empty = cast(int, cursor.fetchone()[0]) == 0
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE queue_id = ?
""",
(queue_id,),
)
is_empty = cast(int, cursor.fetchone()[0]) == 0
return IsEmptyResult(is_empty=is_empty)
def is_full(self, queue_id: str) -> IsFullResult:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE queue_id = ?
""",
(queue_id,),
)
max_queue_size = self.__invoker.services.configuration.max_queue_size
is_full = cast(int, cursor.fetchone()[0]) >= max_queue_size
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE queue_id = ?
""",
(queue_id,),
)
max_queue_size = self.__invoker.services.configuration.max_queue_size
is_full = cast(int, cursor.fetchone()[0]) >= max_queue_size
return IsFullResult(is_full=is_full)
def clear(self, queue_id: str) -> ClearResult:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT COUNT(*)
@@ -285,24 +294,19 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
self.__invoker.services.events.emit_queue_cleared(queue_id)
return ClearResult(deleted=count)
def prune(self, queue_id: str) -> PruneResult:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
where = """--sql
WHERE
queue_id = ?
AND (
queue_id = ?
AND (
status = 'completed'
OR status = 'failed'
OR status = 'canceled'
)
)
"""
cursor.execute(
f"""--sql
@@ -321,16 +325,28 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return PruneResult(deleted=count)
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
queue_item = self._set_queue_item_status(item_id=item_id, status="canceled")
return queue_item
def delete_queue_item(self, item_id: int) -> None:
"""Deletes a session queue item"""
try:
self.cancel_queue_item(item_id)
except SessionQueueItemNotFoundError:
pass
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
DELETE
FROM session_queue
WHERE item_id = ?
""",
(item_id,),
)
def complete_queue_item(self, item_id: int) -> SessionQueueItem:
queue_item = self._set_queue_item_status(item_id=item_id, status="completed")
return queue_item
@@ -352,8 +368,7 @@ class SqliteSessionQueue(SessionQueueBase):
return queue_item
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
current_queue_item = self.get_current(queue_id)
placeholders = ", ".join(["?" for _ in batch_ids])
where = f"""--sql
@@ -363,6 +378,8 @@ class SqliteSessionQueue(SessionQueueBase):
AND status != 'canceled'
AND status != 'completed'
AND status != 'failed'
-- We will cancel the current item separately below - skip it here
AND status != 'in_progress'
"""
params = [queue_id] + batch_ids
cursor.execute(
@@ -382,17 +399,14 @@ class SqliteSessionQueue(SessionQueueBase):
""",
tuple(params),
)
self._conn.commit()
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
except Exception:
self._conn.rollback()
raise
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
return CancelByBatchIDsResult(canceled=count)
def cancel_by_destination(self, queue_id: str, destination: str) -> CancelByDestinationResult:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
current_queue_item = self.get_current(queue_id)
where = """--sql
WHERE
@@ -401,6 +415,8 @@ class SqliteSessionQueue(SessionQueueBase):
AND status != 'canceled'
AND status != 'completed'
AND status != 'failed'
-- We will cancel the current item separately below - skip it here
AND status != 'in_progress'
"""
params = (queue_id, destination)
cursor.execute(
@@ -420,17 +436,67 @@ class SqliteSessionQueue(SessionQueueBase):
""",
params,
)
self._conn.commit()
if current_queue_item is not None and current_queue_item.destination == destination:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
except Exception:
self._conn.rollback()
raise
if current_queue_item is not None and current_queue_item.destination == destination:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
return CancelByDestinationResult(canceled=count)
def delete_by_destination(self, queue_id: str, destination: str) -> DeleteByDestinationResult:
with self._db.transaction() as cursor:
current_queue_item = self.get_current(queue_id)
if current_queue_item is not None and current_queue_item.destination == destination:
self.cancel_queue_item(current_queue_item.item_id)
params = (queue_id, destination)
cursor.execute(
"""--sql
SELECT COUNT(*)
FROM session_queue
WHERE
queue_id = ?
AND destination = ?;
""",
params,
)
count = cursor.fetchone()[0]
cursor.execute(
"""--sql
DELETE
FROM session_queue
WHERE
queue_id = ?
AND destination = ?;
""",
params,
)
return DeleteByDestinationResult(deleted=count)
def delete_all_except_current(self, queue_id: str) -> DeleteAllExceptCurrentResult:
with self._db.transaction() as cursor:
where = """--sql
WHERE
queue_id == ?
AND status == 'pending'
"""
cursor.execute(
f"""--sql
SELECT COUNT(*)
FROM session_queue
{where};
""",
(queue_id,),
)
count = cursor.fetchone()[0]
cursor.execute(
f"""--sql
DELETE
FROM session_queue
{where};
""",
(queue_id,),
)
return DeleteAllExceptCurrentResult(deleted=count)
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
current_queue_item = self.get_current(queue_id)
where = """--sql
WHERE
@@ -438,6 +504,8 @@ class SqliteSessionQueue(SessionQueueBase):
AND status != 'canceled'
AND status != 'completed'
AND status != 'failed'
-- We will cancel the current item separately below - skip it here
AND status != 'in_progress'
"""
params = [queue_id]
cursor.execute(
@@ -457,21 +525,13 @@ class SqliteSessionQueue(SessionQueueBase):
""",
tuple(params),
)
self._conn.commit()
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_id)
self.__invoker.services.events.emit_queue_item_status_changed(
current_queue_item, batch_status, queue_status
)
except Exception:
self._conn.rollback()
raise
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
return CancelByQueueIDResult(canceled=count)
def cancel_all_except_current(self, queue_id: str) -> CancelAllExceptCurrentResult:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
where = """--sql
WHERE
queue_id == ?
@@ -494,30 +554,25 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return CancelAllExceptCurrentResult(canceled=count)
def get_queue_item(self, item_id: int) -> SessionQueueItem:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT * FROM session_queue
WHERE
item_id = ?
""",
(item_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT * FROM session_queue
WHERE
item_id = ?
""",
(item_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
if result is None:
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
return SessionQueueItem.queue_item_from_dict(dict(result))
def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
# Use exclude_none so we don't end up with a bunch of nulls in the graph - this can cause validation errors
# when the graph is loaded. Graph execution occurs purely in memory - the session saved here is not referenced
# during execution.
@@ -530,10 +585,6 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(session_json, item_id),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return self.get_queue_item(item_id)
def list_queue_items(
@@ -543,53 +594,45 @@ class SqliteSessionQueue(SessionQueueBase):
priority: int,
cursor: Optional[int] = None,
status: Optional[QUEUE_ITEM_STATUS] = None,
) -> CursorPaginatedResults[SessionQueueItemDTO]:
cursor_ = self._conn.cursor()
item_id = cursor
query = """--sql
SELECT item_id,
status,
priority,
field_values,
error_type,
error_message,
error_traceback,
created_at,
updated_at,
completed_at,
started_at,
session_id,
batch_id,
queue_id,
origin,
destination
FROM session_queue
WHERE queue_id = ?
"""
params: list[Union[str, int]] = [queue_id]
if status is not None:
query += """--sql
AND status = ?
"""
params.append(status)
if item_id is not None:
query += """--sql
AND (priority < ?) OR (priority = ? AND item_id > ?)
"""
params.extend([priority, priority, item_id])
query += """--sql
ORDER BY
priority DESC,
item_id ASC
LIMIT ?
destination: Optional[str] = None,
) -> CursorPaginatedResults[SessionQueueItem]:
with self._db.transaction() as cursor_:
item_id = cursor
query = """--sql
SELECT *
FROM session_queue
WHERE queue_id = ?
"""
params.append(limit + 1)
cursor_.execute(query, params)
results = cast(list[sqlite3.Row], cursor_.fetchall())
items = [SessionQueueItemDTO.queue_item_dto_from_dict(dict(result)) for result in results]
params: list[Union[str, int]] = [queue_id]
if status is not None:
query += """--sql
AND status = ?
"""
params.append(status)
if destination is not None:
query += """---sql
AND destination = ?
"""
params.append(destination)
if item_id is not None:
query += """--sql
AND (priority < ?) OR (priority = ? AND item_id > ?)
"""
params.extend([priority, priority, item_id])
query += """--sql
ORDER BY
priority DESC,
item_id ASC
LIMIT ?
"""
params.append(limit + 1)
cursor_.execute(query, params)
results = cast(list[sqlite3.Row], cursor_.fetchall())
items = [SessionQueueItem.queue_item_from_dict(dict(result)) for result in results]
has_more = False
if len(items) > limit:
# remove the extra item
@@ -597,21 +640,52 @@ class SqliteSessionQueue(SessionQueueBase):
has_more = True
return CursorPaginatedResults(items=items, limit=limit, has_more=has_more)
def list_all_queue_items(
self,
queue_id: str,
destination: Optional[str] = None,
) -> list[SessionQueueItem]:
"""Gets all queue items that match the given parameters"""
with self._db.transaction() as cursor:
query = """--sql
SELECT *
FROM session_queue
WHERE queue_id = ?
"""
params: list[Union[str, int]] = [queue_id]
if destination is not None:
query += """---sql
AND destination = ?
"""
params.append(destination)
query += """--sql
ORDER BY
priority DESC,
item_id ASC
;
"""
cursor.execute(query, params)
results = cast(list[sqlite3.Row], cursor.fetchall())
items = [SessionQueueItem.queue_item_from_dict(dict(result)) for result in results]
return items
def get_queue_status(self, queue_id: str) -> SessionQueueStatus:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
WHERE queue_id = ?
GROUP BY status
""",
(queue_id,),
)
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
WHERE queue_id = ?
GROUP BY status
""",
(queue_id,),
)
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
current_item = self.get_current(queue_id=queue_id)
total = sum(row[1] for row in counts_result)
total = sum(row[1] or 0 for row in counts_result)
counts: dict[str, int] = {row[0]: row[1] for row in counts_result}
return SessionQueueStatus(
queue_id=queue_id,
@@ -627,20 +701,20 @@ class SqliteSessionQueue(SessionQueueBase):
)
def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT status, count(*), origin, destination
FROM session_queue
WHERE
queue_id = ?
AND batch_id = ?
GROUP BY status
""",
(queue_id, batch_id),
)
result = cast(list[sqlite3.Row], cursor.fetchall())
total = sum(row[1] for row in result)
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT status, count(*), origin, destination
FROM session_queue
WHERE
queue_id = ?
AND batch_id = ?
GROUP BY status
""",
(queue_id, batch_id),
)
result = cast(list[sqlite3.Row], cursor.fetchall())
total = sum(row[1] or 0 for row in result)
counts: dict[str, int] = {row[0]: row[1] for row in result}
origin = result[0]["origin"] if result else None
destination = result[0]["destination"] if result else None
@@ -659,20 +733,20 @@ class SqliteSessionQueue(SessionQueueBase):
)
def get_counts_by_destination(self, queue_id: str, destination: str) -> SessionQueueCountsByDestination:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
WHERE queue_id = ?
AND destination = ?
GROUP BY status
""",
(queue_id, destination),
)
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
WHERE queue_id = ?
AND destination = ?
GROUP BY status
""",
(queue_id, destination),
)
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
total = sum(row[1] for row in counts_result)
total = sum(row[1] or 0 for row in counts_result)
counts: dict[str, int] = {row[0]: row[1] for row in counts_result}
return SessionQueueCountsByDestination(
@@ -688,8 +762,7 @@ class SqliteSessionQueue(SessionQueueBase):
def retry_items_by_id(self, queue_id: str, item_ids: list[int]) -> RetryItemsResult:
"""Retries the given queue items"""
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
values_to_insert: list[ValueToInsertTuple] = []
retried_item_ids: list[int] = []
@@ -740,10 +813,6 @@ class SqliteSessionQueue(SessionQueueBase):
values_to_insert,
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
retry_result = RetryItemsResult(
queue_id=queue_id,
retried_item_ids=retried_item_ids,

View File

@@ -2,11 +2,12 @@
import copy
import itertools
from typing import Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
from typing import Any, Optional, TypeVar, Union, get_args, get_origin
import networkx as nx
from pydantic import (
BaseModel,
ConfigDict,
GetCoreSchemaHandler,
GetJsonSchemaHandler,
ValidationError,
@@ -57,17 +58,32 @@ class Edge(BaseModel):
def get_output_field_type(node: BaseInvocation, field: str) -> Any:
node_type = type(node)
node_outputs = get_type_hints(node_type.get_output_annotation())
node_output_field = node_outputs.get(field) or None
return node_output_field
# TODO(psyche): This is awkward - if field_info is None, it means the field is not defined in the output, which
# really should raise. The consumers of this utility expect it to never raise, and return None instead. Fixing this
# would require some fairly significant changes and I don't want risk breaking anything.
try:
invocation_class = type(node)
invocation_output_class = invocation_class.get_output_annotation()
field_info = invocation_output_class.model_fields.get(field)
assert field_info is not None, f"Output field '{field}' not found in {invocation_output_class.get_type()}"
output_field_type = field_info.annotation
return output_field_type
except Exception:
return None
def get_input_field_type(node: BaseInvocation, field: str) -> Any:
node_type = type(node)
node_inputs = get_type_hints(node_type)
node_input_field = node_inputs.get(field) or None
return node_input_field
# TODO(psyche): This is awkward - if field_info is None, it means the field is not defined in the output, which
# really should raise. The consumers of this utility expect it to never raise, and return None instead. Fixing this
# would require some fairly significant changes and I don't want risk breaking anything.
try:
invocation_class = type(node)
field_info = invocation_class.model_fields.get(field)
assert field_info is not None, f"Input field '{field}' not found in {invocation_class.get_type()}"
input_field_type = field_info.annotation
return input_field_type
except Exception:
return None
def is_union_subtype(t1, t2):
@@ -424,7 +440,7 @@ class Graph(BaseModel):
)
# input fields are on the node
if edge.destination.field not in destination_node.model_fields:
if edge.destination.field not in type(destination_node).model_fields:
raise NodeFieldNotFoundError(
f"Edge destination field {edge.destination.field} does not exist in node {edge.destination.node_id}"
)
@@ -787,6 +803,22 @@ class GraphExecutionState(BaseModel):
default_factory=dict,
)
model_config = ConfigDict(
json_schema_extra={
"required": [
"id",
"graph",
"execution_graph",
"executed",
"executed_history",
"results",
"errors",
"prepared_source_mapping",
"source_prepared_mapping",
]
}
)
@field_validator("graph")
def graph_is_valid(cls, v: Graph):
"""Validates that the graph is valid"""
@@ -975,10 +1007,11 @@ class GraphExecutionState(BaseModel):
new_node_ids = []
if isinstance(next_node, CollectInvocation):
# Collapse all iterator input mappings and create a single execution node for the collect invocation
all_iteration_mappings = list(
itertools.chain(*(((s, p) for p in self.source_prepared_mapping[s]) for s in next_node_parents))
)
# all_iteration_mappings = list(set(itertools.chain(*prepared_parent_mappings)))
all_iteration_mappings = []
for source_node_id in next_node_parents:
prepared_nodes = self.source_prepared_mapping[source_node_id]
all_iteration_mappings.extend([(source_node_id, p) for p in prepared_nodes])
create_results = self._create_execution_node(next_node_id, all_iteration_mappings)
if create_results is not None:
new_node_ids.extend(create_results)

View File

@@ -21,6 +21,7 @@ from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.util.step_callback import diffusion_step_callback
from invokeai.backend.model_manager.config import (
AnyModelConfig,
ModelConfigBase,
)
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType
@@ -543,6 +544,30 @@ class ModelsInterface(InvocationContextInterface):
self._util.signal_progress(f"Loading model {source}")
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
def get_absolute_path(self, config_or_path: AnyModelConfig | Path | str) -> Path:
"""Gets the absolute path for a given model config or path.
For example, if the model's path is `flux/main/FLUX Dev.safetensors`, and the models path is
`/home/username/InvokeAI/models`, this method will return
`/home/username/InvokeAI/models/flux/main/FLUX Dev.safetensors`.
Args:
config_or_path: The model config or path.
Returns:
The absolute path to the model.
"""
model_path = Path(config_or_path.path) if isinstance(config_or_path, ModelConfigBase) else Path(config_or_path)
if model_path.is_absolute():
return model_path.resolve()
base_models_path = self._services.configuration.models_path
joined_path = base_models_path / model_path
resolved_path = joined_path.resolve()
return resolved_path
class ConfigInterface(InvocationContextInterface):
def get(self) -> InvokeAIAppConfig:

View File

@@ -1,4 +1,7 @@
import sqlite3
import threading
from collections.abc import Generator
from contextlib import contextmanager
from logging import Logger
from pathlib import Path
@@ -26,46 +29,65 @@ class SqliteDatabase:
def __init__(self, db_path: Path | None, logger: Logger, verbose: bool = False) -> None:
"""Initializes the database. This is used internally by the class constructor."""
self.logger = logger
self.db_path = db_path
self.verbose = verbose
self._logger = logger
self._db_path = db_path
self._verbose = verbose
self._lock = threading.RLock()
if not self.db_path:
if not self._db_path:
logger.info("Initializing in-memory database")
else:
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self.logger.info(f"Initializing database at {self.db_path}")
self._db_path.parent.mkdir(parents=True, exist_ok=True)
self._logger.info(f"Initializing database at {self._db_path}")
self.conn = sqlite3.connect(database=self.db_path or sqlite_memory, check_same_thread=False)
self.conn.row_factory = sqlite3.Row
self._conn = sqlite3.connect(database=self._db_path or sqlite_memory, check_same_thread=False)
self._conn.row_factory = sqlite3.Row
if self.verbose:
self.conn.set_trace_callback(self.logger.debug)
if self._verbose:
self._conn.set_trace_callback(self._logger.debug)
# Enable foreign key constraints
self.conn.execute("PRAGMA foreign_keys = ON;")
self._conn.execute("PRAGMA foreign_keys = ON;")
# Enable Write-Ahead Logging (WAL) mode for better concurrency
self.conn.execute("PRAGMA journal_mode = WAL;")
self._conn.execute("PRAGMA journal_mode = WAL;")
# Set a busy timeout to prevent database lockups during writes
self.conn.execute("PRAGMA busy_timeout = 5000;") # 5 seconds
self._conn.execute("PRAGMA busy_timeout = 5000;") # 5 seconds
def clean(self) -> None:
"""
Cleans the database by running the VACUUM command, reporting on the freed space.
"""
# No need to clean in-memory database
if not self.db_path:
if not self._db_path:
return
try:
initial_db_size = Path(self.db_path).stat().st_size
self.conn.execute("VACUUM;")
self.conn.commit()
final_db_size = Path(self.db_path).stat().st_size
freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
if freed_space_in_mb > 0:
self.logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
with self._conn as conn:
initial_db_size = Path(self._db_path).stat().st_size
conn.execute("VACUUM;")
conn.commit()
final_db_size = Path(self._db_path).stat().st_size
freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
if freed_space_in_mb > 0:
self._logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
except Exception as e:
self.logger.error(f"Error cleaning database: {e}")
self._logger.error(f"Error cleaning database: {e}")
raise
@contextmanager
def transaction(self) -> Generator[sqlite3.Cursor, None, None]:
"""
Thread-safe context manager for DB work.
Acquires the RLock, yields a Cursor, then commits or rolls back.
"""
with self._lock:
cursor = self._conn.cursor()
try:
yield cursor
self._conn.commit()
except:
self._conn.rollback()
raise
finally:
cursor.close()

View File

@@ -22,6 +22,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_16 import
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_17 import build_migration_17
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_18 import build_migration_18
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_19 import build_migration_19
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_20 import build_migration_20
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@@ -61,6 +62,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_17())
migrator.register_migration(build_migration_18())
migrator.register_migration(build_migration_19(app_config=config))
migrator.register_migration(build_migration_20())
migrator.run_migrations()
return db

View File

@@ -0,0 +1,37 @@
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration20Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
cursor.execute(
"""
-- many-to-many relationship table for models
CREATE TABLE IF NOT EXISTS model_relationships (
-- model_key_1 and model_key_2 are the same as the key(primary key) in the models table
model_key_1 TEXT NOT NULL,
model_key_2 TEXT NOT NULL,
created_at TEXT DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
PRIMARY KEY (model_key_1, model_key_2),
-- model_key_1 < model_key_2, to ensure uniqueness and prevent duplicates
FOREIGN KEY (model_key_1) REFERENCES models(id) ON DELETE CASCADE,
FOREIGN KEY (model_key_2) REFERENCES models(id) ON DELETE CASCADE
);
"""
)
cursor.execute(
"""
-- Creates an index to keep performance equal when searching for model_key_1 or model_key_2
CREATE INDEX IF NOT EXISTS keyx_model_relationships_model_key_2
ON model_relationships(model_key_2)
"""
)
def build_migration_20() -> Migration:
return Migration(
from_version=19,
to_version=20,
callback=Migration20Callback(),
)

View File

@@ -32,7 +32,7 @@ class SqliteMigrator:
def __init__(self, db: SqliteDatabase) -> None:
self._db = db
self._logger = db.logger
self._logger = db._logger
self._migration_set = MigrationSet()
self._backup_path: Optional[Path] = None
@@ -45,7 +45,7 @@ class SqliteMigrator:
"""Migrates the database to the latest version."""
# This throws if there is a problem.
self._migration_set.validate_migration_chain()
cursor = self._db.conn.cursor()
cursor = self._db._conn.cursor()
self._create_migrations_table(cursor=cursor)
if self._migration_set.count == 0:
@@ -59,13 +59,13 @@ class SqliteMigrator:
self._logger.info("Database update needed")
# Make a backup of the db if it needs to be updated and is a file db
if self._db.db_path is not None:
if self._db._db_path is not None:
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
self._backup_path = self._db.db_path.parent / f"{self._db.db_path.stem}_backup_{timestamp}.db"
self._backup_path = self._db._db_path.parent / f"{self._db._db_path.stem}_backup_{timestamp}.db"
self._logger.info(f"Backing up database to {str(self._backup_path)}")
# Use SQLite to do the backup
with closing(sqlite3.connect(self._backup_path)) as backup_conn:
self._db.conn.backup(backup_conn)
self._db._conn.backup(backup_conn)
else:
self._logger.info("Using in-memory database, no backup needed")
@@ -81,7 +81,7 @@ class SqliteMigrator:
try:
# Using sqlite3.Connection as a context manager commits a the transaction on exit, or rolls it back if an
# exception is raised.
with self._db.conn as conn:
with self._db._conn as conn:
cursor = conn.cursor()
if self._get_current_version(cursor) != migration.from_version:
raise MigrationError(

View File

@@ -17,7 +17,7 @@ from invokeai.app.util.misc import uuid_string
class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._conn = db.conn
self._db = db
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
@@ -25,24 +25,23 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
def get(self, style_preset_id: str) -> StylePresetRecordDTO:
"""Gets a style preset by ID."""
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM style_presets
WHERE id = ?;
""",
(style_preset_id,),
)
row = cursor.fetchone()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT *
FROM style_presets
WHERE id = ?;
""",
(style_preset_id,),
)
row = cursor.fetchone()
if row is None:
raise StylePresetNotFoundError(f"Style preset with id {style_preset_id} not found")
return StylePresetRecordDTO.from_dict(dict(row))
def create(self, style_preset: StylePresetWithoutId) -> StylePresetRecordDTO:
style_preset_id = uuid_string()
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
INSERT OR IGNORE INTO style_presets (
@@ -60,16 +59,11 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
style_preset.type,
),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return self.get(style_preset_id)
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
style_preset_ids = []
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
for style_preset in style_presets:
style_preset_id = uuid_string()
style_preset_ids.append(style_preset_id)
@@ -90,16 +84,11 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
style_preset.type,
),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return None
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
# Change the name of a style preset
if changes.name is not None:
cursor.execute(
@@ -122,15 +111,10 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
(changes.preset_data.model_dump_json(), style_preset_id),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return self.get(style_preset_id)
def delete(self, style_preset_id: str) -> None:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
DELETE from style_presets
@@ -138,51 +122,41 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
""",
(style_preset_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return None
def get_many(self, type: PresetType | None = None) -> list[StylePresetRecordDTO]:
main_query = """
SELECT
*
FROM style_presets
"""
with self._db.transaction() as cursor:
main_query = """
SELECT
*
FROM style_presets
"""
if type is not None:
main_query += "WHERE type = ? "
if type is not None:
main_query += "WHERE type = ? "
main_query += "ORDER BY LOWER(name) ASC"
main_query += "ORDER BY LOWER(name) ASC"
cursor = self._conn.cursor()
if type is not None:
cursor.execute(main_query, (type,))
else:
cursor.execute(main_query)
if type is not None:
cursor.execute(main_query, (type,))
else:
cursor.execute(main_query)
rows = cursor.fetchall()
rows = cursor.fetchall()
style_presets = [StylePresetRecordDTO.from_dict(dict(row)) for row in rows]
return style_presets
def _sync_default_style_presets(self) -> None:
"""Syncs default style presets to the database. Internal use only."""
# First delete all existing default style presets
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
# First delete all existing default style presets
cursor.execute(
"""--sql
DELETE FROM style_presets
WHERE type = "default";
"""
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
# Next, parse and create the default style presets
with open(Path(__file__).parent / Path("default_style_presets.json"), "r") as file:
presets = json.load(file)

View File

@@ -25,7 +25,7 @@ SQL_TIME_FORMAT = "%Y-%m-%d %H:%M:%f"
class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._conn = db.conn
self._db = db
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
@@ -33,16 +33,16 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
def get(self, workflow_id: str) -> WorkflowRecordDTO:
"""Gets a workflow by ID. Updates the opened_at column."""
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT workflow_id, workflow, name, created_at, updated_at, opened_at
FROM workflow_library
WHERE workflow_id = ?;
""",
(workflow_id,),
)
row = cursor.fetchone()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT workflow_id, workflow, name, created_at, updated_at, opened_at
FROM workflow_library
WHERE workflow_id = ?;
""",
(workflow_id,),
)
row = cursor.fetchone()
if row is None:
raise WorkflowNotFoundError(f"Workflow with id {workflow_id} not found")
return WorkflowRecordDTO.from_dict(dict(row))
@@ -51,9 +51,8 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
if workflow.meta.category is WorkflowCategory.Default:
raise ValueError("Default workflows cannot be created via this method")
try:
with self._db.transaction() as cursor:
workflow_with_id = Workflow(**workflow.model_dump(), id=uuid_string())
cursor = self._conn.cursor()
cursor.execute(
"""--sql
INSERT OR IGNORE INTO workflow_library (
@@ -64,18 +63,13 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
""",
(workflow_with_id.id, workflow_with_id.model_dump_json()),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return self.get(workflow_with_id.id)
def update(self, workflow: Workflow) -> WorkflowRecordDTO:
if workflow.meta.category is WorkflowCategory.Default:
raise ValueError("Default workflows cannot be updated")
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
UPDATE workflow_library
@@ -84,18 +78,13 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
""",
(workflow.model_dump_json(), workflow.id),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return self.get(workflow.id)
def delete(self, workflow_id: str) -> None:
if self.get(workflow_id).workflow.meta.category is WorkflowCategory.Default:
raise ValueError("Default workflows cannot be deleted")
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
DELETE from workflow_library
@@ -103,10 +92,6 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
""",
(workflow_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return None
def get_many(
@@ -121,108 +106,108 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
has_been_opened: Optional[bool] = None,
is_published: Optional[bool] = None,
) -> PaginatedResults[WorkflowRecordListItemDTO]:
# sanitize!
assert order_by in WorkflowRecordOrderBy
assert direction in SQLiteDirection
with self._db.transaction() as cursor:
# sanitize!
assert order_by in WorkflowRecordOrderBy
assert direction in SQLiteDirection
# We will construct the query dynamically based on the query params
# We will construct the query dynamically based on the query params
# The main query to get the workflows / counts
main_query = """
SELECT
workflow_id,
category,
name,
description,
created_at,
updated_at,
opened_at,
tags
FROM workflow_library
"""
count_query = "SELECT COUNT(*) FROM workflow_library"
# The main query to get the workflows / counts
main_query = """
SELECT
workflow_id,
category,
name,
description,
created_at,
updated_at,
opened_at,
tags
FROM workflow_library
"""
count_query = "SELECT COUNT(*) FROM workflow_library"
# Start with an empty list of conditions and params
conditions: list[str] = []
params: list[str | int] = []
# Start with an empty list of conditions and params
conditions: list[str] = []
params: list[str | int] = []
if categories:
# Categories is a list of WorkflowCategory enum values, and a single string in the DB
if categories:
# Categories is a list of WorkflowCategory enum values, and a single string in the DB
# Ensure all categories are valid (is this necessary?)
assert all(c in WorkflowCategory for c in categories)
# Ensure all categories are valid (is this necessary?)
assert all(c in WorkflowCategory for c in categories)
# Construct a placeholder string for the number of categories
placeholders = ", ".join("?" for _ in categories)
# Construct a placeholder string for the number of categories
placeholders = ", ".join("?" for _ in categories)
# Construct the condition string & params
category_condition = f"category IN ({placeholders})"
category_params = [category.value for category in categories]
# Construct the condition string & params
category_condition = f"category IN ({placeholders})"
category_params = [category.value for category in categories]
conditions.append(category_condition)
params.extend(category_params)
conditions.append(category_condition)
params.extend(category_params)
if tags:
# Tags is a list of strings, and a single string in the DB
# The string in the DB has no guaranteed format
if tags:
# Tags is a list of strings, and a single string in the DB
# The string in the DB has no guaranteed format
# Construct a list of conditions for each tag
tags_conditions = ["tags LIKE ?" for _ in tags]
tags_conditions_joined = " OR ".join(tags_conditions)
tags_condition = f"({tags_conditions_joined})"
# Construct a list of conditions for each tag
tags_conditions = ["tags LIKE ?" for _ in tags]
tags_conditions_joined = " OR ".join(tags_conditions)
tags_condition = f"({tags_conditions_joined})"
# And the params for the tags, case-insensitive
tags_params = [f"%{t.strip()}%" for t in tags]
# And the params for the tags, case-insensitive
tags_params = [f"%{t.strip()}%" for t in tags]
conditions.append(tags_condition)
params.extend(tags_params)
conditions.append(tags_condition)
params.extend(tags_params)
if has_been_opened:
conditions.append("opened_at IS NOT NULL")
elif has_been_opened is False:
conditions.append("opened_at IS NULL")
if has_been_opened:
conditions.append("opened_at IS NOT NULL")
elif has_been_opened is False:
conditions.append("opened_at IS NULL")
# Ignore whitespace in the query
stripped_query = query.strip() if query else None
if stripped_query:
# Construct a wildcard query for the name, description, and tags
wildcard_query = "%" + stripped_query + "%"
query_condition = "(name LIKE ? OR description LIKE ? OR tags LIKE ?)"
# Ignore whitespace in the query
stripped_query = query.strip() if query else None
if stripped_query:
# Construct a wildcard query for the name, description, and tags
wildcard_query = "%" + stripped_query + "%"
query_condition = "(name LIKE ? OR description LIKE ? OR tags LIKE ?)"
conditions.append(query_condition)
params.extend([wildcard_query, wildcard_query, wildcard_query])
conditions.append(query_condition)
params.extend([wildcard_query, wildcard_query, wildcard_query])
if conditions:
# If there are conditions, add a WHERE clause and then join the conditions
main_query += " WHERE "
count_query += " WHERE "
if conditions:
# If there are conditions, add a WHERE clause and then join the conditions
main_query += " WHERE "
count_query += " WHERE "
all_conditions = " AND ".join(conditions)
main_query += all_conditions
count_query += all_conditions
all_conditions = " AND ".join(conditions)
main_query += all_conditions
count_query += all_conditions
# After this point, the query and params differ for the main query and the count query
main_params = params.copy()
count_params = params.copy()
# After this point, the query and params differ for the main query and the count query
main_params = params.copy()
count_params = params.copy()
# Main query also gets ORDER BY and LIMIT/OFFSET
main_query += f" ORDER BY {order_by.value} {direction.value}"
# Main query also gets ORDER BY and LIMIT/OFFSET
main_query += f" ORDER BY {order_by.value} {direction.value}"
if per_page:
main_query += " LIMIT ? OFFSET ?"
main_params.extend([per_page, page * per_page])
if per_page:
main_query += " LIMIT ? OFFSET ?"
main_params.extend([per_page, page * per_page])
# Put a ring on it
main_query += ";"
count_query += ";"
# Put a ring on it
main_query += ";"
count_query += ";"
cursor = self._conn.cursor()
cursor.execute(main_query, main_params)
rows = cursor.fetchall()
workflows = [WorkflowRecordListItemDTOValidator.validate_python(dict(row)) for row in rows]
cursor.execute(main_query, main_params)
rows = cursor.fetchall()
workflows = [WorkflowRecordListItemDTOValidator.validate_python(dict(row)) for row in rows]
cursor.execute(count_query, count_params)
total = cursor.fetchone()[0]
cursor.execute(count_query, count_params)
total = cursor.fetchone()[0]
if per_page:
pages = total // per_page + (total % per_page > 0)
@@ -247,46 +232,46 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
if not tags:
return {}
cursor = self._conn.cursor()
result: dict[str, int] = {}
# Base conditions for categories and selected tags
base_conditions: list[str] = []
base_params: list[str | int] = []
with self._db.transaction() as cursor:
result: dict[str, int] = {}
# Base conditions for categories and selected tags
base_conditions: list[str] = []
base_params: list[str | int] = []
# Add category conditions
if categories:
assert all(c in WorkflowCategory for c in categories)
placeholders = ", ".join("?" for _ in categories)
base_conditions.append(f"category IN ({placeholders})")
base_params.extend([category.value for category in categories])
# Add category conditions
if categories:
assert all(c in WorkflowCategory for c in categories)
placeholders = ", ".join("?" for _ in categories)
base_conditions.append(f"category IN ({placeholders})")
base_params.extend([category.value for category in categories])
if has_been_opened:
base_conditions.append("opened_at IS NOT NULL")
elif has_been_opened is False:
base_conditions.append("opened_at IS NULL")
if has_been_opened:
base_conditions.append("opened_at IS NOT NULL")
elif has_been_opened is False:
base_conditions.append("opened_at IS NULL")
# For each tag to count, run a separate query
for tag in tags:
# Start with the base conditions
conditions = base_conditions.copy()
params = base_params.copy()
# For each tag to count, run a separate query
for tag in tags:
# Start with the base conditions
conditions = base_conditions.copy()
params = base_params.copy()
# Add this specific tag condition
conditions.append("tags LIKE ?")
params.append(f"%{tag.strip()}%")
# Add this specific tag condition
conditions.append("tags LIKE ?")
params.append(f"%{tag.strip()}%")
# Construct the full query
stmt = """--sql
SELECT COUNT(*)
FROM workflow_library
"""
# Construct the full query
stmt = """--sql
SELECT COUNT(*)
FROM workflow_library
"""
if conditions:
stmt += " WHERE " + " AND ".join(conditions)
if conditions:
stmt += " WHERE " + " AND ".join(conditions)
cursor.execute(stmt, params)
count = cursor.fetchone()[0]
result[tag] = count
cursor.execute(stmt, params)
count = cursor.fetchone()[0]
result[tag] = count
return result
@@ -296,52 +281,51 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
has_been_opened: Optional[bool] = None,
is_published: Optional[bool] = None,
) -> dict[str, int]:
cursor = self._conn.cursor()
result: dict[str, int] = {}
# Base conditions for categories
base_conditions: list[str] = []
base_params: list[str | int] = []
with self._db.transaction() as cursor:
result: dict[str, int] = {}
# Base conditions for categories
base_conditions: list[str] = []
base_params: list[str | int] = []
# Add category conditions
if categories:
assert all(c in WorkflowCategory for c in categories)
placeholders = ", ".join("?" for _ in categories)
base_conditions.append(f"category IN ({placeholders})")
base_params.extend([category.value for category in categories])
# Add category conditions
if categories:
assert all(c in WorkflowCategory for c in categories)
placeholders = ", ".join("?" for _ in categories)
base_conditions.append(f"category IN ({placeholders})")
base_params.extend([category.value for category in categories])
if has_been_opened:
base_conditions.append("opened_at IS NOT NULL")
elif has_been_opened is False:
base_conditions.append("opened_at IS NULL")
if has_been_opened:
base_conditions.append("opened_at IS NOT NULL")
elif has_been_opened is False:
base_conditions.append("opened_at IS NULL")
# For each category to count, run a separate query
for category in categories:
# Start with the base conditions
conditions = base_conditions.copy()
params = base_params.copy()
# For each category to count, run a separate query
for category in categories:
# Start with the base conditions
conditions = base_conditions.copy()
params = base_params.copy()
# Add this specific category condition
conditions.append("category = ?")
params.append(category.value)
# Add this specific category condition
conditions.append("category = ?")
params.append(category.value)
# Construct the full query
stmt = """--sql
SELECT COUNT(*)
FROM workflow_library
"""
# Construct the full query
stmt = """--sql
SELECT COUNT(*)
FROM workflow_library
"""
if conditions:
stmt += " WHERE " + " AND ".join(conditions)
if conditions:
stmt += " WHERE " + " AND ".join(conditions)
cursor.execute(stmt, params)
count = cursor.fetchone()[0]
result[category.value] = count
cursor.execute(stmt, params)
count = cursor.fetchone()[0]
result[category.value] = count
return result
def update_opened_at(self, workflow_id: str) -> None:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
f"""--sql
UPDATE workflow_library
@@ -350,10 +334,6 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
""",
(workflow_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
def _sync_default_workflows(self) -> None:
"""Syncs default workflows to the database. Internal use only."""
@@ -368,8 +348,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
meaningless, as they are overwritten every time the server starts.
"""
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
workflows_from_file: list[Workflow] = []
workflows_to_update: list[Workflow] = []
workflows_to_add: list[Workflow] = []
@@ -449,8 +428,3 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
""",
(w.model_dump_json(), w.id),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise

View File

@@ -230,6 +230,86 @@ def heuristic_resize(np_img: np.ndarray[Any, Any], size: tuple[int, int]) -> np.
return resized
# precompute common kernels
_KERNEL3 = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
# directional masks for NMS
_DIRS = [
np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], np.uint8),
np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], np.uint8),
np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], np.uint8),
np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], np.uint8),
]
def heuristic_resize_fast(np_img: np.ndarray, size: tuple[int, int]) -> np.ndarray:
h, w = np_img.shape[:2]
# early exit
if (w, h) == size:
return np_img
# separate alpha channel
img = np_img
alpha = None
if img.ndim == 3 and img.shape[2] == 4:
alpha, img = img[:, :, 3], img[:, :, :3]
# build small sample for uniquecolor & binary detection
flat = img.reshape(-1, img.shape[-1])
N = flat.shape[0]
# include four corners to avoid missing extreme values
corners = np.vstack([img[0, 0], img[0, w - 1], img[h - 1, 0], img[h - 1, w - 1]])
cnt = min(N, 100_000)
samp = np.vstack([corners, flat[np.random.choice(N, cnt, replace=False)]])
uc = np.unique(samp, axis=0).shape[0]
vmin, vmax = samp.min(), samp.max()
# detect binary edge map & onepixeledge case
is_binary = uc == 2 and vmin < 16 and vmax > 240
one_pixel_edge = False
if is_binary:
# single gray conversion
gray0 = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
grad = cv2.morphologyEx(gray0, cv2.MORPH_GRADIENT, _KERNEL3)
cnt_edge = cv2.countNonZero(grad)
cnt_all = cv2.countNonZero((gray0 > 127).astype(np.uint8))
one_pixel_edge = (2 * cnt_edge) > cnt_all
# choose interp for color/seg/grayscale
area_new, area_old = size[0] * size[1], w * h
if 2 < uc < 200: # segmentation map
interp = cv2.INTER_NEAREST
elif area_new < area_old:
interp = cv2.INTER_AREA
else:
interp = cv2.INTER_CUBIC
# single resize pass on RGB
resized = cv2.resize(img, size, interpolation=interp)
if is_binary:
# convert to gray & apply NMS via C++ dilate
gray_r = cv2.cvtColor(resized, cv2.COLOR_BGR2GRAY)
nms = np.zeros_like(gray_r)
for K in _DIRS:
d = cv2.dilate(gray_r, K)
mask = d == gray_r
nms[mask] = gray_r[mask]
# threshold + thinning if needed
_, bw = cv2.threshold(nms, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
out_bin = cv2.ximgproc.thinning(bw) if one_pixel_edge else bw
# restore 3 channels
resized = np.stack([out_bin] * 3, axis=2)
# restore alpha with same interp as RGB for consistency
if alpha is not None:
am = cv2.resize(alpha, size, interpolation=interp)
am = (am > 127).astype(np.uint8) * 255
resized = np.dstack((resized, am))
return resized
###########################################################################
# Copied from detectmap_proc method in scripts/detectmap_proc.py in Mikubill/sd-webui-controlnet
# modified for InvokeAI
@@ -244,7 +324,7 @@ def np_img_resize(
np_img = normalize_image_channel_count(np_img)
if resize_mode == "just_resize": # RESIZE
np_img = heuristic_resize(np_img, (w, h))
np_img = heuristic_resize_fast(np_img, (w, h))
np_img = clone_contiguous(np_img)
return np_img_to_torch(np_img, device), np_img
@@ -265,7 +345,7 @@ def np_img_resize(
# Inpaint hijack
high_quality_border_color[3] = 255
high_quality_background = np.tile(high_quality_border_color[None, None], [h, w, 1])
np_img = heuristic_resize(np_img, (safeint(old_w * k), safeint(old_h * k)))
np_img = heuristic_resize_fast(np_img, (safeint(old_w * k), safeint(old_h * k)))
new_h, new_w, _ = np_img.shape
pad_h = max(0, (h - new_h) // 2)
pad_w = max(0, (w - new_w) // 2)
@@ -275,7 +355,7 @@ def np_img_resize(
return np_img_to_torch(np_img, device), np_img
else: # resize_mode == "crop_resize" (INNER_FIT)
k = max(k0, k1)
np_img = heuristic_resize(np_img, (safeint(old_w * k), safeint(old_h * k)))
np_img = heuristic_resize_fast(np_img, (safeint(old_w * k), safeint(old_h * k)))
new_h, new_w, _ = np_img.shape
pad_h = max(0, (new_h - h) // 2)
pad_w = max(0, (new_w - w) // 2)

View File

@@ -12,6 +12,9 @@ from invokeai.app.invocations.fields import InputFieldJSONSchemaExtra, OutputFie
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.events.events_common import EventBase
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.backend.util.logging import InvokeAILogger
logger = InvokeAILogger.get_logger()
def move_defs_to_top_level(openapi_schema: dict[str, Any], component_schema: dict[str, Any]) -> None:
@@ -61,6 +64,10 @@ def get_openapi_func(
# We need to manually add all outputs to the schema - pydantic doesn't add them because they aren't used directly.
for output in InvocationRegistry.get_output_classes():
json_schema = output.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}")
# Remove output_metadata that is only used on back-end from the schema
if "output_meta" in json_schema["properties"]:
json_schema["properties"].pop("output_meta")
move_defs_to_top_level(openapi_schema, json_schema)
openapi_schema["components"]["schemas"][output.__name__] = json_schema

View File

@@ -10,7 +10,7 @@ def get_timestamp() -> int:
def get_iso_timestamp() -> str:
return datetime.datetime.utcnow().isoformat()
return datetime.datetime.now(datetime.timezone.utc).isoformat()
def get_datetime_from_iso_timestamp(iso_timestamp: str) -> datetime.datetime:

View File

@@ -123,7 +123,11 @@ def calc_percentage(intermediate_state: PipelineIntermediateState) -> float:
if total_steps == 0:
return 0.0
if order == 2:
return floor(step / 2) / floor(total_steps / 2)
# Prevent division by zero when total_steps is 1 or 2
denominator = floor(total_steps / 2)
if denominator == 0:
return 0.0
return floor(step / 2) / denominator
# order == 1
return step / total_steps

View File

View File

@@ -0,0 +1,314 @@
import math
import os
from typing import List, Optional, Union
import numpy as np
import torch
import torch.distributed as dist
from diffusers.utils import logging
from transformers import (
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPTokenizer,
T5EncoderModel,
T5TokenizerFast,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def get_t5_prompt_embeds(
tokenizer: T5TokenizerFast,
text_encoder: T5EncoderModel,
prompt: Union[str, List[str], None] = None,
num_images_per_prompt: int = 1,
max_sequence_length: int = 128,
device: Optional[torch.device] = None,
):
device = device or text_encoder.device
if prompt is None:
prompt = ""
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = tokenizer(
prompt,
# padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
# Concat zeros to max_sequence
b, seq_len, dim = prompt_embeds.shape
if seq_len < max_sequence_length:
padding = torch.zeros(
(b, max_sequence_length - seq_len, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device
)
prompt_embeds = torch.concat([prompt_embeds, padding], dim=1)
prompt_embeds = prompt_embeds.to(device=device)
_, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds
# in order the get the same sigmas as in training and sample from them
def get_original_sigmas(num_train_timesteps=1000, num_inference_steps=1000):
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
sigmas = timesteps / num_train_timesteps
inds = [int(ind) for ind in np.linspace(0, num_train_timesteps - 1, num_inference_steps)]
new_sigmas = sigmas[inds]
return new_sigmas
def is_ng_none(negative_prompt):
return (
negative_prompt is None
or negative_prompt == ""
or (isinstance(negative_prompt, list) and negative_prompt[0] is None)
or (isinstance(negative_prompt, list) and negative_prompt[0] == "")
)
class CudaTimerContext:
def __init__(self, times_arr):
self.times_arr = times_arr
def __enter__(self):
self.before_event = torch.cuda.Event(enable_timing=True)
self.after_event = torch.cuda.Event(enable_timing=True)
self.before_event.record()
def __exit__(self, type, value, traceback):
self.after_event.record()
torch.cuda.synchronize()
elapsed_time = self.before_event.elapsed_time(self.after_event) / 1000
self.times_arr.append(elapsed_time)
def get_env_prefix():
env = os.environ.get("CLOUD_PROVIDER", "AWS").upper()
if env == "AWS":
return "SM_CHANNEL"
elif env == "AZURE":
return "AZUREML_DATAREFERENCE"
raise Exception(f"Env {env} not supported")
def compute_density_for_timestep_sampling(
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
):
"""Compute the density for sampling the timesteps when doing SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
u = torch.nn.functional.sigmoid(u)
elif weighting_scheme == "mode":
u = torch.rand(size=(batch_size,), device="cpu")
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
else:
u = torch.rand(size=(batch_size,), device="cpu")
return u
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
"""Computes loss weighting scheme for SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "sigma_sqrt":
weighting = (sigmas**-2.0).float()
elif weighting_scheme == "cosmap":
bot = 1 - 2 * sigmas + 2 * sigmas**2
weighting = 2 / (math.pi * bot)
else:
weighting = torch.ones_like(sigmas)
return weighting
def initialize_distributed():
# Initialize the process group for distributed training
dist.init_process_group("nccl")
# Get the current process's rank (ID) and the total number of processes (world size)
rank = dist.get_rank()
world_size = dist.get_world_size()
print(f"Initialized distributed training: Rank {rank}/{world_size}")
def get_clip_prompt_embeds(
text_encoder: CLIPTextModel,
text_encoder_2: CLIPTextModelWithProjection,
tokenizer: CLIPTokenizer,
tokenizer_2: CLIPTokenizer,
prompt: Union[str, List[str]] = None,
num_images_per_prompt: int = 1,
max_sequence_length: int = 77,
device: Optional[torch.device] = None,
):
device = device or text_encoder.device
assert max_sequence_length == tokenizer.model_max_length
prompt = [prompt] if isinstance(prompt, str) else prompt
# Define tokenizers and text encoders
tokenizers = [tokenizer, tokenizer_2]
text_encoders = [text_encoder, text_encoder_2]
# textual inversion: process multi-vector tokens if necessary
prompt_embeds_list = []
prompts = [prompt, prompt]
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders, strict=False):
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
prompt_embeds_list.append(prompt_embeds)
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
return prompt_embeds, pooled_prompt_embeds
def get_1d_rotary_pos_embed(
dim: int,
pos: Union[np.ndarray, int],
theta: float = 10000.0,
use_real=False,
linear_factor=1.0,
ntk_factor=1.0,
repeat_interleave_real=True,
freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
):
"""
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
data type.
Args:
dim (`int`): Dimension of the frequency tensor.
pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
theta (`float`, *optional*, defaults to 10000.0):
Scaling factor for frequency computation. Defaults to 10000.0.
use_real (`bool`, *optional*):
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
linear_factor (`float`, *optional*, defaults to 1.0):
Scaling factor for the context extrapolation. Defaults to 1.0.
ntk_factor (`float`, *optional*, defaults to 1.0):
Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
repeat_interleave_real (`bool`, *optional*, defaults to `True`):
If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
Otherwise, they are concateanted with themselves.
freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
the dtype of the frequency tensor.
Returns:
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
"""
assert dim % 2 == 0
if isinstance(pos, int):
pos = torch.arange(pos)
if isinstance(pos, np.ndarray):
pos = torch.from_numpy(pos) # type: ignore # [S]
theta = theta * ntk_factor
freqs = (
1.0
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
/ linear_factor
) # [D/2]
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
if use_real and repeat_interleave_real:
# flux, hunyuan-dit, cogvideox
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
return freqs_cos, freqs_sin
elif use_real:
# stable audio, allegro
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
return freqs_cos, freqs_sin
else:
# lumina
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
return freqs_cis
class FluxPosEmbed(torch.nn.Module):
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
def __init__(self, theta: int, axes_dim: List[int]):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
cos_out = []
sin_out = []
pos = ids.float()
is_mps = ids.device.type == "mps"
freqs_dtype = torch.float32 if is_mps else torch.float64
for i in range(n_axes):
cos, sin = get_1d_rotary_pos_embed(
self.axes_dim[i],
pos[:, i],
theta=self.theta,
repeat_interleave_real=True,
use_real=True,
freqs_dtype=freqs_dtype,
)
cos_out.append(cos)
sin_out.append(sin)
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
return freqs_cos, freqs_sin

View File

@@ -0,0 +1,6 @@
__version__ = "0.0.9"
from invokeai.backend.bria.controlnet_aux.canny import CannyDetector as CannyDetector
from invokeai.backend.bria.controlnet_aux.open_pose import OpenposeDetector as OpenposeDetector
__all__ = ["CannyDetector", "OpenposeDetector"]

View File

@@ -0,0 +1,48 @@
import warnings
import cv2
import numpy as np
from PIL import Image
from invokeai.backend.bria.controlnet_aux.util import HWC3, resize_image
class CannyDetector:
def __call__(
self,
input_image=None,
low_threshold=100,
high_threshold=200,
detect_resolution=512,
image_resolution=512,
output_type=None,
**kwargs,
):
if "img" in kwargs:
warnings.warn("img is deprecated, please use `input_image=...` instead.", DeprecationWarning, stacklevel=2)
input_image = kwargs.pop("img")
if input_image is None:
raise ValueError("input_image must be defined.")
if not isinstance(input_image, np.ndarray):
input_image = np.array(input_image, dtype=np.uint8)
output_type = output_type or "pil"
else:
output_type = output_type or "np"
input_image = HWC3(input_image)
input_image = resize_image(input_image, detect_resolution)
detected_map = cv2.Canny(input_image, low_threshold, high_threshold)
detected_map = HWC3(detected_map)
img = resize_image(input_image, image_resolution)
H, W, C = img.shape
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
if output_type == "pil":
detected_map = Image.fromarray(detected_map)
return detected_map

View File

@@ -0,0 +1,108 @@
OPENPOSE: MULTIPERSON KEYPOINT DETECTION
SOFTWARE LICENSE AGREEMENT
ACADEMIC OR NON-PROFIT ORGANIZATION NONCOMMERCIAL RESEARCH USE ONLY
BY USING OR DOWNLOADING THE SOFTWARE, YOU ARE AGREEING TO THE TERMS OF THIS LICENSE AGREEMENT. IF YOU DO NOT AGREE WITH THESE TERMS, YOU MAY NOT USE OR DOWNLOAD THE SOFTWARE.
This is a license agreement ("Agreement") between your academic institution or non-profit organization or self (called "Licensee" or "You" in this Agreement) and Carnegie Mellon University (called "Licensor" in this Agreement). All rights not specifically granted to you in this Agreement are reserved for Licensor.
RESERVATION OF OWNERSHIP AND GRANT OF LICENSE:
Licensor retains exclusive ownership of any copy of the Software (as defined below) licensed under this Agreement and hereby grants to Licensee a personal, non-exclusive,
non-transferable license to use the Software for noncommercial research purposes, without the right to sublicense, pursuant to the terms and conditions of this Agreement. As used in this Agreement, the term "Software" means (i) the actual copy of all or any portion of code for program routines made accessible to Licensee by Licensor pursuant to this Agreement, inclusive of backups, updates, and/or merged copies permitted hereunder or subsequently supplied by Licensor, including all or any file structures, programming instructions, user interfaces and screen formats and sequences as well as any and all documentation and instructions related to it, and (ii) all or any derivatives and/or modifications created or made by You to any of the items specified in (i).
CONFIDENTIALITY: Licensee acknowledges that the Software is proprietary to Licensor, and as such, Licensee agrees to receive all such materials in confidence and use the Software only in accordance with the terms of this Agreement. Licensee agrees to use reasonable effort to protect the Software from unauthorized use, reproduction, distribution, or publication.
COPYRIGHT: The Software is owned by Licensor and is protected by United
States copyright laws and applicable international treaties and/or conventions.
PERMITTED USES: The Software may be used for your own noncommercial internal research purposes. You understand and agree that Licensor is not obligated to implement any suggestions and/or feedback you might provide regarding the Software, but to the extent Licensor does so, you are not entitled to any compensation related thereto.
DERIVATIVES: You may create derivatives of or make modifications to the Software, however, You agree that all and any such derivatives and modifications will be owned by Licensor and become a part of the Software licensed to You under this Agreement. You may only use such derivatives and modifications for your own noncommercial internal research purposes, and you may not otherwise use, distribute or copy such derivatives and modifications in violation of this Agreement.
BACKUPS: If Licensee is an organization, it may make that number of copies of the Software necessary for internal noncommercial use at a single site within its organization provided that all information appearing in or on the original labels, including the copyright and trademark notices are copied onto the labels of the copies.
USES NOT PERMITTED: You may not distribute, copy or use the Software except as explicitly permitted herein. Licensee has not been granted any trademark license as part of this Agreement and may not use the name or mark “OpenPose", "Carnegie Mellon" or any renditions thereof without the prior written permission of Licensor.
You may not sell, rent, lease, sublicense, lend, time-share or transfer, in whole or in part, or provide third parties access to prior or present versions (or any parts thereof) of the Software.
ASSIGNMENT: You may not assign this Agreement or your rights hereunder without the prior written consent of Licensor. Any attempted assignment without such consent shall be null and void.
TERM: The term of the license granted by this Agreement is from Licensee's acceptance of this Agreement by downloading the Software or by using the Software until terminated as provided below.
The Agreement automatically terminates without notice if you fail to comply with any provision of this Agreement. Licensee may terminate this Agreement by ceasing using the Software. Upon any termination of this Agreement, Licensee will delete any and all copies of the Software. You agree that all provisions which operate to protect the proprietary rights of Licensor shall remain in force should breach occur and that the obligation of confidentiality described in this Agreement is binding in perpetuity and, as such, survives the term of the Agreement.
FEE: Provided Licensee abides completely by the terms and conditions of this Agreement, there is no fee due to Licensor for Licensee's use of the Software in accordance with this Agreement.
DISCLAIMER OF WARRANTIES: THE SOFTWARE IS PROVIDED "AS-IS" WITHOUT WARRANTY OF ANY KIND INCLUDING ANY WARRANTIES OF PERFORMANCE OR MERCHANTABILITY OR FITNESS FOR A PARTICULAR USE OR PURPOSE OR OF NON-INFRINGEMENT. LICENSEE BEARS ALL RISK RELATING TO QUALITY AND PERFORMANCE OF THE SOFTWARE AND RELATED MATERIALS.
SUPPORT AND MAINTENANCE: No Software support or training by the Licensor is provided as part of this Agreement.
EXCLUSIVE REMEDY AND LIMITATION OF LIABILITY: To the maximum extent permitted under applicable law, Licensor shall not be liable for direct, indirect, special, incidental, or consequential damages or lost profits related to Licensee's use of and/or inability to use the Software, even if Licensor is advised of the possibility of such damage.
EXPORT REGULATION: Licensee agrees to comply with any and all applicable
U.S. export control laws, regulations, and/or other laws related to embargoes and sanction programs administered by the Office of Foreign Assets Control.
SEVERABILITY: If any provision(s) of this Agreement shall be held to be invalid, illegal, or unenforceable by a court or other tribunal of competent jurisdiction, the validity, legality and enforceability of the remaining provisions shall not in any way be affected or impaired thereby.
NO IMPLIED WAIVERS: No failure or delay by Licensor in enforcing any right or remedy under this Agreement shall be construed as a waiver of any future or other exercise of such right or remedy by Licensor.
GOVERNING LAW: This Agreement shall be construed and enforced in accordance with the laws of the Commonwealth of Pennsylvania without reference to conflict of laws principles. You consent to the personal jurisdiction of the courts of this County and waive their rights to venue outside of Allegheny County, Pennsylvania.
ENTIRE AGREEMENT AND AMENDMENTS: This Agreement constitutes the sole and entire agreement between Licensee and Licensor as to the matter set forth herein and supersedes any previous agreements, understandings, and arrangements between the parties relating hereto.
************************************************************************
THIRD-PARTY SOFTWARE NOTICES AND INFORMATION
This project incorporates material from the project(s) listed below (collectively, "Third Party Code"). This Third Party Code is licensed to you under their original license terms set forth below. We reserves all other rights not expressly granted, whether by implication, estoppel or otherwise.
1. Caffe, version 1.0.0, (https://github.com/BVLC/caffe/)
COPYRIGHT
All contributions by the University of California:
Copyright (c) 2014-2017 The Regents of the University of California (Regents)
All rights reserved.
All other contributions:
Copyright (c) 2014-2017, the respective contributors
All rights reserved.
Caffe uses a shared copyright model: each contributor holds copyright over
their contributions to Caffe. The project versioning records all such
contribution and copyright details. If a contributor wants to further mark
their specific copyright on a particular contribution, they should indicate
their copyright solely in the commit message of the change when it is
committed.
LICENSE
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
CONTRIBUTION AGREEMENT
By contributing to the BVLC/caffe repository through pull-request, comment,
or otherwise, the contributor releases their content to the
license and copyright terms herein.
************END OF THIRD-PARTY SOFTWARE NOTICES AND INFORMATION**********

View File

@@ -0,0 +1,267 @@
# Openpose
# Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose
# 2nd Edited by https://github.com/Hzzone/pytorch-openpose
# 3rd Edited by ControlNet
# 4th Edited by ControlNet (added face and correct hands)
# 5th Edited by ControlNet (Improved JSON serialization/deserialization, and lots of bug fixs)
# This preprocessor is licensed by CMU for non-commercial use only.
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
import warnings
from typing import List, NamedTuple, Tuple, Union
import cv2
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from invokeai.backend.bria.controlnet_aux.open_pose import util
from invokeai.backend.bria.controlnet_aux.open_pose.body import Body, BodyResult, Keypoint
from invokeai.backend.bria.controlnet_aux.open_pose.face import Face
from invokeai.backend.bria.controlnet_aux.open_pose.hand import Hand
from invokeai.backend.bria.controlnet_aux.util import HWC3, resize_image
HandResult = List[Keypoint]
FaceResult = List[Keypoint]
class PoseResult(NamedTuple):
body: BodyResult
left_hand: Union[HandResult, None]
right_hand: Union[HandResult, None]
face: Union[FaceResult, None]
def draw_poses(poses: List[PoseResult], H, W, draw_body=True, draw_hand=True, draw_face=True):
"""
Draw the detected poses on an empty canvas.
Args:
poses (List[PoseResult]): A list of PoseResult objects containing the detected poses.
H (int): The height of the canvas.
W (int): The width of the canvas.
draw_body (bool, optional): Whether to draw body keypoints. Defaults to True.
draw_hand (bool, optional): Whether to draw hand keypoints. Defaults to True.
draw_face (bool, optional): Whether to draw face keypoints. Defaults to True.
Returns:
numpy.ndarray: A 3D numpy array representing the canvas with the drawn poses.
"""
canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
for pose in poses:
if draw_body:
canvas = util.draw_bodypose(canvas, pose.body.keypoints)
if draw_hand:
canvas = util.draw_handpose(canvas, pose.left_hand)
canvas = util.draw_handpose(canvas, pose.right_hand)
if draw_face:
canvas = util.draw_facepose(canvas, pose.face)
return canvas
class OpenposeDetector:
"""
A class for detecting human poses in images using the Openpose model.
Attributes:
model_dir (str): Path to the directory where the pose models are stored.
"""
def __init__(self, body_estimation, hand_estimation=None, face_estimation=None):
self.body_estimation = body_estimation
self.hand_estimation = hand_estimation
self.face_estimation = face_estimation
@classmethod
def from_pretrained(
cls,
pretrained_model_or_path,
filename=None,
hand_filename=None,
face_filename=None,
cache_dir=None,
local_files_only=False,
):
if pretrained_model_or_path == "lllyasviel/ControlNet":
filename = filename or "annotator/ckpts/body_pose_model.pth"
hand_filename = hand_filename or "annotator/ckpts/hand_pose_model.pth"
face_filename = face_filename or "facenet.pth"
face_pretrained_model_or_path = "lllyasviel/Annotators"
else:
filename = filename or "body_pose_model.pth"
hand_filename = hand_filename or "hand_pose_model.pth"
face_filename = face_filename or "facenet.pth"
face_pretrained_model_or_path = pretrained_model_or_path
if os.path.isdir(pretrained_model_or_path):
body_model_path = os.path.join(pretrained_model_or_path, filename)
hand_model_path = os.path.join(pretrained_model_or_path, hand_filename)
face_model_path = os.path.join(face_pretrained_model_or_path, face_filename)
else:
body_model_path = hf_hub_download(
pretrained_model_or_path, filename, cache_dir=cache_dir, local_files_only=local_files_only
)
hand_model_path = hf_hub_download(
pretrained_model_or_path, hand_filename, cache_dir=cache_dir, local_files_only=local_files_only
)
face_model_path = hf_hub_download(
face_pretrained_model_or_path, face_filename, cache_dir=cache_dir, local_files_only=local_files_only
)
body_estimation = Body(body_model_path)
hand_estimation = Hand(hand_model_path)
face_estimation = Face(face_model_path)
return cls(body_estimation, hand_estimation, face_estimation)
def to(self, device):
self.body_estimation.to(device)
self.hand_estimation.to(device)
self.face_estimation.to(device)
return self
def detect_hands(self, body: BodyResult, oriImg) -> Tuple[Union[HandResult, None], Union[HandResult, None]]:
left_hand = None
right_hand = None
H, W, _ = oriImg.shape
for x, y, w, is_left in util.handDetect(body, oriImg):
peaks = self.hand_estimation(oriImg[y : y + w, x : x + w, :]).astype(np.float32)
if peaks.ndim == 2 and peaks.shape[1] == 2:
peaks[:, 0] = np.where(peaks[:, 0] < 1e-6, -1, peaks[:, 0] + x) / float(W)
peaks[:, 1] = np.where(peaks[:, 1] < 1e-6, -1, peaks[:, 1] + y) / float(H)
hand_result = [Keypoint(x=peak[0], y=peak[1]) for peak in peaks]
if is_left:
left_hand = hand_result
else:
right_hand = hand_result
return left_hand, right_hand
def detect_face(self, body: BodyResult, oriImg) -> Union[FaceResult, None]:
face = util.faceDetect(body, oriImg)
if face is None:
return None
x, y, w = face
H, W, _ = oriImg.shape
heatmaps = self.face_estimation(oriImg[y : y + w, x : x + w, :])
peaks = self.face_estimation.compute_peaks_from_heatmaps(heatmaps).astype(np.float32)
if peaks.ndim == 2 and peaks.shape[1] == 2:
peaks[:, 0] = np.where(peaks[:, 0] < 1e-6, -1, peaks[:, 0] + x) / float(W)
peaks[:, 1] = np.where(peaks[:, 1] < 1e-6, -1, peaks[:, 1] + y) / float(H)
return [Keypoint(x=peak[0], y=peak[1]) for peak in peaks]
return None
def detect_poses(self, oriImg, include_hand=False, include_face=False) -> List[PoseResult]:
"""
Detect poses in the given image.
Args:
oriImg (numpy.ndarray): The input image for pose detection.
include_hand (bool, optional): Whether to include hand detection. Defaults to False.
include_face (bool, optional): Whether to include face detection. Defaults to False.
Returns:
List[PoseResult]: A list of PoseResult objects containing the detected poses.
"""
oriImg = oriImg[:, :, ::-1].copy()
H, W, C = oriImg.shape
with torch.no_grad():
candidate, subset = self.body_estimation(oriImg)
bodies = self.body_estimation.format_body_result(candidate, subset)
results = []
for body in bodies:
left_hand, right_hand, face = (None,) * 3
if include_hand:
left_hand, right_hand = self.detect_hands(body, oriImg)
if include_face:
face = self.detect_face(body, oriImg)
results.append(
PoseResult(
BodyResult(
keypoints=[
Keypoint(x=keypoint.x / float(W), y=keypoint.y / float(H))
if keypoint is not None
else None
for keypoint in body.keypoints
],
total_score=body.total_score,
total_parts=body.total_parts,
),
left_hand,
right_hand,
face,
)
)
return results
def __call__(
self,
input_image,
detect_resolution=512,
image_resolution=512,
include_body=True,
include_hand=False,
include_face=False,
hand_and_face=None,
output_type="pil",
**kwargs,
):
if hand_and_face is not None:
warnings.warn(
"hand_and_face is deprecated. Use include_hand and include_face instead.",
DeprecationWarning,
stacklevel=2,
)
include_hand = hand_and_face
include_face = hand_and_face
if "return_pil" in kwargs:
warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning, stacklevel=2)
output_type = "pil" if kwargs["return_pil"] else "np"
if type(output_type) is bool:
warnings.warn(
"Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions",
stacklevel=2,
)
if output_type:
output_type = "pil"
if not isinstance(input_image, np.ndarray):
input_image = np.array(input_image, dtype=np.uint8)
input_image = HWC3(input_image)
input_image = resize_image(input_image, detect_resolution)
H, W, C = input_image.shape
poses = self.detect_poses(input_image, include_hand, include_face)
canvas = draw_poses(poses, H, W, draw_body=include_body, draw_hand=include_hand, draw_face=include_face)
detected_map = canvas
detected_map = HWC3(detected_map)
img = resize_image(input_image, image_resolution)
H, W, C = img.shape
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
if output_type == "pil":
detected_map = Image.fromarray(detected_map)
return detected_map

View File

@@ -0,0 +1,319 @@
import math
from typing import List, NamedTuple, Union
import numpy as np
import torch
from scipy.ndimage.filters import gaussian_filter
from invokeai.backend.bria.controlnet_aux.open_pose import util
from invokeai.backend.bria.controlnet_aux.open_pose.model import bodypose_model
class Keypoint(NamedTuple):
x: float
y: float
score: float = 1.0
id: int = -1
class BodyResult(NamedTuple):
# Note: Using `Union` instead of `|` operator as the ladder is a Python
# 3.10 feature.
# Annotator code should be Python 3.8 Compatible, as controlnet repo uses
# Python 3.8 environment.
# https://github.com/lllyasviel/ControlNet/blob/d3284fcd0972c510635a4f5abe2eeb71dc0de524/environment.yaml#L6
keypoints: List[Union[Keypoint, None]]
total_score: float
total_parts: int
class Body(object):
def __init__(self, model_path):
self.model = bodypose_model()
model_dict = util.transfer(self.model, torch.load(model_path))
self.model.load_state_dict(model_dict)
self.model.eval()
def to(self, device):
self.model.to(device)
return self
def __call__(self, oriImg):
device = next(iter(self.model.parameters())).device
# scale_search = [0.5, 1.0, 1.5, 2.0]
scale_search = [0.5]
boxsize = 368
stride = 8
padValue = 128
thre1 = 0.1
thre2 = 0.05
multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search]
heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19))
paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38))
for m in range(len(multiplier)):
scale = multiplier[m]
imageToTest = util.smart_resize_k(oriImg, fx=scale, fy=scale)
imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue)
im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5
im = np.ascontiguousarray(im)
data = torch.from_numpy(im).float()
data = data.to(device)
# data = data.permute([2, 0, 1]).unsqueeze(0).float()
with torch.no_grad():
Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data)
Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy()
Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy()
# extract outputs, resize, and remove padding
# heatmap = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[1]].data), (1, 2, 0)) # output 1 is heatmaps
heatmap = np.transpose(np.squeeze(Mconv7_stage6_L2), (1, 2, 0)) # output 1 is heatmaps
heatmap = util.smart_resize_k(heatmap, fx=stride, fy=stride)
heatmap = heatmap[: imageToTest_padded.shape[0] - pad[2], : imageToTest_padded.shape[1] - pad[3], :]
heatmap = util.smart_resize(heatmap, (oriImg.shape[0], oriImg.shape[1]))
# paf = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[0]].data), (1, 2, 0)) # output 0 is PAFs
paf = np.transpose(np.squeeze(Mconv7_stage6_L1), (1, 2, 0)) # output 0 is PAFs
paf = util.smart_resize_k(paf, fx=stride, fy=stride)
paf = paf[: imageToTest_padded.shape[0] - pad[2], : imageToTest_padded.shape[1] - pad[3], :]
paf = util.smart_resize(paf, (oriImg.shape[0], oriImg.shape[1]))
heatmap_avg += heatmap_avg + heatmap / len(multiplier)
paf_avg += +paf / len(multiplier)
all_peaks = []
peak_counter = 0
for part in range(18):
map_ori = heatmap_avg[:, :, part]
one_heatmap = gaussian_filter(map_ori, sigma=3)
map_left = np.zeros(one_heatmap.shape)
map_left[1:, :] = one_heatmap[:-1, :]
map_right = np.zeros(one_heatmap.shape)
map_right[:-1, :] = one_heatmap[1:, :]
map_up = np.zeros(one_heatmap.shape)
map_up[:, 1:] = one_heatmap[:, :-1]
map_down = np.zeros(one_heatmap.shape)
map_down[:, :-1] = one_heatmap[:, 1:]
peaks_binary = np.logical_and.reduce(
(
one_heatmap >= map_left,
one_heatmap >= map_right,
one_heatmap >= map_up,
one_heatmap >= map_down,
one_heatmap > thre1,
)
)
peaks = list(zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0], strict=False)) # note reverse
peaks_with_score = [x + (map_ori[x[1], x[0]],) for x in peaks]
peak_id = range(peak_counter, peak_counter + len(peaks))
peaks_with_score_and_id = [peaks_with_score[i] + (peak_id[i],) for i in range(len(peak_id))]
all_peaks.append(peaks_with_score_and_id)
peak_counter += len(peaks)
# find connection in the specified sequence, center 29 is in the position 15
limbSeq = [
[2, 3],
[2, 6],
[3, 4],
[4, 5],
[6, 7],
[7, 8],
[2, 9],
[9, 10],
[10, 11],
[2, 12],
[12, 13],
[13, 14],
[2, 1],
[1, 15],
[15, 17],
[1, 16],
[16, 18],
[3, 17],
[6, 18],
]
# the middle joints heatmap correpondence
mapIdx = [
[31, 32],
[39, 40],
[33, 34],
[35, 36],
[41, 42],
[43, 44],
[19, 20],
[21, 22],
[23, 24],
[25, 26],
[27, 28],
[29, 30],
[47, 48],
[49, 50],
[53, 54],
[51, 52],
[55, 56],
[37, 38],
[45, 46],
]
connection_all = []
special_k = []
mid_num = 10
for k in range(len(mapIdx)):
score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]]
candA = all_peaks[limbSeq[k][0] - 1]
candB = all_peaks[limbSeq[k][1] - 1]
nA = len(candA)
nB = len(candB)
indexA, indexB = limbSeq[k]
if nA != 0 and nB != 0:
connection_candidate = []
for i in range(nA):
for j in range(nB):
vec = np.subtract(candB[j][:2], candA[i][:2])
norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1])
norm = max(0.001, norm)
vec = np.divide(vec, norm)
startend = list(
zip(
np.linspace(candA[i][0], candB[j][0], num=mid_num),
np.linspace(candA[i][1], candB[j][1], num=mid_num),
strict=False,
)
)
vec_x = np.array(
[
score_mid[int(round(startend[i][1])), int(round(startend[i][0])), 0]
for i in range(len(startend))
]
)
vec_y = np.array(
[
score_mid[int(round(startend[i][1])), int(round(startend[i][0])), 1]
for i in range(len(startend))
]
)
score_midpts = np.multiply(vec_x, vec[0]) + np.multiply(vec_y, vec[1])
score_with_dist_prior = sum(score_midpts) / len(score_midpts) + min(
0.5 * oriImg.shape[0] / norm - 1, 0
)
criterion1 = len(np.nonzero(score_midpts > thre2)[0]) > 0.8 * len(score_midpts)
criterion2 = score_with_dist_prior > 0
if criterion1 and criterion2:
connection_candidate.append(
[i, j, score_with_dist_prior, score_with_dist_prior + candA[i][2] + candB[j][2]]
)
connection_candidate = sorted(connection_candidate, key=lambda x: x[2], reverse=True)
connection = np.zeros((0, 5))
for c in range(len(connection_candidate)):
i, j, s = connection_candidate[c][0:3]
if i not in connection[:, 3] and j not in connection[:, 4]:
connection = np.vstack([connection, [candA[i][3], candB[j][3], s, i, j]])
if len(connection) >= min(nA, nB):
break
connection_all.append(connection)
else:
special_k.append(k)
connection_all.append([])
# last number in each row is the total parts number of that person
# the second last number in each row is the score of the overall configuration
subset = -1 * np.ones((0, 20))
candidate = np.array([item for sublist in all_peaks for item in sublist])
for k in range(len(mapIdx)):
if k not in special_k:
partAs = connection_all[k][:, 0]
partBs = connection_all[k][:, 1]
indexA, indexB = np.array(limbSeq[k]) - 1
for i in range(len(connection_all[k])): # = 1:size(temp,1)
found = 0
subset_idx = [-1, -1]
for j in range(len(subset)): # 1:size(subset,1):
if subset[j][indexA] == partAs[i] or subset[j][indexB] == partBs[i]:
subset_idx[found] = j
found += 1
if found == 1:
j = subset_idx[0]
if subset[j][indexB] != partBs[i]:
subset[j][indexB] = partBs[i]
subset[j][-1] += 1
subset[j][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2]
elif found == 2: # if found 2 and disjoint, merge them
j1, j2 = subset_idx
membership = ((subset[j1] >= 0).astype(int) + (subset[j2] >= 0).astype(int))[:-2]
if len(np.nonzero(membership == 2)[0]) == 0: # merge
subset[j1][:-2] += subset[j2][:-2] + 1
subset[j1][-2:] += subset[j2][-2:]
subset[j1][-2] += connection_all[k][i][2]
subset = np.delete(subset, j2, 0)
else: # as like found == 1
subset[j1][indexB] = partBs[i]
subset[j1][-1] += 1
subset[j1][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2]
# if find no partA in the subset, create a new subset
elif not found and k < 17:
row = -1 * np.ones(20)
row[indexA] = partAs[i]
row[indexB] = partBs[i]
row[-1] = 2
row[-2] = sum(candidate[connection_all[k][i, :2].astype(int), 2]) + connection_all[k][i][2]
subset = np.vstack([subset, row])
# delete some rows of subset which has few parts occur
deleteIdx = []
for i in range(len(subset)):
if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4:
deleteIdx.append(i)
subset = np.delete(subset, deleteIdx, axis=0)
# subset: n*20 array, 0-17 is the index in candidate, 18 is the total score, 19 is the total parts
# candidate: x, y, score, id
return candidate, subset
@staticmethod
def format_body_result(candidate: np.ndarray, subset: np.ndarray) -> List[BodyResult]:
"""
Format the body results from the candidate and subset arrays into a list of BodyResult objects.
Args:
candidate (np.ndarray): An array of candidates containing the x, y coordinates, score, and id
for each body part.
subset (np.ndarray): An array of subsets containing indices to the candidate array for each
person detected. The last two columns of each row hold the total score and total parts
of the person.
Returns:
List[BodyResult]: A list of BodyResult objects, where each object represents a person with
detected keypoints, total score, and total parts.
"""
return [
BodyResult(
keypoints=[
Keypoint(
x=candidate[candidate_index][0],
y=candidate[candidate_index][1],
score=candidate[candidate_index][2],
id=candidate[candidate_index][3],
)
if candidate_index != -1
else None
for candidate_index in person[:18].astype(int)
],
total_score=person[18],
total_parts=person[19],
)
for person in subset
]

View File

@@ -0,0 +1,307 @@
import logging
import numpy as np
import torch
import torch.nn.functional as F
from torch.nn import Conv2d, MaxPool2d, Module, ReLU, init
from torchvision.transforms import ToPILImage, ToTensor
from invokeai.backend.bria.controlnet_aux.open_pose import util
class FaceNet(Module):
"""Model the cascading heatmaps."""
def __init__(self):
super(FaceNet, self).__init__()
# cnn to make feature map
self.relu = ReLU()
self.max_pooling_2d = MaxPool2d(kernel_size=2, stride=2)
self.conv1_1 = Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1)
self.conv1_2 = Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)
self.conv2_1 = Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
self.conv2_2 = Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1)
self.conv3_1 = Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1)
self.conv3_2 = Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1)
self.conv3_3 = Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1)
self.conv3_4 = Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1)
self.conv4_1 = Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1)
self.conv4_2 = Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1)
self.conv4_3 = Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1)
self.conv4_4 = Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1)
self.conv5_1 = Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1)
self.conv5_2 = Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1)
self.conv5_3_CPM = Conv2d(in_channels=512, out_channels=128, kernel_size=3, stride=1, padding=1)
# stage1
self.conv6_1_CPM = Conv2d(in_channels=128, out_channels=512, kernel_size=1, stride=1, padding=0)
self.conv6_2_CPM = Conv2d(in_channels=512, out_channels=71, kernel_size=1, stride=1, padding=0)
# stage2
self.Mconv1_stage2 = Conv2d(in_channels=199, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv2_stage2 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv3_stage2 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv4_stage2 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv5_stage2 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv6_stage2 = Conv2d(in_channels=128, out_channels=128, kernel_size=1, stride=1, padding=0)
self.Mconv7_stage2 = Conv2d(in_channels=128, out_channels=71, kernel_size=1, stride=1, padding=0)
# stage3
self.Mconv1_stage3 = Conv2d(in_channels=199, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv2_stage3 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv3_stage3 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv4_stage3 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv5_stage3 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv6_stage3 = Conv2d(in_channels=128, out_channels=128, kernel_size=1, stride=1, padding=0)
self.Mconv7_stage3 = Conv2d(in_channels=128, out_channels=71, kernel_size=1, stride=1, padding=0)
# stage4
self.Mconv1_stage4 = Conv2d(in_channels=199, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv2_stage4 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv3_stage4 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv4_stage4 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv5_stage4 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv6_stage4 = Conv2d(in_channels=128, out_channels=128, kernel_size=1, stride=1, padding=0)
self.Mconv7_stage4 = Conv2d(in_channels=128, out_channels=71, kernel_size=1, stride=1, padding=0)
# stage5
self.Mconv1_stage5 = Conv2d(in_channels=199, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv2_stage5 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv3_stage5 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv4_stage5 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv5_stage5 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv6_stage5 = Conv2d(in_channels=128, out_channels=128, kernel_size=1, stride=1, padding=0)
self.Mconv7_stage5 = Conv2d(in_channels=128, out_channels=71, kernel_size=1, stride=1, padding=0)
# stage6
self.Mconv1_stage6 = Conv2d(in_channels=199, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv2_stage6 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv3_stage6 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv4_stage6 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv5_stage6 = Conv2d(in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3)
self.Mconv6_stage6 = Conv2d(in_channels=128, out_channels=128, kernel_size=1, stride=1, padding=0)
self.Mconv7_stage6 = Conv2d(in_channels=128, out_channels=71, kernel_size=1, stride=1, padding=0)
for m in self.modules():
if isinstance(m, Conv2d):
init.constant_(m.bias, 0)
def forward(self, x):
"""Return a list of heatmaps."""
heatmaps = []
h = self.relu(self.conv1_1(x))
h = self.relu(self.conv1_2(h))
h = self.max_pooling_2d(h)
h = self.relu(self.conv2_1(h))
h = self.relu(self.conv2_2(h))
h = self.max_pooling_2d(h)
h = self.relu(self.conv3_1(h))
h = self.relu(self.conv3_2(h))
h = self.relu(self.conv3_3(h))
h = self.relu(self.conv3_4(h))
h = self.max_pooling_2d(h)
h = self.relu(self.conv4_1(h))
h = self.relu(self.conv4_2(h))
h = self.relu(self.conv4_3(h))
h = self.relu(self.conv4_4(h))
h = self.relu(self.conv5_1(h))
h = self.relu(self.conv5_2(h))
h = self.relu(self.conv5_3_CPM(h))
feature_map = h
# stage1
h = self.relu(self.conv6_1_CPM(h))
h = self.conv6_2_CPM(h)
heatmaps.append(h)
# stage2
h = torch.cat([h, feature_map], dim=1) # channel concat
h = self.relu(self.Mconv1_stage2(h))
h = self.relu(self.Mconv2_stage2(h))
h = self.relu(self.Mconv3_stage2(h))
h = self.relu(self.Mconv4_stage2(h))
h = self.relu(self.Mconv5_stage2(h))
h = self.relu(self.Mconv6_stage2(h))
h = self.Mconv7_stage2(h)
heatmaps.append(h)
# stage3
h = torch.cat([h, feature_map], dim=1) # channel concat
h = self.relu(self.Mconv1_stage3(h))
h = self.relu(self.Mconv2_stage3(h))
h = self.relu(self.Mconv3_stage3(h))
h = self.relu(self.Mconv4_stage3(h))
h = self.relu(self.Mconv5_stage3(h))
h = self.relu(self.Mconv6_stage3(h))
h = self.Mconv7_stage3(h)
heatmaps.append(h)
# stage4
h = torch.cat([h, feature_map], dim=1) # channel concat
h = self.relu(self.Mconv1_stage4(h))
h = self.relu(self.Mconv2_stage4(h))
h = self.relu(self.Mconv3_stage4(h))
h = self.relu(self.Mconv4_stage4(h))
h = self.relu(self.Mconv5_stage4(h))
h = self.relu(self.Mconv6_stage4(h))
h = self.Mconv7_stage4(h)
heatmaps.append(h)
# stage5
h = torch.cat([h, feature_map], dim=1) # channel concat
h = self.relu(self.Mconv1_stage5(h))
h = self.relu(self.Mconv2_stage5(h))
h = self.relu(self.Mconv3_stage5(h))
h = self.relu(self.Mconv4_stage5(h))
h = self.relu(self.Mconv5_stage5(h))
h = self.relu(self.Mconv6_stage5(h))
h = self.Mconv7_stage5(h)
heatmaps.append(h)
# stage6
h = torch.cat([h, feature_map], dim=1) # channel concat
h = self.relu(self.Mconv1_stage6(h))
h = self.relu(self.Mconv2_stage6(h))
h = self.relu(self.Mconv3_stage6(h))
h = self.relu(self.Mconv4_stage6(h))
h = self.relu(self.Mconv5_stage6(h))
h = self.relu(self.Mconv6_stage6(h))
h = self.Mconv7_stage6(h)
heatmaps.append(h)
return heatmaps
LOG = logging.getLogger(__name__)
TOTEN = ToTensor()
TOPIL = ToPILImage()
params = {
"gaussian_sigma": 2.5,
"inference_img_size": 736, # 368, 736, 1312
"heatmap_peak_thresh": 0.1,
"crop_scale": 1.5,
"line_indices": [
[0, 1],
[1, 2],
[2, 3],
[3, 4],
[4, 5],
[5, 6],
[6, 7],
[7, 8],
[8, 9],
[9, 10],
[10, 11],
[11, 12],
[12, 13],
[13, 14],
[14, 15],
[15, 16],
[17, 18],
[18, 19],
[19, 20],
[20, 21],
[22, 23],
[23, 24],
[24, 25],
[25, 26],
[27, 28],
[28, 29],
[29, 30],
[31, 32],
[32, 33],
[33, 34],
[34, 35],
[36, 37],
[37, 38],
[38, 39],
[39, 40],
[40, 41],
[41, 36],
[42, 43],
[43, 44],
[44, 45],
[45, 46],
[46, 47],
[47, 42],
[48, 49],
[49, 50],
[50, 51],
[51, 52],
[52, 53],
[53, 54],
[54, 55],
[55, 56],
[56, 57],
[57, 58],
[58, 59],
[59, 48],
[60, 61],
[61, 62],
[62, 63],
[63, 64],
[64, 65],
[65, 66],
[66, 67],
[67, 60],
],
}
class Face(object):
"""
The OpenPose face landmark detector model.
Args:
inference_size: set the size of the inference image size, suggested:
368, 736, 1312, default 736
gaussian_sigma: blur the heatmaps, default 2.5
heatmap_peak_thresh: return landmark if over threshold, default 0.1
"""
def __init__(self, face_model_path, inference_size=None, gaussian_sigma=None, heatmap_peak_thresh=None):
self.inference_size = inference_size or params["inference_img_size"]
self.sigma = gaussian_sigma or params["gaussian_sigma"]
self.threshold = heatmap_peak_thresh or params["heatmap_peak_thresh"]
self.model = FaceNet()
self.model.load_state_dict(torch.load(face_model_path))
self.model.eval()
def to(self, device):
self.model.to(device)
return self
def __call__(self, face_img):
device = next(iter(self.model.parameters())).device
H, W, C = face_img.shape
w_size = 384
x_data = torch.from_numpy(util.smart_resize(face_img, (w_size, w_size))).permute([2, 0, 1]) / 256.0 - 0.5
x_data = x_data.to(device)
with torch.no_grad():
hs = self.model(x_data[None, ...])
heatmaps = F.interpolate(hs[-1], (H, W), mode="bilinear", align_corners=True).cpu().numpy()[0]
return heatmaps
def compute_peaks_from_heatmaps(self, heatmaps):
all_peaks = []
for part in range(heatmaps.shape[0]):
map_ori = heatmaps[part].copy()
binary = np.ascontiguousarray(map_ori > 0.05, dtype=np.uint8)
if np.sum(binary) == 0:
continue
positions = np.where(binary > 0.5)
intensities = map_ori[positions]
mi = np.argmax(intensities)
y, x = positions[0][mi], positions[1][mi]
all_peaks.append([x, y])
return np.array(all_peaks)

View File

@@ -0,0 +1,91 @@
import cv2
import numpy as np
import torch
from scipy.ndimage.filters import gaussian_filter
from skimage.measure import label
from invokeai.backend.bria.controlnet_aux.open_pose import util
from invokeai.backend.bria.controlnet_aux.open_pose.model import handpose_model
class Hand(object):
def __init__(self, model_path):
self.model = handpose_model()
model_dict = util.transfer(self.model, torch.load(model_path))
self.model.load_state_dict(model_dict)
self.model.eval()
def to(self, device):
self.model.to(device)
return self
def __call__(self, oriImgRaw):
device = next(iter(self.model.parameters())).device
scale_search = [0.5, 1.0, 1.5, 2.0]
# scale_search = [0.5]
boxsize = 368
stride = 8
padValue = 128
thre = 0.05
multiplier = [x * boxsize for x in scale_search]
wsize = 128
heatmap_avg = np.zeros((wsize, wsize, 22))
Hr, Wr, Cr = oriImgRaw.shape
oriImg = cv2.GaussianBlur(oriImgRaw, (0, 0), 0.8)
for m in range(len(multiplier)):
scale = multiplier[m]
imageToTest = util.smart_resize(oriImg, (scale, scale))
imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue)
im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5
im = np.ascontiguousarray(im)
data = torch.from_numpy(im).float()
data = data.to(device)
with torch.no_grad():
output = self.model(data).cpu().numpy()
# extract outputs, resize, and remove padding
heatmap = np.transpose(np.squeeze(output), (1, 2, 0)) # output 1 is heatmaps
heatmap = util.smart_resize_k(heatmap, fx=stride, fy=stride)
heatmap = heatmap[: imageToTest_padded.shape[0] - pad[2], : imageToTest_padded.shape[1] - pad[3], :]
heatmap = util.smart_resize(heatmap, (wsize, wsize))
heatmap_avg += heatmap / len(multiplier)
all_peaks = []
for part in range(21):
map_ori = heatmap_avg[:, :, part]
one_heatmap = gaussian_filter(map_ori, sigma=3)
binary = np.ascontiguousarray(one_heatmap > thre, dtype=np.uint8)
if np.sum(binary) == 0:
all_peaks.append([0, 0])
continue
label_img, label_numbers = label(binary, return_num=True, connectivity=binary.ndim)
max_index = np.argmax([np.sum(map_ori[label_img == i]) for i in range(1, label_numbers + 1)]) + 1
label_img[label_img != max_index] = 0
map_ori[label_img == 0] = 0
y, x = util.npmax(map_ori)
y = int(float(y) * float(Hr) / float(wsize))
x = int(float(x) * float(Wr) / float(wsize))
all_peaks.append([x, y])
return np.array(all_peaks)
if __name__ == "__main__":
hand_estimation = Hand("../model/hand_pose_model.pth")
# test_image = '../images/hand.jpg'
test_image = "../images/hand.jpg"
oriImg = cv2.imread(test_image) # B,G,R order
peaks = hand_estimation(oriImg)
canvas = util.draw_handpose(oriImg, peaks, True)
cv2.imshow("", canvas)
cv2.waitKey(0)

View File

@@ -0,0 +1,240 @@
from collections import OrderedDict
import torch
import torch.nn as nn
def make_layers(block, no_relu_layers):
layers = []
for layer_name, v in block.items():
if "pool" in layer_name:
layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1], padding=v[2])
layers.append((layer_name, layer))
else:
conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1], kernel_size=v[2], stride=v[3], padding=v[4])
layers.append((layer_name, conv2d))
if layer_name not in no_relu_layers:
layers.append(("relu_" + layer_name, nn.ReLU(inplace=True)))
return nn.Sequential(OrderedDict(layers))
class bodypose_model(nn.Module):
def __init__(self):
super(bodypose_model, self).__init__()
# these layers have no relu layer
no_relu_layers = [
"conv5_5_CPM_L1",
"conv5_5_CPM_L2",
"Mconv7_stage2_L1",
"Mconv7_stage2_L2",
"Mconv7_stage3_L1",
"Mconv7_stage3_L2",
"Mconv7_stage4_L1",
"Mconv7_stage4_L2",
"Mconv7_stage5_L1",
"Mconv7_stage5_L2",
"Mconv7_stage6_L1",
"Mconv7_stage6_L1",
]
blocks = {}
block0 = OrderedDict(
[
("conv1_1", [3, 64, 3, 1, 1]),
("conv1_2", [64, 64, 3, 1, 1]),
("pool1_stage1", [2, 2, 0]),
("conv2_1", [64, 128, 3, 1, 1]),
("conv2_2", [128, 128, 3, 1, 1]),
("pool2_stage1", [2, 2, 0]),
("conv3_1", [128, 256, 3, 1, 1]),
("conv3_2", [256, 256, 3, 1, 1]),
("conv3_3", [256, 256, 3, 1, 1]),
("conv3_4", [256, 256, 3, 1, 1]),
("pool3_stage1", [2, 2, 0]),
("conv4_1", [256, 512, 3, 1, 1]),
("conv4_2", [512, 512, 3, 1, 1]),
("conv4_3_CPM", [512, 256, 3, 1, 1]),
("conv4_4_CPM", [256, 128, 3, 1, 1]),
]
)
# Stage 1
block1_1 = OrderedDict(
[
("conv5_1_CPM_L1", [128, 128, 3, 1, 1]),
("conv5_2_CPM_L1", [128, 128, 3, 1, 1]),
("conv5_3_CPM_L1", [128, 128, 3, 1, 1]),
("conv5_4_CPM_L1", [128, 512, 1, 1, 0]),
("conv5_5_CPM_L1", [512, 38, 1, 1, 0]),
]
)
block1_2 = OrderedDict(
[
("conv5_1_CPM_L2", [128, 128, 3, 1, 1]),
("conv5_2_CPM_L2", [128, 128, 3, 1, 1]),
("conv5_3_CPM_L2", [128, 128, 3, 1, 1]),
("conv5_4_CPM_L2", [128, 512, 1, 1, 0]),
("conv5_5_CPM_L2", [512, 19, 1, 1, 0]),
]
)
blocks["block1_1"] = block1_1
blocks["block1_2"] = block1_2
self.model0 = make_layers(block0, no_relu_layers)
# Stages 2 - 6
for i in range(2, 7):
blocks["block%d_1" % i] = OrderedDict(
[
("Mconv1_stage%d_L1" % i, [185, 128, 7, 1, 3]),
("Mconv2_stage%d_L1" % i, [128, 128, 7, 1, 3]),
("Mconv3_stage%d_L1" % i, [128, 128, 7, 1, 3]),
("Mconv4_stage%d_L1" % i, [128, 128, 7, 1, 3]),
("Mconv5_stage%d_L1" % i, [128, 128, 7, 1, 3]),
("Mconv6_stage%d_L1" % i, [128, 128, 1, 1, 0]),
("Mconv7_stage%d_L1" % i, [128, 38, 1, 1, 0]),
]
)
blocks["block%d_2" % i] = OrderedDict(
[
("Mconv1_stage%d_L2" % i, [185, 128, 7, 1, 3]),
("Mconv2_stage%d_L2" % i, [128, 128, 7, 1, 3]),
("Mconv3_stage%d_L2" % i, [128, 128, 7, 1, 3]),
("Mconv4_stage%d_L2" % i, [128, 128, 7, 1, 3]),
("Mconv5_stage%d_L2" % i, [128, 128, 7, 1, 3]),
("Mconv6_stage%d_L2" % i, [128, 128, 1, 1, 0]),
("Mconv7_stage%d_L2" % i, [128, 19, 1, 1, 0]),
]
)
for k in blocks.keys():
blocks[k] = make_layers(blocks[k], no_relu_layers)
self.model1_1 = blocks["block1_1"]
self.model2_1 = blocks["block2_1"]
self.model3_1 = blocks["block3_1"]
self.model4_1 = blocks["block4_1"]
self.model5_1 = blocks["block5_1"]
self.model6_1 = blocks["block6_1"]
self.model1_2 = blocks["block1_2"]
self.model2_2 = blocks["block2_2"]
self.model3_2 = blocks["block3_2"]
self.model4_2 = blocks["block4_2"]
self.model5_2 = blocks["block5_2"]
self.model6_2 = blocks["block6_2"]
def forward(self, x):
out1 = self.model0(x)
out1_1 = self.model1_1(out1)
out1_2 = self.model1_2(out1)
out2 = torch.cat([out1_1, out1_2, out1], 1)
out2_1 = self.model2_1(out2)
out2_2 = self.model2_2(out2)
out3 = torch.cat([out2_1, out2_2, out1], 1)
out3_1 = self.model3_1(out3)
out3_2 = self.model3_2(out3)
out4 = torch.cat([out3_1, out3_2, out1], 1)
out4_1 = self.model4_1(out4)
out4_2 = self.model4_2(out4)
out5 = torch.cat([out4_1, out4_2, out1], 1)
out5_1 = self.model5_1(out5)
out5_2 = self.model5_2(out5)
out6 = torch.cat([out5_1, out5_2, out1], 1)
out6_1 = self.model6_1(out6)
out6_2 = self.model6_2(out6)
return out6_1, out6_2
class handpose_model(nn.Module):
def __init__(self):
super(handpose_model, self).__init__()
# these layers have no relu layer
no_relu_layers = [
"conv6_2_CPM",
"Mconv7_stage2",
"Mconv7_stage3",
"Mconv7_stage4",
"Mconv7_stage5",
"Mconv7_stage6",
]
# stage 1
block1_0 = OrderedDict(
[
("conv1_1", [3, 64, 3, 1, 1]),
("conv1_2", [64, 64, 3, 1, 1]),
("pool1_stage1", [2, 2, 0]),
("conv2_1", [64, 128, 3, 1, 1]),
("conv2_2", [128, 128, 3, 1, 1]),
("pool2_stage1", [2, 2, 0]),
("conv3_1", [128, 256, 3, 1, 1]),
("conv3_2", [256, 256, 3, 1, 1]),
("conv3_3", [256, 256, 3, 1, 1]),
("conv3_4", [256, 256, 3, 1, 1]),
("pool3_stage1", [2, 2, 0]),
("conv4_1", [256, 512, 3, 1, 1]),
("conv4_2", [512, 512, 3, 1, 1]),
("conv4_3", [512, 512, 3, 1, 1]),
("conv4_4", [512, 512, 3, 1, 1]),
("conv5_1", [512, 512, 3, 1, 1]),
("conv5_2", [512, 512, 3, 1, 1]),
("conv5_3_CPM", [512, 128, 3, 1, 1]),
]
)
block1_1 = OrderedDict([("conv6_1_CPM", [128, 512, 1, 1, 0]), ("conv6_2_CPM", [512, 22, 1, 1, 0])])
blocks = {}
blocks["block1_0"] = block1_0
blocks["block1_1"] = block1_1
# stage 2-6
for i in range(2, 7):
blocks["block%d" % i] = OrderedDict(
[
("Mconv1_stage%d" % i, [150, 128, 7, 1, 3]),
("Mconv2_stage%d" % i, [128, 128, 7, 1, 3]),
("Mconv3_stage%d" % i, [128, 128, 7, 1, 3]),
("Mconv4_stage%d" % i, [128, 128, 7, 1, 3]),
("Mconv5_stage%d" % i, [128, 128, 7, 1, 3]),
("Mconv6_stage%d" % i, [128, 128, 1, 1, 0]),
("Mconv7_stage%d" % i, [128, 22, 1, 1, 0]),
]
)
for k in blocks.keys():
blocks[k] = make_layers(blocks[k], no_relu_layers)
self.model1_0 = blocks["block1_0"]
self.model1_1 = blocks["block1_1"]
self.model2 = blocks["block2"]
self.model3 = blocks["block3"]
self.model4 = blocks["block4"]
self.model5 = blocks["block5"]
self.model6 = blocks["block6"]
def forward(self, x):
out1_0 = self.model1_0(x)
out1_1 = self.model1_1(out1_0)
concat_stage2 = torch.cat([out1_1, out1_0], 1)
out_stage2 = self.model2(concat_stage2)
concat_stage3 = torch.cat([out_stage2, out1_0], 1)
out_stage3 = self.model3(concat_stage3)
concat_stage4 = torch.cat([out_stage3, out1_0], 1)
out_stage4 = self.model4(concat_stage4)
concat_stage5 = torch.cat([out_stage4, out1_0], 1)
out_stage5 = self.model5(concat_stage5)
concat_stage6 = torch.cat([out_stage5, out1_0], 1)
out_stage6 = self.model6(concat_stage6)
return out_stage6

View File

@@ -0,0 +1,436 @@
import math
from typing import List, Tuple, Union
import cv2
import numpy as np
from invokeai.backend.bria.controlnet_aux.open_pose.body import BodyResult, Keypoint
eps = 0.01
def smart_resize(x, s):
Ht, Wt = s
if x.ndim == 2:
Ho, Wo = x.shape
Co = 1
else:
Ho, Wo, Co = x.shape
if Co == 3 or Co == 1:
k = float(Ht + Wt) / float(Ho + Wo)
return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
else:
return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2)
def smart_resize_k(x, fx, fy):
if x.ndim == 2:
Ho, Wo = x.shape
Co = 1
else:
Ho, Wo, Co = x.shape
Ht, Wt = Ho * fy, Wo * fx
if Co == 3 or Co == 1:
k = float(Ht + Wt) / float(Ho + Wo)
return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
else:
return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2)
def padRightDownCorner(img, stride, padValue):
h = img.shape[0]
w = img.shape[1]
pad = 4 * [None]
pad[0] = 0 # up
pad[1] = 0 # left
pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
img_padded = img
pad_up = np.tile(img_padded[0:1, :, :] * 0 + padValue, (pad[0], 1, 1))
img_padded = np.concatenate((pad_up, img_padded), axis=0)
pad_left = np.tile(img_padded[:, 0:1, :] * 0 + padValue, (1, pad[1], 1))
img_padded = np.concatenate((pad_left, img_padded), axis=1)
pad_down = np.tile(img_padded[-2:-1, :, :] * 0 + padValue, (pad[2], 1, 1))
img_padded = np.concatenate((img_padded, pad_down), axis=0)
pad_right = np.tile(img_padded[:, -2:-1, :] * 0 + padValue, (1, pad[3], 1))
img_padded = np.concatenate((img_padded, pad_right), axis=1)
return img_padded, pad
def transfer(model, model_weights):
transfered_model_weights = {}
for weights_name in model.state_dict().keys():
transfered_model_weights[weights_name] = model_weights[".".join(weights_name.split(".")[1:])]
return transfered_model_weights
def draw_bodypose(canvas: np.ndarray, keypoints: List[Keypoint]) -> np.ndarray:
"""
Draw keypoints and limbs representing body pose on a given canvas.
Args:
canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the body pose.
keypoints (List[Keypoint]): A list of Keypoint objects representing the body keypoints to be drawn.
Returns:
np.ndarray: A 3D numpy array representing the modified canvas with the drawn body pose.
Note:
The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
"""
H, W, C = canvas.shape
stickwidth = 4
limbSeq = [
[2, 3],
[2, 6],
[3, 4],
[4, 5],
[6, 7],
[7, 8],
[2, 9],
[9, 10],
[10, 11],
[2, 12],
[12, 13],
[13, 14],
[2, 1],
[1, 15],
[15, 17],
[1, 16],
[16, 18],
]
colors = [
[255, 0, 0],
[255, 85, 0],
[255, 170, 0],
[255, 255, 0],
[170, 255, 0],
[85, 255, 0],
[0, 255, 0],
[0, 255, 85],
[0, 255, 170],
[0, 255, 255],
[0, 170, 255],
[0, 85, 255],
[0, 0, 255],
[85, 0, 255],
[170, 0, 255],
[255, 0, 255],
[255, 0, 170],
[255, 0, 85],
]
for (k1_index, k2_index), color in zip(limbSeq, colors, strict=False):
keypoint1 = keypoints[k1_index - 1]
keypoint2 = keypoints[k2_index - 1]
if keypoint1 is None or keypoint2 is None:
continue
Y = np.array([keypoint1.x, keypoint2.x]) * float(W)
X = np.array([keypoint1.y, keypoint2.y]) * float(H)
mX = np.mean(X)
mY = np.mean(Y)
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
cv2.fillConvexPoly(canvas, polygon, [int(float(c) * 0.6) for c in color])
for keypoint, color in zip(keypoints, colors, strict=False):
if keypoint is None:
continue
x, y = keypoint.x, keypoint.y
x = int(x * W)
y = int(y * H)
cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1)
return canvas
def draw_handpose(canvas: np.ndarray, keypoints: Union[List[Keypoint], None]) -> np.ndarray:
import matplotlib
"""
Draw keypoints and connections representing hand pose on a given canvas.
Args:
canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose.
keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn
or None if no keypoints are present.
Returns:
np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose.
Note:
The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
"""
if not keypoints:
return canvas
H, W, C = canvas.shape
edges = [
[0, 1],
[1, 2],
[2, 3],
[3, 4],
[0, 5],
[5, 6],
[6, 7],
[7, 8],
[0, 9],
[9, 10],
[10, 11],
[11, 12],
[0, 13],
[13, 14],
[14, 15],
[15, 16],
[0, 17],
[17, 18],
[18, 19],
[19, 20],
]
for ie, (e1, e2) in enumerate(edges):
k1 = keypoints[e1]
k2 = keypoints[e2]
if k1 is None or k2 is None:
continue
x1 = int(k1.x * W)
y1 = int(k1.y * H)
x2 = int(k2.x * W)
y2 = int(k2.y * H)
if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
cv2.line(
canvas,
(x1, y1),
(x2, y2),
matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255,
thickness=2,
)
for keypoint in keypoints:
x, y = keypoint.x, keypoint.y
x = int(x * W)
y = int(y * H)
if x > eps and y > eps:
cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
return canvas
def draw_facepose(canvas: np.ndarray, keypoints: Union[List[Keypoint], None]) -> np.ndarray:
"""
Draw keypoints representing face pose on a given canvas.
Args:
canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the face pose.
keypoints (List[Keypoint]| None): A list of Keypoint objects representing the face keypoints to be drawn
or None if no keypoints are present.
Returns:
np.ndarray: A 3D numpy array representing the modified canvas with the drawn face pose.
Note:
The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
"""
if not keypoints:
return canvas
H, W, C = canvas.shape
for keypoint in keypoints:
x, y = keypoint.x, keypoint.y
x = int(x * W)
y = int(y * H)
if x > eps and y > eps:
cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1)
return canvas
# detect hand according to body pose keypoints
# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp
def handDetect(body: BodyResult, oriImg) -> List[Tuple[int, int, int, bool]]:
"""
Detect hands in the input body pose keypoints and calculate the bounding box for each hand.
Args:
body (BodyResult): A BodyResult object containing the detected body pose keypoints.
oriImg (numpy.ndarray): A 3D numpy array representing the original input image.
Returns:
List[Tuple[int, int, int, bool]]: A list of tuples, each containing the coordinates (x, y) of the top-left
corner of the bounding box, the width (height) of the bounding box, and
a boolean flag indicating whether the hand is a left hand (True) or a
right hand (False).
Notes:
- The width and height of the bounding boxes are equal since the network requires squared input.
- The minimum bounding box size is 20 pixels.
"""
ratioWristElbow = 0.33
detect_result = []
image_height, image_width = oriImg.shape[0:2]
keypoints = body.keypoints
# right hand: wrist 4, elbow 3, shoulder 2
# left hand: wrist 7, elbow 6, shoulder 5
left_shoulder = keypoints[5]
left_elbow = keypoints[6]
left_wrist = keypoints[7]
right_shoulder = keypoints[2]
right_elbow = keypoints[3]
right_wrist = keypoints[4]
# if any of three not detected
has_left = all(keypoint is not None for keypoint in (left_shoulder, left_elbow, left_wrist))
has_right = all(keypoint is not None for keypoint in (right_shoulder, right_elbow, right_wrist))
if not (has_left or has_right):
return []
hands = []
# left hand
if has_left:
hands.append([left_shoulder.x, left_shoulder.y, left_elbow.x, left_elbow.y, left_wrist.x, left_wrist.y, True])
# right hand
if has_right:
hands.append(
[right_shoulder.x, right_shoulder.y, right_elbow.x, right_elbow.y, right_wrist.x, right_wrist.y, False]
)
for x1, y1, x2, y2, x3, y3, is_left in hands:
# pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox
# handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]);
# handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]);
# const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow);
# const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder);
# handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder);
x = x3 + ratioWristElbow * (x3 - x2)
y = y3 + ratioWristElbow * (y3 - y2)
distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2)
distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
# x-y refers to the center --> offset to topLeft point
# handRectangle.x -= handRectangle.width / 2.f;
# handRectangle.y -= handRectangle.height / 2.f;
x -= width / 2
y -= width / 2 # width = height
# overflow the image
if x < 0:
x = 0
if y < 0:
y = 0
width1 = width
width2 = width
if x + width > image_width:
width1 = image_width - x
if y + width > image_height:
width2 = image_height - y
width = min(width1, width2)
# the max hand box value is 20 pixels
if width >= 20:
detect_result.append((int(x), int(y), int(width), is_left))
"""
return value: [[x, y, w, True if left hand else False]].
width=height since the network require squared input.
x, y is the coordinate of top left.
"""
return detect_result
# Written by Lvmin
def faceDetect(body: BodyResult, oriImg) -> Union[Tuple[int, int, int], None]:
"""
Detect the face in the input body pose keypoints and calculate the bounding box for the face.
Args:
body (BodyResult): A BodyResult object containing the detected body pose keypoints.
oriImg (numpy.ndarray): A 3D numpy array representing the original input image.
Returns:
Tuple[int, int, int] | None: A tuple containing the coordinates (x, y) of the top-left corner of the
bounding box and the width (height) of the bounding box, or None if the
face is not detected or the bounding box width is less than 20 pixels.
Notes:
- The width and height of the bounding box are equal.
- The minimum bounding box size is 20 pixels.
"""
# left right eye ear 14 15 16 17
image_height, image_width = oriImg.shape[0:2]
keypoints = body.keypoints
head = keypoints[0]
left_eye = keypoints[14]
right_eye = keypoints[15]
left_ear = keypoints[16]
right_ear = keypoints[17]
if head is None or all(keypoint is None for keypoint in (left_eye, right_eye, left_ear, right_ear)):
return None
width = 0.0
x0, y0 = head.x, head.y
if left_eye is not None:
x1, y1 = left_eye.x, left_eye.y
d = max(abs(x0 - x1), abs(y0 - y1))
width = max(width, d * 3.0)
if right_eye is not None:
x1, y1 = right_eye.x, right_eye.y
d = max(abs(x0 - x1), abs(y0 - y1))
width = max(width, d * 3.0)
if left_ear is not None:
x1, y1 = left_ear.x, left_ear.y
d = max(abs(x0 - x1), abs(y0 - y1))
width = max(width, d * 1.5)
if right_ear is not None:
x1, y1 = right_ear.x, right_ear.y
d = max(abs(x0 - x1), abs(y0 - y1))
width = max(width, d * 1.5)
x, y = x0, y0
x -= width
y -= width
if x < 0:
x = 0
if y < 0:
y = 0
width1 = width * 2
width2 = width * 2
if x + width > image_width:
width1 = image_width - x
if y + width > image_height:
width2 = image_height - y
width = min(width1, width2)
if width >= 20:
return int(x), int(y), int(width)
else:
return None
# get max index of 2d array
def npmax(array):
arrayindex = array.argmax(1)
arrayvalue = array.max(1)
i = arrayvalue.argmax()
j = arrayindex[i]
return i, j

View File

@@ -0,0 +1,260 @@
import os
import random
import cv2
import numpy as np
import torch
annotator_ckpts_path = os.path.join(os.path.dirname(__file__), "ckpts")
def HWC3(x):
assert x.dtype == np.uint8
if x.ndim == 2:
x = x[:, :, None]
assert x.ndim == 3
H, W, C = x.shape
assert C == 1 or C == 3 or C == 4
if C == 3:
return x
if C == 1:
return np.concatenate([x, x, x], axis=2)
if C == 4:
color = x[:, :, 0:3].astype(np.float32)
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
y = color * alpha + 255.0 * (1.0 - alpha)
y = y.clip(0, 255).astype(np.uint8)
return y
def make_noise_disk(H, W, C, F):
noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C))
noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC)
noise = noise[F : F + H, F : F + W]
noise -= np.min(noise)
noise /= np.max(noise)
if C == 1:
noise = noise[:, :, None]
return noise
def nms(x, t, s):
x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
y = np.zeros_like(x)
for f in [f1, f2, f3, f4]:
np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
z = np.zeros_like(y, dtype=np.uint8)
z[y > t] = 255
return z
def min_max_norm(x):
x -= np.min(x)
x /= np.maximum(np.max(x), 1e-5)
return x
def safe_step(x, step=2):
y = x.astype(np.float32) * float(step + 1)
y = y.astype(np.int32).astype(np.float32) / float(step)
return y
def img2mask(img, H, W, low=10, high=90):
assert img.ndim == 3 or img.ndim == 2
assert img.dtype == np.uint8
if img.ndim == 3:
y = img[:, :, random.randrange(0, img.shape[2])]
else:
y = img
y = cv2.resize(y, (W, H), interpolation=cv2.INTER_CUBIC)
if random.uniform(0, 1) < 0.5:
y = 255 - y
return y < np.percentile(y, random.randrange(low, high))
def resize_image(input_image, resolution):
H, W, C = input_image.shape
H = float(H)
W = float(W)
k = float(resolution) / min(H, W)
H *= k
W *= k
H = int(np.round(H / 64.0)) * 64
W = int(np.round(W / 64.0)) * 64
img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
return img
def torch_gc():
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def ade_palette():
"""ADE20K palette that maps each class to RGB values."""
return [
[120, 120, 120],
[180, 120, 120],
[6, 230, 230],
[80, 50, 50],
[4, 200, 3],
[120, 120, 80],
[140, 140, 140],
[204, 5, 255],
[230, 230, 230],
[4, 250, 7],
[224, 5, 255],
[235, 255, 7],
[150, 5, 61],
[120, 120, 70],
[8, 255, 51],
[255, 6, 82],
[143, 255, 140],
[204, 255, 4],
[255, 51, 7],
[204, 70, 3],
[0, 102, 200],
[61, 230, 250],
[255, 6, 51],
[11, 102, 255],
[255, 7, 71],
[255, 9, 224],
[9, 7, 230],
[220, 220, 220],
[255, 9, 92],
[112, 9, 255],
[8, 255, 214],
[7, 255, 224],
[255, 184, 6],
[10, 255, 71],
[255, 41, 10],
[7, 255, 255],
[224, 255, 8],
[102, 8, 255],
[255, 61, 6],
[255, 194, 7],
[255, 122, 8],
[0, 255, 20],
[255, 8, 41],
[255, 5, 153],
[6, 51, 255],
[235, 12, 255],
[160, 150, 20],
[0, 163, 255],
[140, 140, 140],
[250, 10, 15],
[20, 255, 0],
[31, 255, 0],
[255, 31, 0],
[255, 224, 0],
[153, 255, 0],
[0, 0, 255],
[255, 71, 0],
[0, 235, 255],
[0, 173, 255],
[31, 0, 255],
[11, 200, 200],
[255, 82, 0],
[0, 255, 245],
[0, 61, 255],
[0, 255, 112],
[0, 255, 133],
[255, 0, 0],
[255, 163, 0],
[255, 102, 0],
[194, 255, 0],
[0, 143, 255],
[51, 255, 0],
[0, 82, 255],
[0, 255, 41],
[0, 255, 173],
[10, 0, 255],
[173, 255, 0],
[0, 255, 153],
[255, 92, 0],
[255, 0, 255],
[255, 0, 245],
[255, 0, 102],
[255, 173, 0],
[255, 0, 20],
[255, 184, 184],
[0, 31, 255],
[0, 255, 61],
[0, 71, 255],
[255, 0, 204],
[0, 255, 194],
[0, 255, 82],
[0, 10, 255],
[0, 112, 255],
[51, 0, 255],
[0, 194, 255],
[0, 122, 255],
[0, 255, 163],
[255, 153, 0],
[0, 255, 10],
[255, 112, 0],
[143, 255, 0],
[82, 0, 255],
[163, 255, 0],
[255, 235, 0],
[8, 184, 170],
[133, 0, 255],
[0, 255, 92],
[184, 0, 255],
[255, 0, 31],
[0, 184, 255],
[0, 214, 255],
[255, 0, 112],
[92, 255, 0],
[0, 224, 255],
[112, 224, 255],
[70, 184, 160],
[163, 0, 255],
[153, 0, 255],
[71, 255, 0],
[255, 0, 163],
[255, 204, 0],
[255, 0, 143],
[0, 255, 235],
[133, 255, 0],
[255, 0, 235],
[245, 0, 255],
[255, 0, 122],
[255, 245, 0],
[10, 190, 212],
[214, 255, 0],
[0, 204, 255],
[20, 0, 255],
[255, 255, 0],
[0, 153, 255],
[0, 41, 255],
[0, 255, 204],
[41, 0, 255],
[41, 255, 0],
[173, 0, 255],
[0, 245, 255],
[71, 0, 255],
[122, 0, 255],
[0, 255, 184],
[0, 92, 255],
[184, 255, 0],
[0, 133, 255],
[255, 214, 0],
[25, 194, 194],
[102, 255, 0],
[92, 0, 255],
]

View File

@@ -0,0 +1,559 @@
# type: ignore
# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
import torch
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import PeftAdapterMixin
from diffusers.models.attention_processor import AttentionProcessor
from diffusers.models.controlnet import zero_module
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from diffusers.utils.outputs import BaseOutput
from invokeai.backend.bria.transformer_bria import (
EmbedND,
FluxSingleTransformerBlock,
FluxTransformerBlock,
TimestepProjEmbeddings,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
BRIA_CONTROL_MODES = Literal["depth", "canny", "colorgrid", "recolor", "tile", "pose"]
class BriaControlModes(Enum):
depth = 0
canny = 1
colorgrid = 2
recolor = 3
tile = 4
pose = 5
@dataclass
class BriaControlNetOutput(BaseOutput):
controlnet_block_samples: Tuple[torch.Tensor]
controlnet_single_block_samples: Tuple[torch.Tensor]
class BriaControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
patch_size: int = 1,
in_channels: int = 64,
num_layers: int = 19,
num_single_layers: int = 38,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
joint_attention_dim: int = 4096,
pooled_projection_dim: int = 768,
guidance_embeds: bool = False,
axes_dims_rope: Optional[List[int]] = None,
num_mode: int = None,
rope_theta: int = 10000,
time_theta: int = 10000,
):
super().__init__()
self.out_channels = in_channels
self.inner_dim = num_attention_heads * attention_head_dim
# self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
axes_dims_rope = [16, 56, 56] if axes_dims_rope is None else axes_dims_rope
self.pos_embed = EmbedND(theta=rope_theta, axes_dim=axes_dims_rope)
# text_time_guidance_cls = (
# CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
# )
# self.time_text_embed = text_time_guidance_cls(
# embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
# )
self.time_embed = TimestepProjEmbeddings(embedding_dim=self.inner_dim, time_theta=time_theta)
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
self.transformer_blocks = nn.ModuleList(
[
FluxTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
)
for i in range(num_layers)
]
)
self.single_transformer_blocks = nn.ModuleList(
[
FluxSingleTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
)
for i in range(num_single_layers)
]
)
# controlnet_blocks
self.controlnet_blocks = nn.ModuleList([])
for _ in range(len(self.transformer_blocks)):
self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
self.controlnet_single_blocks = nn.ModuleList([])
for _ in range(len(self.single_transformer_blocks)):
self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
self.union = num_mode is not None and num_mode > 0
if self.union:
self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim)
self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
self.gradient_checkpointing = False
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self):
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
@classmethod
def from_transformer(
cls,
transformer,
num_layers: int = 4,
num_single_layers: int = 10,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
load_weights_from_transformer=True,
):
config = transformer.config
config["num_layers"] = num_layers
config["num_single_layers"] = num_single_layers
config["attention_head_dim"] = attention_head_dim
config["num_attention_heads"] = num_attention_heads
controlnet = cls(**config)
if load_weights_from_transformer:
controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict())
controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
controlnet.single_transformer_blocks.load_state_dict(
transformer.single_transformer_blocks.state_dict(), strict=False
)
controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder)
return controlnet
def forward(
self,
hidden_states: torch.Tensor,
controlnet_cond: torch.Tensor,
controlnet_mode: torch.Tensor = None,
conditioning_scale: float = 1.0,
encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
The [`FluxTransformer2DModel`] forward method.
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
controlnet_cond (`torch.Tensor`):
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
controlnet_mode (`torch.Tensor`):
The mode tensor of shape `(batch_size, 1)`.
conditioning_scale (`float`, defaults to `1.0`):
The scale factor for ControlNet outputs.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if guidance is not None:
print("guidance is not supported in BriaControlNetModel")
if pooled_projections is not None:
print("pooled_projections is not supported in BriaControlNetModel")
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
# Convert controlnet_cond to the same dtype as the model weights
controlnet_cond = controlnet_cond.to(dtype=self.controlnet_x_embedder.weight.dtype)
# add
hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
timestep = timestep.to(hidden_states.dtype) # Original code was * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) # Original code was * 1000
else:
guidance = None
temb = self.time_embed(timestep, dtype=hidden_states.dtype)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if txt_ids.ndim == 3:
logger.warning(
"Passing `txt_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
txt_ids = txt_ids[0]
if img_ids.ndim == 3:
logger.warning(
"Passing `img_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
img_ids = img_ids[0]
if self.union:
# union mode
if controlnet_mode is None:
raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")
# Validate controlnet_mode values are within the valid range
if torch.any(controlnet_mode < 0) or torch.any(controlnet_mode >= self.num_mode):
raise ValueError(
f"`controlnet_mode` values must be in range [0, {self.num_mode - 1}], but got values outside this range"
)
# union mode emb
controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
if controlnet_mode_emb.shape[0] < encoder_hidden_states.shape[0]: # duplicate mode emb for each batch
controlnet_mode_emb = controlnet_mode_emb.expand(
encoder_hidden_states.shape[0], 1, encoder_hidden_states.shape[2]
)
encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
txt_ids = torch.cat((txt_ids[0:1, :], txt_ids), dim=0)
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)
block_samples = ()
for _, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
block_samples = block_samples + (hidden_states,)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
single_block_samples = ()
for _, block in enumerate(self.single_transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
# controlnet block
controlnet_block_samples = ()
for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks, strict=False):
block_sample = controlnet_block(block_sample)
controlnet_block_samples = controlnet_block_samples + (block_sample,)
controlnet_single_block_samples = ()
for single_block_sample, controlnet_block in zip(
single_block_samples, self.controlnet_single_blocks, strict=False
):
single_block_sample = controlnet_block(single_block_sample)
controlnet_single_block_samples = controlnet_single_block_samples + (single_block_sample,)
# scaling
controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples]
controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
controlnet_single_block_samples = (
None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples
)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (controlnet_block_samples, controlnet_single_block_samples)
return BriaControlNetOutput(
controlnet_block_samples=controlnet_block_samples,
controlnet_single_block_samples=controlnet_single_block_samples,
)
class BriaMultiControlNetModel(ModelMixin):
r"""
`BriaMultiControlNetModel` wrapper class for Multi-BriaControlNetModel
This module is a wrapper for multiple instances of the `BriaControlNetModel`. The `forward()` API is designed to be
compatible with `BriaControlNetModel`.
Args:
controlnets (`List[BriaControlNetModel]`):
Provides additional conditioning to the unet during the denoising process. You must set multiple
`BriaControlNetModel` as a list.
"""
def __init__(self, controlnets):
super().__init__()
self.nets = nn.ModuleList(controlnets)
def forward(
self,
hidden_states: torch.FloatTensor,
controlnet_cond: List[torch.tensor],
controlnet_mode: List[torch.tensor],
conditioning_scale: List[float],
encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[BriaControlNetOutput, Tuple]:
# ControlNet-Union with multiple conditions
# only load one ControlNet for saving memories
if len(self.nets) == 1 and self.nets[0].union:
controlnet = self.nets[0]
for i, (image, mode, scale) in enumerate(
zip(controlnet_cond, controlnet_mode, conditioning_scale, strict=False)
):
block_samples, single_block_samples = controlnet(
hidden_states=hidden_states,
controlnet_cond=image,
controlnet_mode=mode[:, None],
conditioning_scale=scale,
timestep=timestep,
guidance=guidance,
pooled_projections=pooled_projections,
encoder_hidden_states=encoder_hidden_states,
txt_ids=txt_ids,
img_ids=img_ids,
joint_attention_kwargs=joint_attention_kwargs,
return_dict=return_dict,
)
# merge samples
if i == 0:
control_block_samples = block_samples
control_single_block_samples = single_block_samples
else:
control_block_samples = [
control_block_sample + block_sample
for control_block_sample, block_sample in zip(
control_block_samples, block_samples, strict=False
)
]
control_single_block_samples = [
control_single_block_sample + block_sample
for control_single_block_sample, block_sample in zip(
control_single_block_samples, single_block_samples, strict=False
)
]
# Regular Multi-ControlNets
# load all ControlNets into memories
else:
for i, (image, mode, scale, controlnet) in enumerate(
zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets, strict=False)
):
block_samples, single_block_samples = controlnet(
hidden_states=hidden_states,
controlnet_cond=image,
controlnet_mode=mode[:, None],
conditioning_scale=scale,
timestep=timestep,
guidance=guidance,
pooled_projections=pooled_projections,
encoder_hidden_states=encoder_hidden_states,
txt_ids=txt_ids,
img_ids=img_ids,
joint_attention_kwargs=joint_attention_kwargs,
return_dict=return_dict,
)
# merge samples
if i == 0:
control_block_samples = block_samples
control_single_block_samples = single_block_samples
else:
if block_samples is not None and control_block_samples is not None:
control_block_samples = [
control_block_sample + block_sample
for control_block_sample, block_sample in zip(
control_block_samples, block_samples, strict=False
)
]
if single_block_samples is not None and control_single_block_samples is not None:
control_single_block_samples = [
control_single_block_sample + block_sample
for control_single_block_sample, block_sample in zip(
control_single_block_samples, single_block_samples, strict=False
)
]
return control_block_samples, control_single_block_samples

View File

@@ -0,0 +1,68 @@
from typing import List, Tuple
import torch
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from PIL import Image
@torch.no_grad()
def prepare_control_images(
vae: AutoencoderKL,
control_images: list[Image.Image],
control_modes: list[int],
width: int,
height: int,
device: torch.device,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
tensored_control_images = []
tensored_control_modes = []
for idx, control_image_ in enumerate(control_images):
tensored_control_image = _prepare_image(
image=control_image_,
width=width,
height=height,
device=device,
dtype=vae.dtype,
)
height, width = tensored_control_image.shape[-2:]
# vae encode
tensored_control_image = vae.encode(tensored_control_image).latent_dist.sample()
tensored_control_image = (tensored_control_image) * vae.config.scaling_factor
# pack
height_control_image, width_control_image = tensored_control_image.shape[2:]
tensored_control_image = _pack_latents(
tensored_control_image,
height_control_image,
width_control_image,
)
tensored_control_images.append(tensored_control_image)
tensored_control_modes.append(
torch.tensor(control_modes[idx]).expand(tensored_control_image.shape[0]).to(device, dtype=torch.long)
)
return tensored_control_images, tensored_control_modes
def _prepare_image(
image: Image.Image,
width: int,
height: int,
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
image = image.convert("RGB")
image = VaeImageProcessor(vae_scale_factor=16).preprocess(image, height=height, width=width)
image = image.repeat_interleave(1, dim=0)
image = image.to(device=device, dtype=dtype)
return image
def _pack_latents(latents, height, width):
latents = latents.view(1, 4, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(1, (height // 2) * (width // 2), 16)
return latents

View File

@@ -0,0 +1,636 @@
from typing import Any, Callable, Dict, List, Optional, Union
import diffusers
import numpy as np
import torch
from diffusers import AutoencoderKL, DDIMScheduler, EulerAncestralDiscreteScheduler
from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import FluxLoraLoaderMixin
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline, calculate_shift, retrieve_timesteps
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers
from diffusers.utils import (
USE_PEFT_BACKEND,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from diffusers.utils.torch_utils import randn_tensor
from transformers import (
T5EncoderModel,
T5TokenizerFast,
)
from invokeai.backend.bria.bria_utils import get_original_sigmas, get_t5_prompt_embeds, is_ng_none
from invokeai.backend.bria.transformer_bria import BriaTransformer2DModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import StableDiffusion3Pipeline
>>> pipe = StableDiffusion3Pipeline.from_pretrained(
... "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
... )
>>> pipe.to("cuda")
>>> prompt = "A cat holding a sign that says hello world"
>>> image = pipe(prompt).images[0]
>>> image.save("sd3.png")
```
"""
T5_PRECISION = torch.float16
"""
Based on FluxPipeline with several changes:
- no pooled embeddings
- We use zero padding for prompts
- No guidance embedding since this is not a distilled version
"""
class BriaPipeline(FluxPipeline):
r"""
Args:
transformer ([`SD3Transformer2DModel`]):
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`T5EncoderModel`]):
Frozen text-encoder. Stable Diffusion 3 uses
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
[t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
tokenizer (`T5TokenizerFast`):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
"""
def __init__(
self,
transformer: BriaTransformer2DModel,
scheduler: Union[FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers],
vae: AutoencoderKL,
text_encoder: T5EncoderModel,
tokenizer: T5TokenizerFast,
):
self.register_modules(
vae=vae,
transformer=transformer,
scheduler=scheduler,
text_encoder=text_encoder,
tokenizer=tokenizer,
)
# TODO - why different than offical flux (-1)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = 64 # due to patchify=> 128,128 => res of 1k,1k
# T5 is senstive to precision so we use the precision used for precompute and cast as needed
if self.vae.config.shift_factor is None:
self.vae.config.shift_factor = 0
self.vae.to(dtype=torch.float32)
def encode_prompt(
self,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 128,
lora_scale: Optional[float] = None,
):
r"""
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
"""
device = device or self._execution_device
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if self.text_encoder is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds = get_t5_prompt_embeds(
self.tokenizer,
self.text_encoder,
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
).to(dtype=self.transformer.dtype)
if do_classifier_free_guidance and negative_prompt_embeds is None:
if not is_ng_none(negative_prompt):
negative_prompt = (
batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
)
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds = get_t5_prompt_embeds(
self.tokenizer,
self.text_encoder,
prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
).to(dtype=self.transformer.dtype)
else:
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
if self.text_encoder is not None:
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
return prompt_embeds, negative_prompt_embeds, text_ids
@property
def guidance_scale(self):
return self._guidance_scale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def joint_attention_kwargs(self):
return self._joint_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 30,
timesteps: List[int] = None,
guidance_scale: float = 5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
max_sequence_length: int = 128,
clip_value: Union[None, float] = None,
normalize: bool = False,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 5.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
Examples:
Returns:
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
images.
"""
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
callback_on_step_end_tensor_inputs = (
["latents"] if callback_on_step_end_tensor_inputs is None else callback_on_step_end_tensor_inputs
)
self.check_inputs(
prompt=prompt,
height=height,
width=width,
prompt_embeds=prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
lora_scale = self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
(prompt_embeds, negative_prompt_embeds, text_ids) = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4 # due to patch=2, we devide by 4
latents, latent_image_ids = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
if (
isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler)
and self.scheduler.config["use_dynamic_shifting"]
):
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
image_seq_len = latents.shape[1] # Shift by height - Why just height?
print(f"Using dynamic shift in pipeline with sequence length {image_seq_len}")
mu = calculate_shift(
image_seq_len,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
timesteps,
sigmas,
mu=mu,
)
else:
# 4. Prepare timesteps
# Sample from training sigmas
if isinstance(self.scheduler, DDIMScheduler) or isinstance(self.scheduler, EulerAncestralDiscreteScheduler):
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, None, None
)
else:
sigmas = get_original_sigmas(
num_train_timesteps=self.scheduler.config.num_train_timesteps,
num_inference_steps=num_inference_steps,
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas=sigmas
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# Supprot different diffusers versions
if diffusers.__version__ >= "0.32.0":
latent_image_ids = latent_image_ids[0]
text_ids = text_ids[0]
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
if not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
# This is predicts "v" from flow-matching or eps from diffusion
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
txt_ids=text_ids,
img_ids=latent_image_ids,
)[0]
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
cfg_noise_pred_text = noise_pred_text.std()
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
if normalize:
noise_pred = noise_pred * (0.7 * (cfg_noise_pred_text / noise_pred.std())) + 0.3 * noise_pred
if clip_value:
assert clip_value > 0
noise_pred = noise_pred.clip(-clip_value, clip_value)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if output_type == "latent":
image = latents
else:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = (latents.to(dtype=torch.float32) / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return FluxPipelineOutput(images=image)
def check_inputs(
self,
prompt,
height,
width,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
def to(self, *args, **kwargs):
DiffusionPipeline.to(self, *args, **kwargs)
# T5 is senstive to precision so we use the precision used for precompute and cast as needed
self.text_encoder = self.text_encoder.to(dtype=T5_PRECISION)
for block in self.text_encoder.encoder.block:
block.layer[-1].DenseReluDense.wo.to(dtype=torch.float32)
if self.vae.config.shift_factor == 0 and self.vae.dtype != torch.float32:
self.vae.to(dtype=torch.float32)
return self
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // self.vae_scale_factor)
width = 2 * (int(width) // self.vae_scale_factor)
shape = (batch_size, num_channels_latents, height, width)
if latents is not None:
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
return latents.to(device=device, dtype=dtype), latent_image_ids
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
return latents, latent_image_ids
@staticmethod
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
return latents
@staticmethod
def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape
height = height // vae_scale_factor
width = width // vae_scale_factor
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
return latents
@staticmethod
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids.repeat(batch_size, 1, 1, 1)
latent_image_ids = latent_image_ids.reshape(
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
return latent_image_ids.to(device=device, dtype=dtype)

View File

@@ -0,0 +1,671 @@
# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, List, Optional, Union
import diffusers
import numpy as np
import torch
from diffusers import AutoencoderKL # Waiting for diffusers udpdate
from diffusers.image_processor import PipelineImageInput
from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers
from diffusers.utils import USE_PEFT_BACKEND, logging
from diffusers.utils.peft_utils import scale_lora_layers, unscale_lora_layers
from diffusers.utils.torch_utils import randn_tensor
from transformers import (
T5EncoderModel,
T5TokenizerFast,
)
from invokeai.backend.bria.bria_utils import get_original_sigmas, get_t5_prompt_embeds, is_ng_none
from invokeai.backend.bria.controlnet_bria import BriaControlNetModel
from invokeai.backend.bria.pipeline_bria import BriaPipeline
from invokeai.backend.bria.transformer_bria import BriaTransformer2DModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class BriaControlNetPipeline(BriaPipeline):
r"""
Args:
transformer ([`SD3Transformer2DModel`]):
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`T5EncoderModel`]):
Frozen text-encoder. Stable Diffusion 3 uses
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
[t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
tokenizer (`T5TokenizerFast`):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
"""
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder->transformer->vae"
_optional_components = []
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
def __init__( # EYAL - removed clip text encoder + tokenizer
self,
transformer: BriaTransformer2DModel,
scheduler: Union[FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers],
vae: AutoencoderKL,
text_encoder: T5EncoderModel,
tokenizer: T5TokenizerFast,
controlnet: BriaControlNetModel,
):
super().__init__(
transformer=transformer, scheduler=scheduler, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer
)
self.register_modules(controlnet=controlnet)
def prepare_image(
self,
image,
width,
height,
batch_size,
num_images_per_prompt,
device,
dtype,
do_classifier_free_guidance=False,
guess_mode=False,
):
if isinstance(image, torch.Tensor):
pass
else:
image = self.image_processor.preprocess(image, height=height, width=width)
image_batch_size = image.shape[0]
if image_batch_size == 1:
repeat_by = batch_size
else:
# image batch size is the same as prompt batch size
repeat_by = num_images_per_prompt
image = image.repeat_interleave(repeat_by, dim=0)
image = image.to(device=device, dtype=dtype)
if do_classifier_free_guidance and not guess_mode:
image = torch.cat([image] * 2)
return image
def prepare_control(self, control_image, width, height, batch_size, num_images_per_prompt, device, control_mode):
num_channels_latents = self.transformer.config.in_channels // 4
control_image = self.prepare_image(
image=control_image,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=self.vae.dtype,
)
height, width = control_image.shape[-2:]
# vae encode
control_image = self.vae.encode(control_image).latent_dist.sample()
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
# pack
height_control_image, width_control_image = control_image.shape[2:]
control_image = self._pack_latents(
control_image,
batch_size * num_images_per_prompt,
num_channels_latents,
height_control_image,
width_control_image,
)
# Here we ensure that `control_mode` has the same length as the control_image.
if control_mode is not None:
if not isinstance(control_mode, int):
raise ValueError(" For `BriaControlNet`, `control_mode` should be an `int` or `None`")
control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
control_mode = control_mode.view(-1, 1).expand(control_image.shape[0], 1)
return control_image, control_mode
def prepare_multi_control(
self, control_image, width, height, batch_size, num_images_per_prompt, device, control_mode
):
num_channels_latents = self.transformer.config.in_channels // 4
control_images = []
for _, control_image_ in enumerate(control_image):
control_image_ = self.prepare_image(
image=control_image_,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=self.vae.dtype,
)
height, width = control_image_.shape[-2:]
# vae encode
control_image_ = self.vae.encode(control_image_).latent_dist.sample()
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
# pack
height_control_image, width_control_image = control_image_.shape[2:]
control_image_ = self._pack_latents(
control_image_,
batch_size * num_images_per_prompt,
num_channels_latents,
height_control_image,
width_control_image,
)
control_images.append(control_image_)
control_image = control_images
# Here we ensure that `control_mode` has the same length as the control_image.
if isinstance(control_mode, list) and len(control_mode) != len(control_image):
raise ValueError(
"For Multi-ControlNet, `control_mode` must be a list of the same "
+ " length as the number of controlnets (control images) specified"
)
if not isinstance(control_mode, list):
control_mode = [control_mode] * len(control_image)
# set control mode
control_modes = []
for cmode in control_mode:
if cmode is None:
cmode = -1
control_mode = torch.tensor(cmode).expand(control_images[0].shape[0]).to(device, dtype=torch.long)
control_modes.append(control_mode)
control_mode = control_modes
return control_image, control_mode
def get_controlnet_keep(self, timesteps, control_guidance_start, control_guidance_end):
controlnet_keep = []
for i in range(len(timesteps)):
keeps = [
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
for s, e in zip(control_guidance_start, control_guidance_end, strict=False)
]
controlnet_keep.append(keeps[0] if isinstance(self.controlnet, BriaControlNetModel) else keeps)
return controlnet_keep
def get_control_start_end(self, control_guidance_start, control_guidance_end):
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = 1 # TODO - why is this 1?
control_guidance_start, control_guidance_end = (
mult * [control_guidance_start],
mult * [control_guidance_end],
)
return control_guidance_start, control_guidance_end
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 30,
timesteps: List[int] = None,
guidance_scale: float = 3.5,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
control_image: Optional[PipelineImageInput] = None,
control_mode: Optional[Union[int, List[int]]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
latent_image_ids: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
text_ids: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
max_sequence_length: int = 128,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 5.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
Examples:
Returns:
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
control_guidance_start, control_guidance_end = self.get_control_start_end(
control_guidance_start=control_guidance_start, control_guidance_end=control_guidance_end
)
# 1. Check inputs. Raise error if not correct
callback_on_step_end_tensor_inputs = (
["latents"] if callback_on_step_end_tensor_inputs is None else callback_on_step_end_tensor_inputs
)
self.check_inputs(
prompt,
height,
width,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
device = self._execution_device
# 4. Prepare timesteps
if (
isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler)
and self.scheduler.config["use_dynamic_shifting"]
):
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
# Determine image sequence length
if control_image is not None:
if isinstance(control_image, list):
image_seq_len = control_image[0].shape[1]
else:
image_seq_len = control_image.shape[1]
else:
# Use latents sequence length when no control image is provided
image_seq_len = latents.shape[1]
print(f"Using dynamic shift in pipeline with sequence length {image_seq_len}")
mu = calculate_shift(
image_seq_len,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
timesteps=None,
sigmas=sigmas,
mu=mu,
)
else:
# 5. Prepare timesteps
sigmas = get_original_sigmas(
num_train_timesteps=self.scheduler.config.num_train_timesteps, num_inference_steps=num_inference_steps
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas=sigmas
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# 6. Create tensor stating which controlnets to keep
if control_image is not None:
controlnet_keep = self.get_controlnet_keep(
timesteps=timesteps,
control_guidance_start=control_guidance_start,
control_guidance_end=control_guidance_end,
)
if diffusers.__version__ >= "0.32.0":
latent_image_ids = latent_image_ids[0]
text_ids = text_ids[0]
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
# EYAL - added the CFG loop
# 7. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
# if type(self.scheduler) != FlowMatchEulerDiscreteScheduler:
if not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
# Handling ControlNet
if control_image is not None:
if isinstance(controlnet_keep[i], list):
if isinstance(controlnet_conditioning_scale, list):
cond_scale = controlnet_conditioning_scale
else:
cond_scale = [
c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i], strict=False)
]
else:
controlnet_cond_scale = controlnet_conditioning_scale
if isinstance(controlnet_cond_scale, list):
controlnet_cond_scale = controlnet_cond_scale[0]
cond_scale = controlnet_cond_scale * controlnet_keep[i]
controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
hidden_states=latents,
controlnet_cond=control_image,
controlnet_mode=control_mode,
conditioning_scale=cond_scale,
timestep=timestep,
# guidance=guidance,
# pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)
else:
controlnet_block_samples, controlnet_single_block_samples = None, None
# This is predicts "v" from flow-matching
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
txt_ids=text_ids,
img_ids=latent_image_ids,
controlnet_block_samples=controlnet_block_samples,
controlnet_single_block_samples=controlnet_single_block_samples,
)[0]
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if output_type == "latent":
image = latents
else:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return FluxPipelineOutput(images=image)
def encode_prompt(
prompt: Union[str, List[str]],
tokenizer: T5TokenizerFast,
text_encoder: T5EncoderModel,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 128,
lora_scale: Optional[float] = None,
):
r"""
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
"""
device = device or torch.device("cuda")
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
# dynamically adjust the LoRA scale
if text_encoder is not None and USE_PEFT_BACKEND:
scale_lora_layers(text_encoder, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
dtype = text_encoder.dtype if text_encoder is not None else torch.float32
if prompt_embeds is None:
prompt_embeds = get_t5_prompt_embeds(
tokenizer,
text_encoder,
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
).to(dtype=dtype)
if negative_prompt_embeds is None:
if not is_ng_none(negative_prompt):
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds = get_t5_prompt_embeds(
tokenizer,
text_encoder,
prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
).to(dtype=dtype)
else:
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
if text_encoder is not None:
if USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(text_encoder, lora_scale)
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
return prompt_embeds, negative_prompt_embeds, text_ids
def prepare_latents(
batch_size: int,
num_channels_latents: int,
height: int,
width: int,
dtype: torch.dtype,
device: torch.device,
generator: torch.Generator,
latents: Optional[torch.FloatTensor] = None,
):
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
vae_scale_factor = 16
height = 2 * (int(height) // vae_scale_factor)
width = 2 * (int(width) // vae_scale_factor)
shape = (batch_size, num_channels_latents, height, width)
if latents is not None:
latent_image_ids = _prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
return latents.to(device=device, dtype=dtype), latent_image_ids
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = _pack_latents(latents, batch_size, num_channels_latents, height, width)
latent_image_ids = _prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
return latents, latent_image_ids
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids.repeat(batch_size, 1, 1, 1)
latent_image_ids = latent_image_ids.reshape(
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
return latent_image_ids.to(device=device, dtype=dtype)
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
return latents

View File

@@ -0,0 +1,322 @@
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
from diffusers.models.embeddings import TimestepEmbedding, get_timestep_embedding
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.normalization import AdaLayerNormContinuous
from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from invokeai.backend.bria.bria_utils import FluxPosEmbed as EmbedND
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class Timesteps(nn.Module):
def __init__(
self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1, time_theta=10000
):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
self.scale = scale
self.time_theta = time_theta
def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
scale=self.scale,
max_period=self.time_theta,
)
return t_emb
class TimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, time_theta):
super().__init__()
self.time_proj = Timesteps(
num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, time_theta=time_theta
)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
def forward(self, timestep, dtype):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=dtype)) # (N, D)
return timesteps_emb
"""
Based on FluxPipeline with several changes:
- no pooled embeddings
- We use zero padding for prompts
- No guidance embedding since this is not a distilled version
"""
class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
"""
The Transformer model introduced in Flux.
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
Parameters:
patch_size (`int`): Patch size to turn the input data into small patches.
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
patch_size: int = 1,
in_channels: int = 64,
num_layers: int = 19,
num_single_layers: int = 38,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
joint_attention_dim: int = 4096,
pooled_projection_dim: int = None,
guidance_embeds: bool = False,
axes_dims_rope: Optional[List[int]] = None,
rope_theta=10000,
time_theta=10000,
):
super().__init__()
self.out_channels = in_channels
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
axes_dims_rope = [16, 56, 56] if axes_dims_rope is None else axes_dims_rope
self.pos_embed = EmbedND(theta=rope_theta, axes_dim=axes_dims_rope)
self.time_embed = TimestepProjEmbeddings(embedding_dim=self.inner_dim, time_theta=time_theta)
# if pooled_projection_dim:
# self.pooled_text_embed = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim=self.inner_dim, act_fn="silu")
if guidance_embeds:
self.guidance_embed = TimestepProjEmbeddings(embedding_dim=self.inner_dim)
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
self.transformer_blocks = nn.ModuleList(
[
FluxTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
)
for i in range(self.config.num_layers)
]
)
self.single_transformer_blocks = nn.ModuleList(
[
FluxSingleTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
)
for i in range(self.config.num_single_layers)
]
)
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
controlnet_block_samples=None,
controlnet_single_block_samples=None,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
The [`FluxTransformer2DModel`] forward method.
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
timestep = timestep.to(hidden_states.dtype)
if guidance is not None:
guidance = guidance.to(hidden_states.dtype)
else:
guidance = None
# temb = (
# self.time_text_embed(timestep, pooled_projections)
# if guidance is None
# else self.time_text_embed(timestep, guidance, pooled_projections)
# )
temb = self.time_embed(timestep, dtype=hidden_states.dtype)
# if pooled_projections:
# temb+=self.pooled_text_embed(pooled_projections)
if guidance:
temb += self.guidance_embed(guidance, dtype=hidden_states.dtype)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if len(txt_ids.shape) == 2:
ids = torch.cat((txt_ids, img_ids), dim=0)
else:
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pos_embed(ids)
for index_block, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
# controlnet residual
if controlnet_block_samples is not None:
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
# controlnet residual
if controlnet_single_block_samples is not None:
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ controlnet_single_block_samples[index_block // interval_control]
)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

Some files were not shown because too many files have changed in this diff Show More