Compare commits

..

143 Commits

Author SHA1 Message Date
psychedelicious
c451f52ea3 chore(ui): lint 2024-08-22 21:00:09 +10:00
psychedelicious
8a2c78f2e1 fix(ui): dynamic prompts not recalculating when deleting or updating a style preset
The root cause was the active style preset not being reset when it was deleted, or no longer present in the list of style presets.

- Add extra reducer to `stylePresetSlice` to reset the active preset if it is deleted or otherwise unavailable
- Update the dynamic prompts listener to trigger on delete/update/list of style presets
2024-08-22 21:00:09 +10:00
psychedelicious
bcc78bde9b chore: bump version to v4.2.8 2024-08-22 21:00:09 +10:00
Васянатор
054bb6fe0a translationBot(ui): update translation (Russian)
Currently translated at 100.0% (1367 of 1367 strings)

Co-authored-by: Васянатор <ilabulanov339@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/ru/
Translation: InvokeAI/Web UI
2024-08-22 13:09:56 +10:00
Riccardo Giovanetti
4f4aa6d92e translationBot(ui): update translation (Italian)
Currently translated at 98.4% (1346 of 1367 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.4% (1346 of 1367 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2024-08-22 13:09:56 +10:00
Hosted Weblate
eac51ac6f5 translationBot(ui): update translation files
Updated by "Cleanup translation files" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI
2024-08-22 13:09:56 +10:00
psychedelicious
9f349a7c0a fix(ui): do not constrain width of hide/show boards button
lets translations display fully
2024-08-22 11:36:07 +10:00
psychedelicious
918afa5b15 fix(ui): show more of current board name 2024-08-22 11:36:07 +10:00
psychedelicious
eb1113f95c feat(ui): add translation string for "Upscale" 2024-08-22 11:36:07 +10:00
psychedelicious
4f4ba7b462 tidy(ui): clean up ActiveStylePreset markup 2024-08-21 09:06:41 +10:00
Mary Hipp
2298be0e6b fix(ui): error handling if unable to convert image URL to blob 2024-08-21 09:06:41 +10:00
Mary Hipp
63494dfca7 remove extra slash in exports path 2024-08-21 09:06:41 +10:00
Mary Hipp
36a1d39454 fix(ui): handle badge styling when template name is long 2024-08-21 09:06:41 +10:00
Mary Hipp
a6f6d5c400 fix(ui): add loading state to button when creating or updating a style preset 2024-08-21 09:06:41 +10:00
Mary Hipp
e85f221aca fix(ui): clear prompt template when prompts are recalled 2024-08-21 09:04:35 +10:00
Mary Hipp
d4797e37dc fix(ui): properly unwrap delete style preset API request so that error is caught 2024-08-19 16:12:39 -04:00
Mary Hipp
3e7923d072 fix(api): allow updating of type for style preset 2024-08-19 16:12:39 -04:00
psychedelicious
a85d69ce3d tidy(ui): getViewModeChunks.tsx -> .ts 2024-08-19 08:25:39 +10:00
psychedelicious
96db006c99 fix(ui): edge case with getViewModeChunks 2024-08-19 08:25:39 +10:00
psychedelicious
8ca57d03d8 tests(ui): add tests for getViewModeChunks 2024-08-19 08:25:39 +10:00
psychedelicious
6c404ce5f8 fix(ui): prompt template preset preview out of order 2024-08-19 08:25:39 +10:00
psychedelicious
584e07182b fix(ui): use translations for style preset strings 2024-08-17 21:27:53 +10:00
psychedelicious
f787e9acf6 chore: bump version v4.2.8rc2 2024-08-16 21:47:06 +10:00
psychedelicious
5a24b89e54 fix(app): include style preset defaults in build 2024-08-16 21:47:06 +10:00
psychedelicious
9b482e2a4f chore: bump version to v4.2.8rc1 2024-08-16 10:53:19 +10:00
Max
df4dbe2d57 Fix invoke.sh not detecting symlinks
When invoke.sh is executed using a symlink with a working directory outside of InvokeAI's root directory, it will fail.

invoke.sh attempts to cd into the correct directory at the start of the script, but will cd into the directory of the symlink instead. This commit fixes that.
2024-08-16 10:40:59 +10:00
psychedelicious
713bd11177 feat(ui, api): prompt template export (#6745)
## Summary

Adds option to download all prompt templates to a CSV

## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
2024-08-16 10:38:50 +10:00
psychedelicious
182571df4b Merge branch 'main' into maryhipp/export-presets 2024-08-16 10:17:07 +10:00
psychedelicious
29bfe492b6 ui: translations update from weblate (#6746)
Translations update from [Hosted Weblate](https://hosted.weblate.org)
for [InvokeAI/Web
UI](https://hosted.weblate.org/projects/invokeai/web-ui/).



Current translation status:

![Weblate translation
status](https://hosted.weblate.org/widget/invokeai/web-ui/horizontal-auto.svg)
2024-08-16 10:16:51 +10:00
psychedelicious
3fb4e3050c feat(ui): focus in textarea after inserting placeholder 2024-08-16 10:14:25 +10:00
psychedelicious
39c7ec3cd9 feat(ui): per type fallbacks for templates 2024-08-16 10:11:43 +10:00
psychedelicious
26bfbdec7f feat(ui): use buttons instead of menu for preset import/export 2024-08-16 09:58:19 +10:00
psychedelicious
7a3eaa8da9 feat(api): save file as prompt_templates.csv 2024-08-16 09:51:46 +10:00
Mary Hipp
599db7296f export only user style presets 2024-08-15 16:07:32 -04:00
Riccardo Giovanetti
042aab4295 translationBot(ui): update translation (Italian)
Currently translated at 98.6% (1340 of 1359 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2024-08-15 20:44:02 +02:00
Mary Hipp
24f298283f clean up, add context menu to import/download templates 2024-08-15 12:39:55 -04:00
Mary Hipp
68dac6349d Merge remote-tracking branch 'origin/main' into maryhipp/export-presets 2024-08-15 11:21:56 -04:00
chainchompa
b675fc19e8 feat: add base prop for selectedWorkflow to allow loading a workflow on launch (#6742)
## Summary
added a base prop for selectedWorkflow to allow loading a workflow on
launch

<!--A description of the changes in this PR. Include the kind of change
(fix, feature, docs, etc), the "why" and the "how". Screenshots or
videos are useful for frontend changes.-->

## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions
can test by loading InvokeAIUI with a selectedWorkflow prop of the
workflow ID
<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
2024-08-15 10:52:23 -04:00
chainchompa
659019cfd6 Merge branch 'main' into chainchompa/preselect-workflows 2024-08-15 10:40:44 -04:00
Mary Hipp
dcd61e1f82 pin ruff version in python check gha 2024-08-15 09:47:49 -04:00
Mary Hipp
f5c99b1488 exclude jupyter notebooks from ruff 2024-08-15 09:47:49 -04:00
Mary Hipp
810be3e1d4 update import directions to include JSON 2024-08-15 09:47:49 -04:00
psychedelicious
60d754d1df feat(api): tidy style presets import logic
- Extract parsing into utility function
- Log import errors
- Forbid extra properties on the imported data
2024-08-15 09:47:49 -04:00
psychedelicious
bd07c86db9 feat(ui): make style preset menu trigger look like button 2024-08-15 09:47:49 -04:00
psychedelicious
bcbf8b6bd8 feat(ui): revert to using {prompt} for prompt template placeholder 2024-08-15 09:47:49 -04:00
psychedelicious
356661459b feat(api): support JSON for preset imports
This allows us to support Fooocus format presets.
2024-08-15 09:47:49 -04:00
psychedelicious
deb917825e feat(api): use pydantic validation during style preset import
- Enforce name is present and not an empty string
- Provide empty string as default for positive and negative prompt
- Add `positive_prompt` as validation alias for `prompt` field
- Strip whitespace automatically
- Create `TypeAdapter` to validate the whole list in one go
2024-08-15 09:47:49 -04:00
psychedelicious
15415c6d85 feat(ui): use dropzone for style preset upload
Easier to accept multiple file types and supper drag and drop in the future.
2024-08-15 09:47:49 -04:00
Mary Hipp
76b0380b5f feat(ui): create component to upload CSV of style presets to import 2024-08-15 09:47:49 -04:00
Mary Hipp
2d58754789 feat(api): add endpoint to take a CSV, parse it, validate it, and create many style preset entries 2024-08-15 09:47:49 -04:00
chainchompa
9cdf1f599c Merge branch 'main' into chainchompa/preselect-workflows 2024-08-15 09:25:19 -04:00
chainchompa
268be97ba0 remove ref, make options optional for useGetLoadWorkflow 2024-08-15 09:18:41 -04:00
Mary Hipp
a9014673a0 wip export 2024-08-15 09:00:11 -04:00
psychedelicious
d36c43a10f ui: translations update from weblate (#6727)
Translations update from [Hosted Weblate](https://hosted.weblate.org)
for [InvokeAI/Web
UI](https://hosted.weblate.org/projects/invokeai/web-ui/).



Current translation status:

![Weblate translation
status](https://hosted.weblate.org/widget/invokeai/web-ui/horizontal-auto.svg)
2024-08-15 08:48:03 +10:00
Phrixus2023
54a5c4e482 translationBot(ui): update translation (Chinese (Simplified))
Currently translated at 98.1% (1296 of 1320 strings)

Co-authored-by: Phrixus2023 <920414016@qq.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/zh_Hans/
Translation: InvokeAI/Web UI
2024-08-15 00:46:01 +02:00
Riccardo Giovanetti
5e09a244e3 translationBot(ui): update translation (Italian)
Currently translated at 98.5% (1336 of 1355 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.5% (1302 of 1321 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.6% (1302 of 1320 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2024-08-15 00:46:01 +02:00
chainchompa
88648dca1a change selectedWorkflow to selectedWorkflowId 2024-08-14 11:22:37 -04:00
chainchompa
8840df2b00 Merge branch 'main' into chainchompa/preselect-workflows 2024-08-14 09:02:12 -04:00
chainchompa
af159acbdf cleanup 2024-08-14 08:58:38 -04:00
chainchompa
471719bbbe add base prop for selectedWorkflow to allow loading a workflow on launch 2024-08-14 08:47:02 -04:00
psychedelicious
b126f2ffd5 feat(ui, api): prompt templates (#6729)
## Summary

Adds prompt templates to the UI. Demo video is attached.
* added default prompt templates to seed database on startup (these
cannot be edited or deleted by users via the UI)
* can create fresh prompt template, create from an image in gallery that
has prompt metadata, or copy an existing prompt template and modify
* if a template is active, can view what your prompt will be invoked as
by switching to "view mode"



https://github.com/user-attachments/assets/32d84e0c-b04c-48da-bae5-aa6eb685d209



## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
2024-08-14 12:49:31 +10:00
psychedelicious
9938f12ef0 Merge branch 'main' into maryhipp/style-presets 2024-08-14 12:33:30 +10:00
psychedelicious
982c266073 tidy: remove extra characters in prompt templates 2024-08-14 12:31:57 +10:00
psychedelicious
5c37391883 fix(ui): do not show [prompt] in preset preview 2024-08-14 12:29:05 +10:00
psychedelicious
ddeafc6833 fix(ui): minimize layout shift when overlaying preset prompt preview 2024-08-14 12:24:57 +10:00
psychedelicious
41b2d5d013 fix(ui): prompt preview not working preset starts with [prompt] 2024-08-14 12:21:38 +10:00
psychedelicious
29d6f48901 fix(ui): prompt shows thru prompt label text 2024-08-14 12:01:49 +10:00
psychedelicious
d5c9f4e47f chore(ui): revert framer-motion upgrade
`framer-motion` 11 breaks a lot of stuff in profoundly unintuitive ways, holy crap. UI lib rolled back its dep, pulling in latest version of that
2024-08-14 06:12:00 +10:00
psychedelicious
24d73387d8 build(ui): fix chakra deps
We had multiple versions of @emotion/react, stemming from an extraneous dependency on @chakra-ui/react. Removed the extraneosu dep
2024-08-14 06:12:00 +10:00
Mary Hipp
e0d3927265 feat: add flag for allowPrivateStylePresets that shows a type field when creating a style preset 2024-08-13 14:08:54 -04:00
Mary Hipp
e5f7c2a9b7 add type safety / validation to form data payloads and allow type to be passed through api 2024-08-13 13:00:31 -04:00
Mary Hipp
b0760710d5 add the rest of default style presets, update image service to return default images correctly by name, add tooltip popover to images in UI 2024-08-13 11:33:15 -04:00
Mary Hipp
764accc921 update config docstring 2024-08-12 15:17:40 -04:00
Mary Hipp
6a01fce9c1 fix payloads for stringified data 2024-08-12 15:16:22 -04:00
Mary Hipp
9c732ac3b1 Merge remote-tracking branch 'origin/main' into maryhipp/style-presets 2024-08-12 14:53:45 -04:00
Mary Hipp
b70891c661 update descriptoin of placeholder in modal 2024-08-12 13:37:04 -04:00
Mary Hipp
4dbf851741 ui: add labels to prompt boxes 2024-08-12 13:33:39 -04:00
Mary Hipp
6c927a9fd4 move mdoal state into nanostore 2024-08-12 12:46:02 -04:00
Mary Hipp
096f001634 ui: add ability to copy template 2024-08-12 12:32:31 -04:00
Mary Hipp
4837e578b2 api: update dir path for style preset images, update payload for create/update formdata 2024-08-12 12:00:14 -04:00
Mary Hipp
1e547ef912 UI more pr feedback 2024-08-12 11:59:25 -04:00
psychedelicious
f6b8970bd1 fix(app): create reference to events task to prevent accidental GC
This wasn't a problem, but it's advised in the official docs so I've done it.
2024-08-12 07:49:58 +10:00
psychedelicious
29325a7214 fix(app): use asyncio queue and existing event loop for events
Around the time we (I) implemented pydantic events, I noticed a short pause between progress images every 4 or 5 steps when generating with SDXL. It didn't happen with SD1.5, but I did notice that with SD1.5, we'd get 4 or 5 progress events simultaneously. I'd expect one event every ~25ms, matching my it/s with SD1.5. Mysterious!

Digging in, I found an issue is related to our use of a synchronous queue for events. When the event queue is empty, we must call `asyncio.sleep` before checking again. We were sleeping for 100ms.

Said another way, every time we clear the event queue, we have to wait 100ms before another event can be dispatched, even if it is put on the queue immediately after we start waiting. In practice, this means our events get buffered into batches, dispatched once every 100ms.

This explains why I was getting batches of 4 or 5 SD1.5 progress events at once, but not the intermittent SDXL delay.

But this 100ms wait has another effect when the events are put on the queue in intervals that don't perfectly line up with the 100ms wait. This is most noticeable when the time between events is >100ms, and can add up to 100ms delay before the event is dispatched.

For example, say the queue is empty and we start a 100ms wait. Then, immediately after - like 0.01ms later - we push an event on to the queue. We still need to wait another 99.9ms before that event will be dispatched. That's the SDXL delay.

The easy fix is to reduce the sleep to something like 0.01 seconds, but this feels kinda dirty. Can't we just wait on the queue and dispatch every event immediately? Not with the normal synchronous queue - but we can with `asyncio.Queue`.

I switched the events queue to use `asyncio.Queue` (as seen in this commit), which lets us asynchronous wait on the queue in a loop.

Unfortunately, I ran into another issue - events now felt like their timing was inconsistent, but in a different way than with the 100ms sleep. The time between pushing events on the queue and dispatching them was not consistently ~0ms as I'd expect - it was highly variable from ~0ms up to ~100ms.

This is resolved by passing the asyncio loop directly into the events service and using its methods to create the task and interact with the queue. I don't fully understand why this resolved the issue, because either way we are interacting with the same event loop (as shown by `asyncio.get_running_loop()`). I suppose there's some scheduling magic happening.
2024-08-12 07:49:58 +10:00
psychedelicious
8ecf72838d fix(api): image downloads with correct filename
Closes #6730
2024-08-10 09:53:56 -04:00
psychedelicious
c3ab8a6aa8 chore(ui): bump rest of deps 2024-08-10 07:45:23 -04:00
psychedelicious
1931aa3e70 chore(ui): typegen 2024-08-10 07:45:23 -04:00
psychedelicious
d3d8055055 feat(ui): update typegen script 2024-08-10 07:45:23 -04:00
psychedelicious
476b0a0403 chore(ui): bump openapi-typescript 2024-08-10 07:45:23 -04:00
psychedelicious
f66584713c fix(api): sort OpenAPI schema properties for InvocationOutputMap
This makes the schema output deterministic!
2024-08-10 07:45:23 -04:00
psychedelicious
33624fc2fa fix(api): duplicate operation id for get_image_full
There's a FastAPI bug that results in the OpenAPI spec outputting the same operation id for each operation when specifying multiple HTTP methods.

- Discussion: https://github.com/tiangolo/fastapi/discussions/8449
- Pending PR to fix: https://github.com/tiangolo/fastapi/pull/10694

In our case, we have a `get_image_full` endpoint that handles GET and HEAD.

This results in an invalid OpenAPI schema. A workaround is to use two route decorators for the operation handler. This works as expected - HEAD requests get the header, and GET requests get the resource. And the OpenAPI schema is valid.
2024-08-10 07:45:23 -04:00
Mary Hipp
41c3e73a3c fix tests 2024-08-09 16:31:42 -04:00
Mary Hipp
97553a7de2 API/DB updates per PR feedback 2024-08-09 16:27:37 -04:00
Mary Hipp
12ba15bfa9 UI updates per PR feedback 2024-08-09 16:00:13 -04:00
Mary Hipp
09d1e190e7 show warning for maxUpscaleDimension if model tab is disabled 2024-08-09 14:07:55 -04:00
Mary Hipp
8eb5d08499 missed translation 2024-08-08 16:01:16 -04:00
Mary Hipp
9be6acde7d require name to submit style preset 2024-08-08 15:53:21 -04:00
Mary Hipp
5f83bb0069 update config docstring 2024-08-08 15:20:43 -04:00
Mary Hipp
b138882abc fix tests? 2024-08-08 15:18:32 -04:00
Mary Hipp
0cd7cdb52e remove send2trash 2024-08-08 15:13:36 -04:00
Mary Hipp
1d8b7e2bcf ruff 2024-08-08 15:08:45 -04:00
Mary Hipp
6461f4758d lint fix 2024-08-08 15:07:58 -04:00
Mary Hipp
3189ab6863 get dynamic prompts working 2024-08-08 15:07:23 -04:00
Mary Hipp
3f9a674d4b seed default presets and handle them in UI 2024-08-08 15:02:41 -04:00
Mary Hipp
587f59b25b focus on prompt textarea when exiting view mode by clicking 2024-08-08 14:38:50 -04:00
Mary Hipp
4952eada87 ruff format 2024-08-08 14:22:40 -04:00
Mary Hipp
581029ebaa ruff 2024-08-08 14:21:37 -04:00
Mary Hipp
42d68780de lint 2024-08-08 14:19:33 -04:00
Mary Hipp
28032a2f80 more cleanup 2024-08-08 14:18:05 -04:00
Mary Hipp
e381e021e9 knip lint 2024-08-08 14:00:17 -04:00
Mary Hipp
641af64f93 regnerate schema 2024-08-08 13:58:25 -04:00
Mary Hipp
a7b83c8b5b Merge remote-tracking branch 'origin/main' into maryhipp/style-presets 2024-08-08 13:56:59 -04:00
Mary Hipp
4cc41e0188 translations and lint fix 2024-08-08 13:56:37 -04:00
Mary Hipp
442fc02429 resize images to 100x100 for style preset images 2024-08-08 12:56:55 -04:00
Mary Hipp
9a4d075074 fix path for style_preset_images, fix png type when converting blobs to files, built view mode components 2024-08-08 12:31:20 -04:00
Sergey Borisov
17ff8196cb Remove tmp code 2024-08-07 22:06:05 -04:00
Sergey Borisov
68f993998a Add support for norm layer 2024-08-07 22:06:05 -04:00
Sergey Borisov
7da6120b39 Fix LoKR refactor bug 2024-08-07 22:06:05 -04:00
blessedcoolant
6cd40965c4 Depth Anything V2 (#6674)
- Updated the previous DepthAnything manual implementation to use the
`transformers` implementation instead. So we can get upstream features.
- Plugged in the DepthAnything models to be handled by Invoke's Model
Manager.
- `small_v2` model will use DepthAnythingV2. This has been added as a
new model option and is now also the default in the Linear UI.


![opera_TxRhmbFole](https://github.com/user-attachments/assets/2a25abe3-ba0b-4f97-b75a-2ce5fd6246e6)


# Merge

Review and merge.
2024-08-07 20:26:58 +05:30
Kent Keirsey
408a1d6dbb Merge branch 'main' into depth_anything_v2 2024-08-07 10:45:56 -04:00
Mary Hipp
0b0abfbe8f clean up image implementation 2024-08-07 10:36:38 -04:00
Mary Hipp
cc96dcf0ed style preset images 2024-08-07 09:58:27 -04:00
Mary Hipp
2604fd9fde a whole bunch of stuff 2024-08-06 15:31:13 -04:00
Mary Hipp
857d74bbfe wip apply and calculate prompt with interpolation 2024-08-05 19:11:48 -04:00
Mary Hipp
fd7a635777 (ui) the most basic crud ui: view list of presets, create a new preset, edit/delete existing presets 2024-08-05 15:48:23 -04:00
Mary Hipp
af9110e964 fix prompt concat logic 2024-08-05 13:42:28 -04:00
Mary Hipp
a61209206b remove custom SDXL prompts component 2024-08-05 13:40:46 -04:00
Mary Hipp
e05cc62e5f add style presets API layer to UI 2024-08-05 13:37:07 -04:00
blessedcoolant
4f8a4b0f22 Merge branch 'main' into depth_anything_v2 2024-08-03 00:38:57 +05:30
blessedcoolant
a743f3c9b5 fix: implement model to func for depth anything 2024-08-03 00:37:17 +05:30
Mary Hipp
217fe40d99 feat(api): add style_presets router, make sure all CRUD is working, add is_default 2024-08-02 12:29:54 -04:00
Mary Hipp
b76bf50b93 feat(db,api): create new table for style presets, build out record storage service for style presets 2024-08-01 22:20:11 -04:00
blessedcoolant
332bc9da5b fix: Update depth anything node default to v2 2024-07-31 23:52:29 +05:30
blessedcoolant
08def3da95 fix: Update canvas depth anything processor default to v2 2024-07-31 23:50:13 +05:30
blessedcoolant
daf899f9c4 fix: Move the manual image resizing out of the depth anything pipeline 2024-07-31 23:38:12 +05:30
blessedcoolant
13fb2d1f49 fix: Add Depth Anything V2 as a new option
It is also now the default in the UI replacing Depth Anything V1 small
2024-07-31 23:29:43 +05:30
blessedcoolant
95dde802ea fix: assert the return depth map to be a PIL image 2024-07-31 23:22:01 +05:30
blessedcoolant
b4cf78a95d fix: make DA Pipeline a subclass of RawModel 2024-07-31 21:14:49 +05:30
blessedcoolant
18f89ed5ed fix: Make DepthAnything work with Invoke's Model Management 2024-07-31 03:57:54 +05:30
blessedcoolant
f170697ebe Merge branch 'main' into depth_anything_v2 2024-07-31 00:53:32 +05:30
blessedcoolant
556c6a1d84 fix: Update DepthAnything to use the transformers implementation 2024-07-31 00:51:55 +05:30
blessedcoolant
e5d9ca013e fix: use v1 models for large and base versions 2024-07-25 17:24:12 +05:30
blessedcoolant
4166c756ce wip: depth_anything_v2 init lint fixes 2024-07-25 14:41:22 +05:30
blessedcoolant
4f0dfbd34d wip: depth_anything_v2 initial implementation 2024-07-25 13:53:06 +05:30
140 changed files with 23884 additions and 21390 deletions

View File

@@ -62,7 +62,7 @@ jobs:
- name: install ruff
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
run: pip install ruff
run: pip install ruff==0.6.0
shell: bash
- name: ruff check

View File

@@ -17,7 +17,7 @@
set -eu
# Ensure we're in the correct folder in case user's CWD is somewhere else
scriptdir=$(dirname "$0")
scriptdir=$(dirname $(readlink -f "$0"))
cd "$scriptdir"
. .venv/bin/activate

View File

@@ -1,5 +1,6 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import asyncio
from logging import Logger
import torch
@@ -31,6 +32,8 @@ from invokeai.app.services.session_processor.session_processor_default import (
)
from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.app.services.style_preset_images.style_preset_images_disk import StylePresetImageFileStorageDisk
from invokeai.app.services.style_preset_records.style_preset_records_sqlite import SqliteStylePresetRecordsStorage
from invokeai.app.services.urls.urls_default import LocalUrlService
from invokeai.app.services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
@@ -63,7 +66,12 @@ class ApiDependencies:
invoker: Invoker
@staticmethod
def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger) -> None:
def initialize(
config: InvokeAIAppConfig,
event_handler_id: int,
loop: asyncio.AbstractEventLoop,
logger: Logger = logger,
) -> None:
logger.info(f"InvokeAI version {__version__}")
logger.info(f"Root directory = {str(config.root_path)}")
@@ -74,6 +82,7 @@ class ApiDependencies:
image_files = DiskImageFileStorage(f"{output_folder}/images")
model_images_folder = config.models_path
style_presets_folder = config.style_presets_path
db = init_db(config=config, logger=logger, image_files=image_files)
@@ -84,7 +93,7 @@ class ApiDependencies:
board_images = BoardImagesService()
board_records = SqliteBoardRecordStorage(db=db)
boards = BoardService()
events = FastAPIEventService(event_handler_id)
events = FastAPIEventService(event_handler_id, loop=loop)
bulk_download = BulkDownloadService()
image_records = SqliteImageRecordStorage(db=db)
images = ImageService()
@@ -109,6 +118,8 @@ class ApiDependencies:
session_queue = SqliteSessionQueue(db=db)
urls = LocalUrlService()
workflow_records = SqliteWorkflowRecordsStorage(db=db)
style_preset_records = SqliteStylePresetRecordsStorage(db=db)
style_preset_image_files = StylePresetImageFileStorageDisk(style_presets_folder / "images")
services = InvocationServices(
board_image_records=board_image_records,
@@ -134,6 +145,8 @@ class ApiDependencies:
workflow_records=workflow_records,
tensors=tensors,
conditioning=conditioning,
style_preset_records=style_preset_records,
style_preset_image_files=style_preset_image_files,
)
ApiDependencies.invoker = Invoker(services)

View File

@@ -218,9 +218,8 @@ async def get_image_workflow(
raise HTTPException(status_code=404)
@images_router.api_route(
@images_router.get(
"/i/{image_name}/full",
methods=["GET", "HEAD"],
operation_id="get_image_full",
response_class=Response,
responses={
@@ -231,6 +230,18 @@ async def get_image_workflow(
404: {"description": "Image not found"},
},
)
@images_router.head(
"/i/{image_name}/full",
operation_id="get_image_full_head",
response_class=Response,
responses={
200: {
"description": "Return the full-resolution image",
"content": {"image/png": {}},
},
404: {"description": "Image not found"},
},
)
async def get_image_full(
image_name: str = Path(description="The name of full-resolution image file to get"),
) -> Response:
@@ -242,6 +253,7 @@ async def get_image_full(
content = f.read()
response = Response(content, media_type="image/png")
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
response.headers["Content-Disposition"] = f'inline; filename="{image_name}"'
return response
except Exception:
raise HTTPException(status_code=404)

View File

@@ -0,0 +1,274 @@
import csv
import io
import json
import traceback
from typing import Optional
import pydantic
from fastapi import APIRouter, File, Form, HTTPException, Path, Response, UploadFile
from fastapi.responses import FileResponse
from PIL import Image
from pydantic import BaseModel, Field
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.api.routers.model_manager import IMAGE_MAX_AGE
from invokeai.app.services.style_preset_images.style_preset_images_common import StylePresetImageFileNotFoundException
from invokeai.app.services.style_preset_records.style_preset_records_common import (
InvalidPresetImportDataError,
PresetData,
PresetType,
StylePresetChanges,
StylePresetNotFoundError,
StylePresetRecordWithImage,
StylePresetWithoutId,
UnsupportedFileTypeError,
parse_presets_from_file,
)
class StylePresetFormData(BaseModel):
name: str = Field(description="Preset name")
positive_prompt: str = Field(description="Positive prompt")
negative_prompt: str = Field(description="Negative prompt")
type: PresetType = Field(description="Preset type")
style_presets_router = APIRouter(prefix="/v1/style_presets", tags=["style_presets"])
@style_presets_router.get(
"/i/{style_preset_id}",
operation_id="get_style_preset",
responses={
200: {"model": StylePresetRecordWithImage},
},
)
async def get_style_preset(
style_preset_id: str = Path(description="The style preset to get"),
) -> StylePresetRecordWithImage:
"""Gets a style preset"""
try:
image = ApiDependencies.invoker.services.style_preset_image_files.get_url(style_preset_id)
style_preset = ApiDependencies.invoker.services.style_preset_records.get(style_preset_id)
return StylePresetRecordWithImage(image=image, **style_preset.model_dump())
except StylePresetNotFoundError:
raise HTTPException(status_code=404, detail="Style preset not found")
@style_presets_router.patch(
"/i/{style_preset_id}",
operation_id="update_style_preset",
responses={
200: {"model": StylePresetRecordWithImage},
},
)
async def update_style_preset(
image: Optional[UploadFile] = File(description="The image file to upload", default=None),
style_preset_id: str = Path(description="The id of the style preset to update"),
data: str = Form(description="The data of the style preset to update"),
) -> StylePresetRecordWithImage:
"""Updates a style preset"""
if image is not None:
if not image.content_type or not image.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
contents = await image.read()
try:
pil_image = Image.open(io.BytesIO(contents))
except Exception:
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
raise HTTPException(status_code=415, detail="Failed to read image")
try:
ApiDependencies.invoker.services.style_preset_image_files.save(style_preset_id, pil_image)
except ValueError as e:
raise HTTPException(status_code=409, detail=str(e))
else:
try:
ApiDependencies.invoker.services.style_preset_image_files.delete(style_preset_id)
except StylePresetImageFileNotFoundException:
pass
try:
parsed_data = json.loads(data)
validated_data = StylePresetFormData(**parsed_data)
name = validated_data.name
type = validated_data.type
positive_prompt = validated_data.positive_prompt
negative_prompt = validated_data.negative_prompt
except pydantic.ValidationError:
raise HTTPException(status_code=400, detail="Invalid preset data")
preset_data = PresetData(positive_prompt=positive_prompt, negative_prompt=negative_prompt)
changes = StylePresetChanges(name=name, preset_data=preset_data, type=type)
style_preset_image = ApiDependencies.invoker.services.style_preset_image_files.get_url(style_preset_id)
style_preset = ApiDependencies.invoker.services.style_preset_records.update(
style_preset_id=style_preset_id, changes=changes
)
return StylePresetRecordWithImage(image=style_preset_image, **style_preset.model_dump())
@style_presets_router.delete(
"/i/{style_preset_id}",
operation_id="delete_style_preset",
)
async def delete_style_preset(
style_preset_id: str = Path(description="The style preset to delete"),
) -> None:
"""Deletes a style preset"""
try:
ApiDependencies.invoker.services.style_preset_image_files.delete(style_preset_id)
except StylePresetImageFileNotFoundException:
pass
ApiDependencies.invoker.services.style_preset_records.delete(style_preset_id)
@style_presets_router.post(
"/",
operation_id="create_style_preset",
responses={
200: {"model": StylePresetRecordWithImage},
},
)
async def create_style_preset(
image: Optional[UploadFile] = File(description="The image file to upload", default=None),
data: str = Form(description="The data of the style preset to create"),
) -> StylePresetRecordWithImage:
"""Creates a style preset"""
try:
parsed_data = json.loads(data)
validated_data = StylePresetFormData(**parsed_data)
name = validated_data.name
type = validated_data.type
positive_prompt = validated_data.positive_prompt
negative_prompt = validated_data.negative_prompt
except pydantic.ValidationError:
raise HTTPException(status_code=400, detail="Invalid preset data")
preset_data = PresetData(positive_prompt=positive_prompt, negative_prompt=negative_prompt)
style_preset = StylePresetWithoutId(name=name, preset_data=preset_data, type=type)
new_style_preset = ApiDependencies.invoker.services.style_preset_records.create(style_preset=style_preset)
if image is not None:
if not image.content_type or not image.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
contents = await image.read()
try:
pil_image = Image.open(io.BytesIO(contents))
except Exception:
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
raise HTTPException(status_code=415, detail="Failed to read image")
try:
ApiDependencies.invoker.services.style_preset_image_files.save(new_style_preset.id, pil_image)
except ValueError as e:
raise HTTPException(status_code=409, detail=str(e))
preset_image = ApiDependencies.invoker.services.style_preset_image_files.get_url(new_style_preset.id)
return StylePresetRecordWithImage(image=preset_image, **new_style_preset.model_dump())
@style_presets_router.get(
"/",
operation_id="list_style_presets",
responses={
200: {"model": list[StylePresetRecordWithImage]},
},
)
async def list_style_presets() -> list[StylePresetRecordWithImage]:
"""Gets a page of style presets"""
style_presets_with_image: list[StylePresetRecordWithImage] = []
style_presets = ApiDependencies.invoker.services.style_preset_records.get_many()
for preset in style_presets:
image = ApiDependencies.invoker.services.style_preset_image_files.get_url(preset.id)
style_preset_with_image = StylePresetRecordWithImage(image=image, **preset.model_dump())
style_presets_with_image.append(style_preset_with_image)
return style_presets_with_image
@style_presets_router.get(
"/i/{style_preset_id}/image",
operation_id="get_style_preset_image",
responses={
200: {
"description": "The style preset image was fetched successfully",
},
400: {"description": "Bad request"},
404: {"description": "The style preset image could not be found"},
},
status_code=200,
)
async def get_style_preset_image(
style_preset_id: str = Path(description="The id of the style preset image to get"),
) -> FileResponse:
"""Gets an image file that previews the model"""
try:
path = ApiDependencies.invoker.services.style_preset_image_files.get_path(style_preset_id)
response = FileResponse(
path,
media_type="image/png",
filename=style_preset_id + ".png",
content_disposition_type="inline",
)
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
return response
except Exception:
raise HTTPException(status_code=404)
@style_presets_router.get(
"/export",
operation_id="export_style_presets",
responses={200: {"content": {"text/csv": {}}, "description": "A CSV file with the requested data."}},
status_code=200,
)
async def export_style_presets():
# Create an in-memory stream to store the CSV data
output = io.StringIO()
writer = csv.writer(output)
# Write the header
writer.writerow(["name", "prompt", "negative_prompt"])
style_presets = ApiDependencies.invoker.services.style_preset_records.get_many(type=PresetType.User)
for preset in style_presets:
writer.writerow([preset.name, preset.preset_data.positive_prompt, preset.preset_data.negative_prompt])
csv_data = output.getvalue()
output.close()
return Response(
content=csv_data,
media_type="text/csv",
headers={"Content-Disposition": "attachment; filename=prompt_templates.csv"},
)
@style_presets_router.post(
"/import",
operation_id="import_style_presets",
)
async def import_style_presets(file: UploadFile = File(description="The file to import")):
try:
style_presets = await parse_presets_from_file(file)
ApiDependencies.invoker.services.style_preset_records.create_many(style_presets)
except InvalidPresetImportDataError as e:
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
raise HTTPException(status_code=400, detail=str(e))
except UnsupportedFileTypeError as e:
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
raise HTTPException(status_code=415, detail=str(e))

View File

@@ -30,6 +30,7 @@ from invokeai.app.api.routers import (
images,
model_manager,
session_queue,
style_presets,
utilities,
workflows,
)
@@ -55,11 +56,13 @@ mimetypes.add_type("text/css", ".css")
torch_device_name = TorchDevice.get_torch_device_name()
logger.info(f"Using torch device: {torch_device_name}")
loop = asyncio.new_event_loop()
@asynccontextmanager
async def lifespan(app: FastAPI):
# Add startup event to load dependencies
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, loop=loop, logger=logger)
yield
# Shut down threads
ApiDependencies.shutdown()
@@ -106,6 +109,7 @@ app.include_router(board_images.board_images_router, prefix="/api")
app.include_router(app_info.app_router, prefix="/api")
app.include_router(session_queue.session_queue_router, prefix="/api")
app.include_router(workflows.workflows_router, prefix="/api")
app.include_router(style_presets.style_presets_router, prefix="/api")
app.openapi = get_openapi_func(app)
@@ -184,8 +188,6 @@ def invoke_api() -> None:
check_cudnn(logger)
# Start our own event loop for eventing usage
loop = asyncio.new_event_loop()
config = uvicorn.Config(
app=app,
host=app_config.host,

View File

@@ -21,6 +21,8 @@ from controlnet_aux import (
from controlnet_aux.util import HWC3, ade_palette
from PIL import Image
from pydantic import BaseModel, Field, field_validator, model_validator
from transformers import pipeline
from transformers.pipelines import DepthEstimationPipeline
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
@@ -44,13 +46,12 @@ from invokeai.app.invocations.util import validate_begin_end_step, validate_weig
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
from invokeai.backend.image_util.canny import get_canny_edges
from invokeai.backend.image_util.depth_anything import DEPTH_ANYTHING_MODELS, DepthAnythingDetector
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector
from invokeai.backend.image_util.hed import HEDProcessor
from invokeai.backend.image_util.lineart import LineartProcessor
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
from invokeai.backend.util.devices import TorchDevice
class ControlField(BaseModel):
@@ -592,7 +593,14 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
return color_map
DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small"]
DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small", "small_v2"]
# DepthAnything V2 Small model is licensed under Apache 2.0 but not the base and large models.
DEPTH_ANYTHING_MODELS = {
"large": "LiheYoung/depth-anything-large-hf",
"base": "LiheYoung/depth-anything-base-hf",
"small": "LiheYoung/depth-anything-small-hf",
"small_v2": "depth-anything/Depth-Anything-V2-Small-hf",
}
@invocation(
@@ -600,28 +608,33 @@ DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small"]
title="Depth Anything Processor",
tags=["controlnet", "depth", "depth anything"],
category="controlnet",
version="1.1.2",
version="1.1.3",
)
class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
"""Generates a depth map based on the Depth Anything algorithm"""
model_size: DEPTH_ANYTHING_MODEL_SIZES = InputField(
default="small", description="The size of the depth model to use"
default="small_v2", description="The size of the depth model to use"
)
resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image) -> Image.Image:
def loader(model_path: Path):
return DepthAnythingDetector.load_model(
model_path, model_size=self.model_size, device=TorchDevice.choose_torch_device()
)
def load_depth_anything(model_path: Path):
depth_anything_pipeline = pipeline(model=str(model_path), task="depth-estimation", local_files_only=True)
assert isinstance(depth_anything_pipeline, DepthEstimationPipeline)
return DepthAnythingPipeline(depth_anything_pipeline)
with self._context.models.load_remote_model(
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader
) as model:
depth_anything_detector = DepthAnythingDetector(model, TorchDevice.choose_torch_device())
processed_image = depth_anything_detector(image=image, resolution=self.resolution)
return processed_image
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=load_depth_anything
) as depth_anything_detector:
assert isinstance(depth_anything_detector, DepthAnythingPipeline)
depth_map = depth_anything_detector.generate_depth(image)
# Resizing to user target specified size
new_height = int(image.size[1] * (self.resolution / image.size[0]))
depth_map = depth_map.resize((self.resolution, new_height))
return depth_map
@invocation(

View File

@@ -1,278 +0,0 @@
from pathlib import Path
from typing import Literal
import torch
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from optimum.quanto import qfloat8
from PIL import Image
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from transformers.models.auto import AutoModelForTextEncoding
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import InputField, WithBoard, WithMetadata
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.quantization.fast_quantized_diffusion_model import FastQuantizedDiffusersModel
from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel
from invokeai.backend.util.devices import TorchDevice
TFluxModelKeys = Literal["flux-schnell"]
FLUX_MODELS: dict[TFluxModelKeys, str] = {"flux-schnell": "black-forest-labs/FLUX.1-schnell"}
class QuantizedFluxTransformer2DModel(FastQuantizedDiffusersModel):
base_class = FluxTransformer2DModel
class QuantizedModelForTextEncoding(FastQuantizedTransformersModel):
auto_class = AutoModelForTextEncoding
@invocation(
"flux_text_to_image",
title="FLUX Text to Image",
tags=["image"],
category="image",
version="1.0.0",
)
class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Text-to-image generation using a FLUX model."""
model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.")
use_8bit: bool = InputField(
default=False, description="Whether to quantize the transformer model to 8-bit precision."
)
positive_prompt: str = InputField(description="Positive prompt for text-to-image generation.")
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
num_steps: int = InputField(default=4, description="Number of diffusion steps.")
guidance: float = InputField(
default=4.0,
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images.",
)
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
model_path = context.models.download_and_cache_model(FLUX_MODELS[self.model])
t5_embeddings, clip_embeddings = self._encode_prompt(context, model_path)
latents = self._run_diffusion(context, model_path, clip_embeddings, t5_embeddings)
image = self._run_vae_decoding(context, model_path, latents)
image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto)
def _encode_prompt(self, context: InvocationContext, flux_model_dir: Path) -> tuple[torch.Tensor, torch.Tensor]:
# Determine the T5 max sequence length based on the model.
if self.model == "flux-schnell":
max_seq_len = 256
# elif self.model == "flux-dev":
# max_seq_len = 512
else:
raise ValueError(f"Unknown model: {self.model}")
# Load the CLIP tokenizer.
clip_tokenizer_path = flux_model_dir / "tokenizer"
clip_tokenizer = CLIPTokenizer.from_pretrained(clip_tokenizer_path, local_files_only=True)
assert isinstance(clip_tokenizer, CLIPTokenizer)
# Load the T5 tokenizer.
t5_tokenizer_path = flux_model_dir / "tokenizer_2"
t5_tokenizer = T5TokenizerFast.from_pretrained(t5_tokenizer_path, local_files_only=True)
assert isinstance(t5_tokenizer, T5TokenizerFast)
clip_text_encoder_path = flux_model_dir / "text_encoder"
t5_text_encoder_path = flux_model_dir / "text_encoder_2"
with (
context.models.load_local_model(
model_path=clip_text_encoder_path, loader=self._load_flux_text_encoder
) as clip_text_encoder,
context.models.load_local_model(
model_path=t5_text_encoder_path, loader=self._load_flux_text_encoder_2
) as t5_text_encoder,
):
assert isinstance(clip_text_encoder, CLIPTextModel)
assert isinstance(t5_text_encoder, T5EncoderModel)
pipeline = FluxPipeline(
scheduler=None,
vae=None,
text_encoder=clip_text_encoder,
tokenizer=clip_tokenizer,
text_encoder_2=t5_text_encoder,
tokenizer_2=t5_tokenizer,
transformer=None,
)
# prompt_embeds: T5 embeddings
# pooled_prompt_embeds: CLIP embeddings
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
prompt=self.positive_prompt,
prompt_2=self.positive_prompt,
device=TorchDevice.choose_torch_device(),
max_sequence_length=max_seq_len,
)
assert isinstance(prompt_embeds, torch.Tensor)
assert isinstance(pooled_prompt_embeds, torch.Tensor)
return prompt_embeds, pooled_prompt_embeds
def _run_diffusion(
self,
context: InvocationContext,
flux_model_dir: Path,
clip_embeddings: torch.Tensor,
t5_embeddings: torch.Tensor,
):
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(flux_model_dir / "scheduler", local_files_only=True)
# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
# if the cache is not empty.
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
transformer_path = flux_model_dir / "transformer"
with context.models.load_local_model(
model_path=transformer_path, loader=self._load_flux_transformer
) as transformer:
assert isinstance(transformer, FluxTransformer2DModel)
flux_pipeline_with_transformer = FluxPipeline(
scheduler=scheduler,
vae=None,
text_encoder=None,
tokenizer=None,
text_encoder_2=None,
tokenizer_2=None,
transformer=transformer,
)
t5_embeddings = t5_embeddings.to(dtype=transformer.dtype)
clip_embeddings = clip_embeddings.to(dtype=transformer.dtype)
latents = flux_pipeline_with_transformer(
height=self.height,
width=self.width,
num_inference_steps=self.num_steps,
guidance_scale=self.guidance,
generator=torch.Generator().manual_seed(self.seed),
prompt_embeds=t5_embeddings,
pooled_prompt_embeds=clip_embeddings,
output_type="latent",
return_dict=False,
)[0]
assert isinstance(latents, torch.Tensor)
return latents
def _run_vae_decoding(
self,
context: InvocationContext,
flux_model_dir: Path,
latents: torch.Tensor,
) -> Image.Image:
vae_path = flux_model_dir / "vae"
with context.models.load_local_model(model_path=vae_path, loader=self._load_flux_vae) as vae:
assert isinstance(vae, AutoencoderKL)
flux_pipeline_with_vae = FluxPipeline(
scheduler=None,
vae=vae,
text_encoder=None,
tokenizer=None,
text_encoder_2=None,
tokenizer_2=None,
transformer=None,
)
latents = flux_pipeline_with_vae._unpack_latents(
latents, self.height, self.width, flux_pipeline_with_vae.vae_scale_factor
)
latents = (
latents / flux_pipeline_with_vae.vae.config.scaling_factor
) + flux_pipeline_with_vae.vae.config.shift_factor
latents = latents.to(dtype=vae.dtype)
image = flux_pipeline_with_vae.vae.decode(latents, return_dict=False)[0]
image = flux_pipeline_with_vae.image_processor.postprocess(image, output_type="pil")[0]
assert isinstance(image, Image.Image)
return image
@staticmethod
def _load_flux_text_encoder(path: Path) -> CLIPTextModel:
model = CLIPTextModel.from_pretrained(path, local_files_only=True)
assert isinstance(model, CLIPTextModel)
return model
def _load_flux_text_encoder_2(self, path: Path) -> T5EncoderModel:
if self.use_8bit:
model_8bit_path = path / "quantized"
if model_8bit_path.exists():
# The quantized model exists, load it.
# TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
# something that we should be able to make much faster.
q_model = QuantizedModelForTextEncoding.from_pretrained(model_8bit_path)
# Access the underlying wrapped model.
# We access the wrapped model, even though it is private, because it simplifies the type checking by
# always returning a T5EncoderModel from this function.
model = q_model._wrapped
else:
# The quantized model does not exist yet, quantize and save it.
# TODO(ryand): dtype?
model = T5EncoderModel.from_pretrained(path, local_files_only=True)
assert isinstance(model, T5EncoderModel)
q_model = QuantizedModelForTextEncoding.quantize(model, weights=qfloat8)
model_8bit_path.mkdir(parents=True, exist_ok=True)
q_model.save_pretrained(model_8bit_path)
# (See earlier comment about accessing the wrapped model.)
model = q_model._wrapped
else:
model = T5EncoderModel.from_pretrained(path, local_files_only=True)
assert isinstance(model, T5EncoderModel)
return model
def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel:
if self.use_8bit:
model_8bit_path = path / "quantized"
if model_8bit_path.exists():
# The quantized model exists, load it.
# TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
# something that we should be able to make much faster.
q_model = QuantizedFluxTransformer2DModel.from_pretrained(model_8bit_path)
# Access the underlying wrapped model.
# We access the wrapped model, even though it is private, because it simplifies the type checking by
# always returning a FluxTransformer2DModel from this function.
model = q_model._wrapped
else:
# The quantized model does not exist yet, quantize and save it.
# TODO(ryand): Loading in float16 and then quantizing seems to result in NaNs. In order to run this on
# GPUs that don't support bfloat16, we would need to host the quantized model instead of generating it
# here.
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
assert isinstance(model, FluxTransformer2DModel)
q_model = QuantizedFluxTransformer2DModel.quantize(model, weights=qfloat8)
model_8bit_path.mkdir(parents=True, exist_ok=True)
q_model.save_pretrained(model_8bit_path)
# (See earlier comment about accessing the wrapped model.)
model = q_model._wrapped
else:
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
assert isinstance(model, FluxTransformer2DModel)
return model
@staticmethod
def _load_flux_vae(path: Path) -> AutoencoderKL:
model = AutoencoderKL.from_pretrained(path, local_files_only=True)
assert isinstance(model, AutoencoderKL)
return model

View File

@@ -91,6 +91,7 @@ class InvokeAIAppConfig(BaseSettings):
db_dir: Path to InvokeAI databases directory.
outputs_dir: Path to directory for outputs.
custom_nodes_dir: Path to directory for custom nodes.
style_presets_dir: Path to directory for style presets.
log_handlers: Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>".
log_format: Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style.<br>Valid values: `plain`, `color`, `syslog`, `legacy`
log_level: Emit logging messages at this level or higher.<br>Valid values: `debug`, `info`, `warning`, `error`, `critical`
@@ -153,6 +154,7 @@ class InvokeAIAppConfig(BaseSettings):
db_dir: Path = Field(default=Path("databases"), description="Path to InvokeAI databases directory.")
outputs_dir: Path = Field(default=Path("outputs"), description="Path to directory for outputs.")
custom_nodes_dir: Path = Field(default=Path("nodes"), description="Path to directory for custom nodes.")
style_presets_dir: Path = Field(default=Path("style_presets"), description="Path to directory for style presets.")
# LOGGING
log_handlers: list[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>".')
@@ -300,6 +302,11 @@ class InvokeAIAppConfig(BaseSettings):
"""Path to the models directory, resolved to an absolute path.."""
return self._resolve(self.models_dir)
@property
def style_presets_path(self) -> Path:
"""Path to the style presets directory, resolved to an absolute path.."""
return self._resolve(self.style_presets_dir)
@property
def convert_cache_path(self) -> Path:
"""Path to the converted cache models directory, resolved to an absolute path.."""

View File

@@ -1,46 +1,44 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import asyncio
import threading
from queue import Empty, Queue
from fastapi_events.dispatcher import dispatch
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.events.events_common import (
EventBase,
)
from invokeai.app.services.events.events_common import EventBase
class FastAPIEventService(EventServiceBase):
def __init__(self, event_handler_id: int) -> None:
def __init__(self, event_handler_id: int, loop: asyncio.AbstractEventLoop) -> None:
self.event_handler_id = event_handler_id
self._queue = Queue[EventBase | None]()
self._queue = asyncio.Queue[EventBase | None]()
self._stop_event = threading.Event()
asyncio.create_task(self._dispatch_from_queue(stop_event=self._stop_event))
self._loop = loop
# We need to store a reference to the task so it doesn't get GC'd
# See: https://docs.python.org/3/library/asyncio-task.html#creating-tasks
self._background_tasks: set[asyncio.Task[None]] = set()
task = self._loop.create_task(self._dispatch_from_queue(stop_event=self._stop_event))
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.remove)
super().__init__()
def stop(self, *args, **kwargs):
self._stop_event.set()
self._queue.put(None)
self._loop.call_soon_threadsafe(self._queue.put_nowait, None)
def dispatch(self, event: EventBase) -> None:
self._queue.put(event)
self._loop.call_soon_threadsafe(self._queue.put_nowait, event)
async def _dispatch_from_queue(self, stop_event: threading.Event):
"""Get events on from the queue and dispatch them, from the correct thread"""
while not stop_event.is_set():
try:
event = self._queue.get(block=False)
event = await self._queue.get()
if not event: # Probably stopping
continue
# Leave the payloads as live pydantic models
dispatch(event, middleware_id=self.event_handler_id, payload_schema_dump=False)
except Empty:
await asyncio.sleep(0.1)
pass
except asyncio.CancelledError as e:
raise e # Raise a proper error

View File

@@ -4,6 +4,8 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
from invokeai.app.services.style_preset_images.style_preset_images_base import StylePresetImageFileStorageBase
from invokeai.app.services.style_preset_records.style_preset_records_base import StylePresetRecordsStorageBase
if TYPE_CHECKING:
from logging import Logger
@@ -61,6 +63,8 @@ class InvocationServices:
workflow_records: "WorkflowRecordsStorageBase",
tensors: "ObjectSerializerBase[torch.Tensor]",
conditioning: "ObjectSerializerBase[ConditioningFieldData]",
style_preset_records: "StylePresetRecordsStorageBase",
style_preset_image_files: "StylePresetImageFileStorageBase",
):
self.board_images = board_images
self.board_image_records = board_image_records
@@ -85,3 +89,5 @@ class InvocationServices:
self.workflow_records = workflow_records
self.tensors = tensors
self.conditioning = conditioning
self.style_preset_records = style_preset_records
self.style_preset_image_files = style_preset_image_files

View File

@@ -16,6 +16,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_10 import
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_11 import build_migration_11
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_12 import build_migration_12
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_13 import build_migration_13
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_14 import build_migration_14
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@@ -49,6 +50,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_11(app_config=config, logger=logger))
migrator.register_migration(build_migration_12(app_config=config))
migrator.register_migration(build_migration_13())
migrator.register_migration(build_migration_14())
migrator.run_migrations()
return db

View File

@@ -0,0 +1,61 @@
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration14Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._create_style_presets(cursor)
def _create_style_presets(self, cursor: sqlite3.Cursor) -> None:
"""Create the table used to store style presets."""
tables = [
"""--sql
CREATE TABLE IF NOT EXISTS style_presets (
id TEXT NOT NULL PRIMARY KEY,
name TEXT NOT NULL,
preset_data TEXT NOT NULL,
type TEXT NOT NULL DEFAULT "user",
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
);
"""
]
# Add trigger for `updated_at`.
triggers = [
"""--sql
CREATE TRIGGER IF NOT EXISTS style_presets
AFTER UPDATE
ON style_presets FOR EACH ROW
BEGIN
UPDATE style_presets SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE id = old.id;
END;
"""
]
# Add indexes for searchable fields
indices = [
"CREATE INDEX IF NOT EXISTS idx_style_presets_name ON style_presets(name);",
]
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def build_migration_14() -> Migration:
"""
Build the migration from database version 13 to 14..
This migration does the following:
- Create the table used to store style presets.
"""
migration_14 = Migration(
from_version=13,
to_version=14,
callback=Migration14Callback(),
)
return migration_14

Binary file not shown.

After

Width:  |  Height:  |  Size: 98 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 138 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 122 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 123 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 160 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 146 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 119 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 117 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 110 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 46 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 79 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 156 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 141 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 96 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 91 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 88 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 107 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 132 KiB

View File

@@ -0,0 +1,33 @@
from abc import ABC, abstractmethod
from pathlib import Path
from PIL.Image import Image as PILImageType
class StylePresetImageFileStorageBase(ABC):
"""Low-level service responsible for storing and retrieving image files."""
@abstractmethod
def get(self, style_preset_id: str) -> PILImageType:
"""Retrieves a style preset image as PIL Image."""
pass
@abstractmethod
def get_path(self, style_preset_id: str) -> Path:
"""Gets the internal path to a style preset image."""
pass
@abstractmethod
def get_url(self, style_preset_id: str) -> str | None:
"""Gets the URL to fetch a style preset image."""
pass
@abstractmethod
def save(self, style_preset_id: str, image: PILImageType) -> None:
"""Saves a style preset image."""
pass
@abstractmethod
def delete(self, style_preset_id: str) -> None:
"""Deletes a style preset image."""
pass

View File

@@ -0,0 +1,19 @@
class StylePresetImageFileNotFoundException(Exception):
"""Raised when an image file is not found in storage."""
def __init__(self, message: str = "Style preset image file not found"):
super().__init__(message)
class StylePresetImageFileSaveException(Exception):
"""Raised when an image cannot be saved."""
def __init__(self, message: str = "Style preset image file not saved"):
super().__init__(message)
class StylePresetImageFileDeleteException(Exception):
"""Raised when an image cannot be deleted."""
def __init__(self, message: str = "Style preset image file not deleted"):
super().__init__(message)

View File

@@ -0,0 +1,88 @@
from pathlib import Path
from PIL import Image
from PIL.Image import Image as PILImageType
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.style_preset_images.style_preset_images_base import StylePresetImageFileStorageBase
from invokeai.app.services.style_preset_images.style_preset_images_common import (
StylePresetImageFileDeleteException,
StylePresetImageFileNotFoundException,
StylePresetImageFileSaveException,
)
from invokeai.app.services.style_preset_records.style_preset_records_common import PresetType
from invokeai.app.util.misc import uuid_string
from invokeai.app.util.thumbnails import make_thumbnail
class StylePresetImageFileStorageDisk(StylePresetImageFileStorageBase):
"""Stores images on disk"""
def __init__(self, style_preset_images_folder: Path):
self._style_preset_images_folder = style_preset_images_folder
self._validate_storage_folders()
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
def get(self, style_preset_id: str) -> PILImageType:
try:
path = self.get_path(style_preset_id)
return Image.open(path)
except FileNotFoundError as e:
raise StylePresetImageFileNotFoundException from e
def save(self, style_preset_id: str, image: PILImageType) -> None:
try:
self._validate_storage_folders()
image_path = self._style_preset_images_folder / (style_preset_id + ".webp")
thumbnail = make_thumbnail(image, 256)
thumbnail.save(image_path, format="webp")
except Exception as e:
raise StylePresetImageFileSaveException from e
def get_path(self, style_preset_id: str) -> Path:
style_preset = self._invoker.services.style_preset_records.get(style_preset_id)
if style_preset.type is PresetType.Default:
default_images_dir = Path(__file__).parent / Path("default_style_preset_images")
path = default_images_dir / (style_preset.name + ".png")
else:
path = self._style_preset_images_folder / (style_preset_id + ".webp")
return path
def get_url(self, style_preset_id: str) -> str | None:
path = self.get_path(style_preset_id)
if not self._validate_path(path):
return
url = self._invoker.services.urls.get_style_preset_image_url(style_preset_id)
# The image URL never changes, so we must add random query string to it to prevent caching
url += f"?{uuid_string()}"
return url
def delete(self, style_preset_id: str) -> None:
try:
path = self.get_path(style_preset_id)
if not self._validate_path(path):
raise StylePresetImageFileNotFoundException
path.unlink()
except StylePresetImageFileNotFoundException as e:
raise StylePresetImageFileNotFoundException from e
except Exception as e:
raise StylePresetImageFileDeleteException from e
def _validate_path(self, path: Path) -> bool:
"""Validates the path given for an image."""
return path.exists()
def _validate_storage_folders(self) -> None:
"""Checks if the required folders exist and create them if they don't"""
self._style_preset_images_folder.mkdir(parents=True, exist_ok=True)

View File

@@ -0,0 +1,146 @@
[
{
"name": "Photography (General)",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt}. photography. f/2.8 macro photo, bokeh, photorealism",
"negative_prompt": "painting, digital art. sketch, blurry"
}
},
{
"name": "Photography (Studio Lighting)",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt}, photography. f/8 photo. centered subject, studio lighting.",
"negative_prompt": "painting, digital art. sketch, blurry"
}
},
{
"name": "Photography (Landscape)",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt}, landscape photograph, f/12, lifelike, highly detailed.",
"negative_prompt": "painting, digital art. sketch, blurry"
}
},
{
"name": "Photography (Portrait)",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt}. photography. portraiture. catch light in eyes. one flash. rembrandt lighting. Soft box. dark shadows. High contrast. 80mm lens. F2.8.",
"negative_prompt": "painting, digital art. sketch, blurry"
}
},
{
"name": "Photography (Black and White)",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt} photography. natural light. 80mm lens. F1.4. strong contrast, hard light. dark contrast. blurred background. black and white",
"negative_prompt": "painting, digital art. sketch, colour+"
}
},
{
"name": "Architectural Visualization",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt}. architectural photography, f/12, luxury, aesthetically pleasing form and function.",
"negative_prompt": "painting, digital art. sketch, blurry"
}
},
{
"name": "Concept Art (Fantasy)",
"type": "default",
"preset_data": {
"positive_prompt": "concept artwork of a {prompt}. (digital painterly art style)++, mythological, (textured 2d dry media brushpack)++, glazed brushstrokes, otherworldly. painting+, illustration+",
"negative_prompt": "photo. distorted, blurry, out of focus. sketch. (cgi, 3d.)++"
}
},
{
"name": "Concept Art (Sci-Fi)",
"type": "default",
"preset_data": {
"positive_prompt": "(concept art)++, {prompt}, (sleek futurism)++, (textured 2d dry media)++, metallic highlights, digital painting style",
"negative_prompt": "photo. distorted, blurry, out of focus. sketch. (cgi, 3d.)++"
}
},
{
"name": "Concept Art (Character)",
"type": "default",
"preset_data": {
"positive_prompt": "(character concept art)++, stylized painterly digital painting of {prompt}, (painterly, impasto. Dry brush.)++",
"negative_prompt": "photo. distorted, blurry, out of focus. sketch. (cgi, 3d.)++"
}
},
{
"name": "Concept Art (Painterly)",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt} oil painting. high contrast. impasto. sfumato. chiaroscuro. Palette knife.",
"negative_prompt": "photo. smooth. border. frame"
}
},
{
"name": "Environment Art",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt} environment artwork, hyper-realistic digital painting style with cinematic composition, atmospheric, depth and detail, voluminous. textured dry brush 2d media",
"negative_prompt": "photo, distorted, blurry, out of focus. sketch."
}
},
{
"name": "Interior Design (Visualization)",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt} interior design photo, gentle shadows, light mid-tones, dimension, mix of smooth and textured surfaces, focus on negative space and clean lines, focus",
"negative_prompt": "photo, distorted. sketch."
}
},
{
"name": "Product Rendering",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt} high quality product photography, 3d rendering with key lighting, shallow depth of field, simple plain background, studio lighting.",
"negative_prompt": "blurry, sketch, messy, dirty. unfinished."
}
},
{
"name": "Sketch",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt} black and white pencil drawing, off-center composition, cross-hatching for shadows, bold strokes, textured paper. sketch+++",
"negative_prompt": "blurry, photo, painting, color. messy, dirty. unfinished. frame, borders."
}
},
{
"name": "Line Art",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt} Line art. bold outline. simplistic. white background. 2d",
"negative_prompt": "photo. digital art. greyscale. solid black. painting"
}
},
{
"name": "Anime",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt} anime++, bold outline, cel-shaded coloring, shounen, seinen",
"negative_prompt": "(photo)+++. greyscale. solid black. painting"
}
},
{
"name": "Illustration",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt} illustration, bold linework, illustrative details, vector art style, flat coloring",
"negative_prompt": "(photo)+++. greyscale. painting, black and white."
}
},
{
"name": "Vehicles",
"type": "default",
"preset_data": {
"positive_prompt": "A weird futuristic normal auto, {prompt} elegant design, nice color, nice wheels",
"negative_prompt": "sketch. digital art. greyscale. painting"
}
}
]

View File

@@ -0,0 +1,42 @@
from abc import ABC, abstractmethod
from invokeai.app.services.style_preset_records.style_preset_records_common import (
PresetType,
StylePresetChanges,
StylePresetRecordDTO,
StylePresetWithoutId,
)
class StylePresetRecordsStorageBase(ABC):
"""Base class for style preset storage services."""
@abstractmethod
def get(self, style_preset_id: str) -> StylePresetRecordDTO:
"""Get style preset by id."""
pass
@abstractmethod
def create(self, style_preset: StylePresetWithoutId) -> StylePresetRecordDTO:
"""Creates a style preset."""
pass
@abstractmethod
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
"""Creates many style presets."""
pass
@abstractmethod
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
"""Updates a style preset."""
pass
@abstractmethod
def delete(self, style_preset_id: str) -> None:
"""Deletes a style preset."""
pass
@abstractmethod
def get_many(self, type: PresetType | None = None) -> list[StylePresetRecordDTO]:
"""Gets many workflows."""
pass

View File

@@ -0,0 +1,139 @@
import codecs
import csv
import json
from enum import Enum
from typing import Any, Optional
import pydantic
from fastapi import UploadFile
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, TypeAdapter
from invokeai.app.util.metaenum import MetaEnum
class StylePresetNotFoundError(Exception):
"""Raised when a style preset is not found"""
class PresetData(BaseModel, extra="forbid"):
positive_prompt: str = Field(description="Positive prompt")
negative_prompt: str = Field(description="Negative prompt")
PresetDataValidator = TypeAdapter(PresetData)
class PresetType(str, Enum, metaclass=MetaEnum):
User = "user"
Default = "default"
Project = "project"
class StylePresetChanges(BaseModel, extra="forbid"):
name: Optional[str] = Field(default=None, description="The style preset's new name.")
preset_data: Optional[PresetData] = Field(default=None, description="The updated data for style preset.")
type: Optional[PresetType] = Field(description="The updated type of the style preset")
class StylePresetWithoutId(BaseModel):
name: str = Field(description="The name of the style preset.")
preset_data: PresetData = Field(description="The preset data")
type: PresetType = Field(description="The type of style preset")
class StylePresetRecordDTO(StylePresetWithoutId):
id: str = Field(description="The style preset ID.")
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "StylePresetRecordDTO":
data["preset_data"] = PresetDataValidator.validate_json(data.get("preset_data", ""))
return StylePresetRecordDTOValidator.validate_python(data)
StylePresetRecordDTOValidator = TypeAdapter(StylePresetRecordDTO)
class StylePresetRecordWithImage(StylePresetRecordDTO):
image: Optional[str] = Field(description="The path for image")
class StylePresetImportRow(BaseModel):
name: str = Field(min_length=1, description="The name of the preset.")
positive_prompt: str = Field(
default="",
description="The positive prompt for the preset.",
validation_alias=AliasChoices("positive_prompt", "prompt"),
)
negative_prompt: str = Field(default="", description="The negative prompt for the preset.")
model_config = ConfigDict(str_strip_whitespace=True, extra="forbid")
StylePresetImportList = list[StylePresetImportRow]
StylePresetImportListTypeAdapter = TypeAdapter(StylePresetImportList)
class UnsupportedFileTypeError(ValueError):
"""Raised when an unsupported file type is encountered"""
pass
class InvalidPresetImportDataError(ValueError):
"""Raised when invalid preset import data is encountered"""
pass
async def parse_presets_from_file(file: UploadFile) -> list[StylePresetWithoutId]:
"""Parses style presets from a file. The file must be a CSV or JSON file.
If CSV, the file must have the following columns:
- name
- prompt (or positive_prompt)
- negative_prompt
If JSON, the file must be a list of objects with the following keys:
- name
- prompt (or positive_prompt)
- negative_prompt
Args:
file (UploadFile): The file to parse.
Returns:
list[StylePresetWithoutId]: The parsed style presets.
Raises:
UnsupportedFileTypeError: If the file type is not supported.
InvalidPresetImportDataError: If the data in the file is invalid.
"""
if file.content_type not in ["text/csv", "application/json"]:
raise UnsupportedFileTypeError()
if file.content_type == "text/csv":
csv_reader = csv.DictReader(codecs.iterdecode(file.file, "utf-8"))
data = list(csv_reader)
else: # file.content_type == "application/json":
json_data = await file.read()
data = json.loads(json_data)
try:
imported_presets = StylePresetImportListTypeAdapter.validate_python(data)
style_presets: list[StylePresetWithoutId] = []
for imported in imported_presets:
preset_data = PresetData(positive_prompt=imported.positive_prompt, negative_prompt=imported.negative_prompt)
style_preset = StylePresetWithoutId(name=imported.name, preset_data=preset_data, type=PresetType.User)
style_presets.append(style_preset)
except pydantic.ValidationError as e:
if file.content_type == "text/csv":
msg = "Invalid CSV format: must include columns 'name', 'prompt', and 'negative_prompt' and name cannot be blank"
else: # file.content_type == "application/json":
msg = "Invalid JSON format: must be a list of objects with keys 'name', 'prompt', and 'negative_prompt' and name cannot be blank"
raise InvalidPresetImportDataError(msg) from e
finally:
file.file.close()
return style_presets

View File

@@ -0,0 +1,215 @@
import json
from pathlib import Path
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.style_preset_records.style_preset_records_base import StylePresetRecordsStorageBase
from invokeai.app.services.style_preset_records.style_preset_records_common import (
PresetType,
StylePresetChanges,
StylePresetNotFoundError,
StylePresetRecordDTO,
StylePresetWithoutId,
)
from invokeai.app.util.misc import uuid_string
class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._lock = db.lock
self._conn = db.conn
self._cursor = self._conn.cursor()
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
self._sync_default_style_presets()
def get(self, style_preset_id: str) -> StylePresetRecordDTO:
"""Gets a style preset by ID."""
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM style_presets
WHERE id = ?;
""",
(style_preset_id,),
)
row = self._cursor.fetchone()
if row is None:
raise StylePresetNotFoundError(f"Style preset with id {style_preset_id} not found")
return StylePresetRecordDTO.from_dict(dict(row))
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
def create(self, style_preset: StylePresetWithoutId) -> StylePresetRecordDTO:
style_preset_id = uuid_string()
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO style_presets (
id,
name,
preset_data,
type
)
VALUES (?, ?, ?, ?);
""",
(
style_preset_id,
style_preset.name,
style_preset.preset_data.model_dump_json(),
style_preset.type,
),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
return self.get(style_preset_id)
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
style_preset_ids = []
try:
self._lock.acquire()
for style_preset in style_presets:
style_preset_id = uuid_string()
style_preset_ids.append(style_preset_id)
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO style_presets (
id,
name,
preset_data,
type
)
VALUES (?, ?, ?, ?);
""",
(
style_preset_id,
style_preset.name,
style_preset.preset_data.model_dump_json(),
style_preset.type,
),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
return None
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
try:
self._lock.acquire()
# Change the name of a style preset
if changes.name is not None:
self._cursor.execute(
"""--sql
UPDATE style_presets
SET name = ?
WHERE id = ?;
""",
(changes.name, style_preset_id),
)
# Change the preset data for a style preset
if changes.preset_data is not None:
self._cursor.execute(
"""--sql
UPDATE style_presets
SET preset_data = ?
WHERE id = ?;
""",
(changes.preset_data.model_dump_json(), style_preset_id),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
return self.get(style_preset_id)
def delete(self, style_preset_id: str) -> None:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE from style_presets
WHERE id = ?;
""",
(style_preset_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
return None
def get_many(self, type: PresetType | None = None) -> list[StylePresetRecordDTO]:
try:
self._lock.acquire()
main_query = """
SELECT
*
FROM style_presets
"""
if type is not None:
main_query += "WHERE type = ? "
main_query += "ORDER BY LOWER(name) ASC"
if type is not None:
self._cursor.execute(main_query, (type,))
else:
self._cursor.execute(main_query)
rows = self._cursor.fetchall()
style_presets = [StylePresetRecordDTO.from_dict(dict(row)) for row in rows]
return style_presets
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
def _sync_default_style_presets(self) -> None:
"""Syncs default style presets to the database. Internal use only."""
# First delete all existing default style presets
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE FROM style_presets
WHERE type = "default";
"""
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
# Next, parse and create the default style presets
with self._lock, open(Path(__file__).parent / Path("default_style_presets.json"), "r") as file:
presets = json.load(file)
for preset in presets:
style_preset = StylePresetWithoutId.model_validate(preset)
self.create(style_preset)

View File

@@ -13,3 +13,8 @@ class UrlServiceBase(ABC):
def get_model_image_url(self, model_key: str) -> str:
"""Gets the URL for a model image"""
pass
@abstractmethod
def get_style_preset_image_url(self, style_preset_id: str) -> str:
"""Gets the URL for a style preset image"""
pass

View File

@@ -19,3 +19,6 @@ class LocalUrlService(UrlServiceBase):
def get_model_image_url(self, model_key: str) -> str:
return f"{self._base_url_v2}/models/i/{model_key}/image"
def get_style_preset_image_url(self, style_preset_id: str) -> str:
return f"{self._base_url}/style_presets/i/{style_preset_id}/image"

View File

@@ -81,7 +81,7 @@ def get_openapi_func(
# Add the output map to the schema
openapi_schema["components"]["schemas"]["InvocationOutputMap"] = {
"type": "object",
"properties": invocation_output_map_properties,
"properties": dict(sorted(invocation_output_map_properties.items())),
"required": invocation_output_map_required,
}

View File

@@ -1,90 +0,0 @@
from pathlib import Path
from typing import Literal
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from einops import repeat
from PIL import Image
from torchvision.transforms import Compose
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
from invokeai.backend.util.logging import InvokeAILogger
config = get_config()
logger = InvokeAILogger.get_logger(config=config)
DEPTH_ANYTHING_MODELS = {
"large": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true",
"base": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true",
"small": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true",
}
transform = Compose(
[
Resize(
width=518,
height=518,
resize_target=False,
keep_aspect_ratio=True,
ensure_multiple_of=14,
resize_method="lower_bound",
image_interpolation_method=cv2.INTER_CUBIC,
),
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
PrepareForNet(),
]
)
class DepthAnythingDetector:
def __init__(self, model: DPT_DINOv2, device: torch.device) -> None:
self.model = model
self.device = device
@staticmethod
def load_model(
model_path: Path, device: torch.device, model_size: Literal["large", "base", "small"] = "small"
) -> DPT_DINOv2:
match model_size:
case "small":
model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
case "base":
model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
case "large":
model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
model.load_state_dict(torch.load(model_path.as_posix(), map_location="cpu"))
model.eval()
model.to(device)
return model
def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image:
if not self.model:
logger.warn("DepthAnything model was not loaded. Returning original image")
return image
np_image = np.array(image, dtype=np.uint8)
np_image = np_image[:, :, ::-1] / 255.0
image_height, image_width = np_image.shape[:2]
np_image = transform({"image": np_image})["image"]
tensor_image = torch.from_numpy(np_image).unsqueeze(0).to(self.device)
with torch.no_grad():
depth = self.model(tensor_image)
depth = F.interpolate(depth[None], (image_height, image_width), mode="bilinear", align_corners=False)[0, 0]
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
depth_map = repeat(depth, "h w -> h w 3").cpu().numpy().astype(np.uint8)
depth_map = Image.fromarray(depth_map)
new_height = int(image_height * (resolution / image_width))
depth_map = depth_map.resize((resolution, new_height))
return depth_map

View File

@@ -0,0 +1,31 @@
from typing import Optional
import torch
from PIL import Image
from transformers.pipelines import DepthEstimationPipeline
from invokeai.backend.raw_model import RawModel
class DepthAnythingPipeline(RawModel):
"""Custom wrapper for the Depth Estimation pipeline from transformers adding compatibility
for Invoke's Model Management System"""
def __init__(self, pipeline: DepthEstimationPipeline) -> None:
self._pipeline = pipeline
def generate_depth(self, image: Image.Image) -> Image.Image:
depth_map = self._pipeline(image)["depth"]
assert isinstance(depth_map, Image.Image)
return depth_map
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
if device is not None and device.type not in {"cpu", "cuda"}:
device = None
self._pipeline.model.to(device=device, dtype=dtype)
self._pipeline.device = self._pipeline.model.device
def calc_size(self) -> int:
from invokeai.backend.model_manager.load.model_util import calc_module_size
return calc_module_size(self._pipeline.model)

View File

@@ -1,145 +0,0 @@
import torch.nn as nn
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
scratch = nn.Module()
out_shape1 = out_shape
out_shape2 = out_shape
out_shape3 = out_shape
if len(in_shape) >= 4:
out_shape4 = out_shape
if expand:
out_shape1 = out_shape
out_shape2 = out_shape * 2
out_shape3 = out_shape * 4
if len(in_shape) >= 4:
out_shape4 = out_shape * 8
scratch.layer1_rn = nn.Conv2d(
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
scratch.layer2_rn = nn.Conv2d(
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
scratch.layer3_rn = nn.Conv2d(
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
if len(in_shape) >= 4:
scratch.layer4_rn = nn.Conv2d(
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
return scratch
class ResidualConvUnit(nn.Module):
"""Residual convolution module."""
def __init__(self, features, activation, bn):
"""Init.
Args:
features (int): number of features
"""
super().__init__()
self.bn = bn
self.groups = 1
self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
if self.bn:
self.bn1 = nn.BatchNorm2d(features)
self.bn2 = nn.BatchNorm2d(features)
self.activation = activation
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: output
"""
out = self.activation(x)
out = self.conv1(out)
if self.bn:
out = self.bn1(out)
out = self.activation(out)
out = self.conv2(out)
if self.bn:
out = self.bn2(out)
if self.groups > 1:
out = self.conv_merge(out)
return self.skip_add.add(out, x)
class FeatureFusionBlock(nn.Module):
"""Feature fusion block."""
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None):
"""Init.
Args:
features (int): number of features
"""
super(FeatureFusionBlock, self).__init__()
self.deconv = deconv
self.align_corners = align_corners
self.groups = 1
self.expand = expand
out_features = features
if self.expand:
out_features = features // 2
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
self.skip_add = nn.quantized.FloatFunctional()
self.size = size
def forward(self, *xs, size=None):
"""Forward pass.
Returns:
tensor: output
"""
output = xs[0]
if len(xs) == 2:
res = self.resConfUnit1(xs[1])
output = self.skip_add.add(output, res)
output = self.resConfUnit2(output)
if (size is None) and (self.size is None):
modifier = {"scale_factor": 2}
elif size is None:
modifier = {"size": self.size}
else:
modifier = {"size": size}
output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
output = self.out_conv(output)
return output

View File

@@ -1,183 +0,0 @@
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from invokeai.backend.image_util.depth_anything.model.blocks import FeatureFusionBlock, _make_scratch
torchhub_path = Path(__file__).parent.parent / "torchhub"
def _make_fusion_block(features, use_bn, size=None):
return FeatureFusionBlock(
features,
nn.ReLU(False),
deconv=False,
bn=use_bn,
expand=False,
align_corners=True,
size=size,
)
class DPTHead(nn.Module):
def __init__(self, nclass, in_channels, features, out_channels, use_bn=False, use_clstoken=False):
super(DPTHead, self).__init__()
self.nclass = nclass
self.use_clstoken = use_clstoken
self.projects = nn.ModuleList(
[
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channel,
kernel_size=1,
stride=1,
padding=0,
)
for out_channel in out_channels
]
)
self.resize_layers = nn.ModuleList(
[
nn.ConvTranspose2d(
in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
),
nn.ConvTranspose2d(
in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
),
nn.Identity(),
nn.Conv2d(
in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
),
]
)
if use_clstoken:
self.readout_projects = nn.ModuleList()
for _ in range(len(self.projects)):
self.readout_projects.append(nn.Sequential(nn.Linear(2 * in_channels, in_channels), nn.GELU()))
self.scratch = _make_scratch(
out_channels,
features,
groups=1,
expand=False,
)
self.scratch.stem_transpose = None
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
head_features_1 = features
head_features_2 = 32
if nclass > 1:
self.scratch.output_conv = nn.Sequential(
nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.Conv2d(head_features_1, nclass, kernel_size=1, stride=1, padding=0),
)
else:
self.scratch.output_conv1 = nn.Conv2d(
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
)
self.scratch.output_conv2 = nn.Sequential(
nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
nn.ReLU(True),
nn.Identity(),
)
def forward(self, out_features, patch_h, patch_w):
out = []
for i, x in enumerate(out_features):
if self.use_clstoken:
x, cls_token = x[0], x[1]
readout = cls_token.unsqueeze(1).expand_as(x)
x = self.readout_projects[i](torch.cat((x, readout), -1))
else:
x = x[0]
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
x = self.projects[i](x)
x = self.resize_layers[i](x)
out.append(x)
layer_1, layer_2, layer_3, layer_4 = out
layer_1_rn = self.scratch.layer1_rn(layer_1)
layer_2_rn = self.scratch.layer2_rn(layer_2)
layer_3_rn = self.scratch.layer3_rn(layer_3)
layer_4_rn = self.scratch.layer4_rn(layer_4)
path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
out = self.scratch.output_conv1(path_1)
out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
out = self.scratch.output_conv2(out)
return out
class DPT_DINOv2(nn.Module):
def __init__(
self,
features,
out_channels,
encoder="vitl",
use_bn=False,
use_clstoken=False,
):
super(DPT_DINOv2, self).__init__()
assert encoder in ["vits", "vitb", "vitl"]
# # in case the Internet connection is not stable, please load the DINOv2 locally
# if use_local:
# self.pretrained = torch.hub.load(
# torchhub_path / "facebookresearch_dinov2_main",
# "dinov2_{:}14".format(encoder),
# source="local",
# pretrained=False,
# )
# else:
# self.pretrained = torch.hub.load(
# "facebookresearch/dinov2",
# "dinov2_{:}14".format(encoder),
# )
self.pretrained = torch.hub.load(
"facebookresearch/dinov2",
"dinov2_{:}14".format(encoder),
)
dim = self.pretrained.blocks[0].attn.qkv.in_features
self.depth_head = DPTHead(1, dim, features, out_channels=out_channels, use_bn=use_bn, use_clstoken=use_clstoken)
def forward(self, x):
h, w = x.shape[-2:]
features = self.pretrained.get_intermediate_layers(x, 4, return_class_token=True)
patch_h, patch_w = h // 14, w // 14
depth = self.depth_head(features, patch_h, patch_w)
depth = F.interpolate(depth, size=(h, w), mode="bilinear", align_corners=True)
depth = F.relu(depth)
return depth.squeeze(1)

View File

@@ -1,227 +0,0 @@
import math
import cv2
import numpy as np
import torch
import torch.nn.functional as F
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
Args:
sample (dict): sample
size (tuple): image size
Returns:
tuple: new size
"""
shape = list(sample["disparity"].shape)
if shape[0] >= size[0] and shape[1] >= size[1]:
return sample
scale = [0, 0]
scale[0] = size[0] / shape[0]
scale[1] = size[1] / shape[1]
scale = max(scale)
shape[0] = math.ceil(scale * shape[0])
shape[1] = math.ceil(scale * shape[1])
# resize
sample["image"] = cv2.resize(sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method)
sample["disparity"] = cv2.resize(sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST)
sample["mask"] = cv2.resize(
sample["mask"].astype(np.float32),
tuple(shape[::-1]),
interpolation=cv2.INTER_NEAREST,
)
sample["mask"] = sample["mask"].astype(bool)
return tuple(shape)
class Resize(object):
"""Resize sample to given size (width, height)."""
def __init__(
self,
width,
height,
resize_target=True,
keep_aspect_ratio=False,
ensure_multiple_of=1,
resize_method="lower_bound",
image_interpolation_method=cv2.INTER_AREA,
):
"""Init.
Args:
width (int): desired output width
height (int): desired output height
resize_target (bool, optional):
True: Resize the full sample (image, mask, target).
False: Resize image only.
Defaults to True.
keep_aspect_ratio (bool, optional):
True: Keep the aspect ratio of the input sample.
Output sample might not have the given width and height, and
resize behaviour depends on the parameter 'resize_method'.
Defaults to False.
ensure_multiple_of (int, optional):
Output width and height is constrained to be multiple of this parameter.
Defaults to 1.
resize_method (str, optional):
"lower_bound": Output will be at least as large as the given size.
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller
than given size.)
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
Defaults to "lower_bound".
"""
self.__width = width
self.__height = height
self.__resize_target = resize_target
self.__keep_aspect_ratio = keep_aspect_ratio
self.__multiple_of = ensure_multiple_of
self.__resize_method = resize_method
self.__image_interpolation_method = image_interpolation_method
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
if max_val is not None and y > max_val:
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
if y < min_val:
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
return y
def get_size(self, width, height):
# determine new height and width
scale_height = self.__height / height
scale_width = self.__width / width
if self.__keep_aspect_ratio:
if self.__resize_method == "lower_bound":
# scale such that output size is lower bound
if scale_width > scale_height:
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
elif self.__resize_method == "upper_bound":
# scale such that output size is upper bound
if scale_width < scale_height:
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
elif self.__resize_method == "minimal":
# scale as least as possbile
if abs(1 - scale_width) < abs(1 - scale_height):
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
else:
raise ValueError(f"resize_method {self.__resize_method} not implemented")
if self.__resize_method == "lower_bound":
new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height)
new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width)
elif self.__resize_method == "upper_bound":
new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height)
new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width)
elif self.__resize_method == "minimal":
new_height = self.constrain_to_multiple_of(scale_height * height)
new_width = self.constrain_to_multiple_of(scale_width * width)
else:
raise ValueError(f"resize_method {self.__resize_method} not implemented")
return (new_width, new_height)
def __call__(self, sample):
width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0])
# resize sample
sample["image"] = cv2.resize(
sample["image"],
(width, height),
interpolation=self.__image_interpolation_method,
)
if self.__resize_target:
if "disparity" in sample:
sample["disparity"] = cv2.resize(
sample["disparity"],
(width, height),
interpolation=cv2.INTER_NEAREST,
)
if "depth" in sample:
sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST)
if "semseg_mask" in sample:
# sample["semseg_mask"] = cv2.resize(
# sample["semseg_mask"], (width, height), interpolation=cv2.INTER_NEAREST
# )
sample["semseg_mask"] = F.interpolate(
torch.from_numpy(sample["semseg_mask"]).float()[None, None, ...], (height, width), mode="nearest"
).numpy()[0, 0]
if "mask" in sample:
sample["mask"] = cv2.resize(
sample["mask"].astype(np.float32),
(width, height),
interpolation=cv2.INTER_NEAREST,
)
# sample["mask"] = sample["mask"].astype(bool)
# print(sample['image'].shape, sample['depth'].shape)
return sample
class NormalizeImage(object):
"""Normlize image by given mean and std."""
def __init__(self, mean, std):
self.__mean = mean
self.__std = std
def __call__(self, sample):
sample["image"] = (sample["image"] - self.__mean) / self.__std
return sample
class PrepareForNet(object):
"""Prepare sample for usage as network input."""
def __init__(self):
pass
def __call__(self, sample):
image = np.transpose(sample["image"], (2, 0, 1))
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
if "mask" in sample:
sample["mask"] = sample["mask"].astype(np.float32)
sample["mask"] = np.ascontiguousarray(sample["mask"])
if "depth" in sample:
depth = sample["depth"].astype(np.float32)
sample["depth"] = np.ascontiguousarray(depth)
if "semseg_mask" in sample:
sample["semseg_mask"] = sample["semseg_mask"].astype(np.float32)
sample["semseg_mask"] = np.ascontiguousarray(sample["semseg_mask"])
return sample

View File

@@ -1,129 +0,0 @@
import json
import os
import time
from pathlib import Path
from typing import Union
import torch
from diffusers.models.model_loading_utils import load_state_dict
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from diffusers.utils import (
CONFIG_NAME,
SAFE_WEIGHTS_INDEX_NAME,
SAFETENSORS_WEIGHTS_NAME,
_get_checkpoint_shard_files,
is_accelerate_available,
)
from optimum.quanto import qfloat8
from optimum.quanto.models import QuantizedDiffusersModel
from optimum.quanto.models.shared_dict import ShardedStateDict
from invokeai.backend.requantize import requantize
class QuantizedFluxTransformer2DModel(QuantizedDiffusersModel):
base_class = FluxTransformer2DModel
@classmethod
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike]):
if cls.base_class is None:
raise ValueError("The `base_class` attribute needs to be configured.")
if not is_accelerate_available():
raise ValueError("Reloading a quantized diffusers model requires the accelerate library.")
from accelerate import init_empty_weights
if os.path.isdir(model_name_or_path):
# Look for a quantization map
qmap_path = os.path.join(model_name_or_path, cls._qmap_name())
if not os.path.exists(qmap_path):
raise ValueError(f"No quantization map found in {model_name_or_path}: is this a quantized model ?")
# Look for original model config file.
model_config_path = os.path.join(model_name_or_path, CONFIG_NAME)
if not os.path.exists(model_config_path):
raise ValueError(f"{CONFIG_NAME} not found in {model_name_or_path}.")
with open(qmap_path, "r", encoding="utf-8") as f:
qmap = json.load(f)
with open(model_config_path, "r", encoding="utf-8") as f:
original_model_cls_name = json.load(f)["_class_name"]
configured_cls_name = cls.base_class.__name__
if configured_cls_name != original_model_cls_name:
raise ValueError(
f"Configured base class ({configured_cls_name}) differs from what was derived from the provided configuration ({original_model_cls_name})."
)
# Create an empty model
config = cls.base_class.load_config(model_name_or_path)
with init_empty_weights():
model = cls.base_class.from_config(config)
# Look for the index of a sharded checkpoint
checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
if os.path.exists(checkpoint_file):
# Convert the checkpoint path to a list of shards
_, sharded_metadata = _get_checkpoint_shard_files(model_name_or_path, checkpoint_file)
# Create a mapping for the sharded safetensor files
state_dict = ShardedStateDict(model_name_or_path, sharded_metadata["weight_map"])
else:
# Look for a single checkpoint file
checkpoint_file = os.path.join(model_name_or_path, SAFETENSORS_WEIGHTS_NAME)
if not os.path.exists(checkpoint_file):
raise ValueError(f"No safetensor weights found in {model_name_or_path}.")
# Get state_dict from model checkpoint
state_dict = load_state_dict(checkpoint_file)
# Requantize and load quantized weights from state_dict
requantize(model, state_dict=state_dict, quantization_map=qmap)
model.eval()
return cls(model)
else:
raise NotImplementedError("Reloading quantized models directly from the hub is not supported yet.")
def load_flux_transformer(path: Path) -> FluxTransformer2DModel:
# model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
model_8bit_path = path / "quantized"
if model_8bit_path.exists():
# The quantized model exists, load it.
# TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
# something that we should be able to make much faster.
q_model = QuantizedFluxTransformer2DModel.from_pretrained(model_8bit_path)
# Access the underlying wrapped model.
# We access the wrapped model, even though it is private, because it simplifies the type checking by
# always returning a FluxTransformer2DModel from this function.
model = q_model._wrapped
else:
# The quantized model does not exist yet, quantize and save it.
# TODO(ryand): Loading in float16 and then quantizing seems to result in NaNs. In order to run this on
# GPUs that don't support bfloat16, we would need to host the quantized model instead of generating it
# here.
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
assert isinstance(model, FluxTransformer2DModel)
q_model = QuantizedFluxTransformer2DModel.quantize(model, weights=qfloat8)
model_8bit_path.mkdir(parents=True, exist_ok=True)
q_model.save_pretrained(model_8bit_path)
# (See earlier comment about accessing the wrapped model.)
model = q_model._wrapped
assert isinstance(model, FluxTransformer2DModel)
return model
def main():
start = time.time()
model = load_flux_transformer(
Path("/data/invokeai/models/.download_cache/black-forest-labs_flux.1-schnell/FLUX.1-schnell/transformer/")
)
print(f"Time to load: {time.time() - start}s")
print("hi")
if __name__ == "__main__":
main()

View File

@@ -220,11 +220,17 @@ class LoKRLayer(LoRALayerBase):
if self.w1 is None:
self.w1_a = values["lokr_w1_a"]
self.w1_b = values["lokr_w1_b"]
else:
self.w1_b = None
self.w1_a = None
self.w2 = values.get("lokr_w2", None)
if self.w2 is None:
self.w2_a = values["lokr_w2_a"]
self.w2_b = values["lokr_w2_b"]
else:
self.w2_a = None
self.w2_b = None
self.t2 = values.get("lokr_t2", None)
@@ -372,7 +378,39 @@ class IA3Layer(LoRALayerBase):
self.on_input = self.on_input.to(device=device, dtype=dtype)
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
class NormLayer(LoRALayerBase):
# bias handled in LoRALayerBase(calc_size, to)
# weight: torch.Tensor
# bias: Optional[torch.Tensor]
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
self.weight = values["w_norm"]
self.bias = values.get("b_norm", None)
self.rank = None # unscaled
self.check_keys(values, {"w_norm", "b_norm"})
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
return self.weight
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer]
class LoRAModelRaw(RawModel): # (torch.nn.Module):
@@ -513,6 +551,10 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
elif "on_input" in values:
layer = IA3Layer(layer_key, values)
# norms
elif "w_norm" in values:
layer = NormLayer(layer_key, values)
else:
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
raise Exception("Unknown lora format!")

View File

@@ -11,6 +11,7 @@ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from transformers import CLIPTokenizer
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
@@ -45,6 +46,7 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
SpandrelImageToImageModel,
GroundingDinoPipeline,
SegmentAnythingPipeline,
DepthAnythingPipeline,
),
):
return model.calc_size()

View File

@@ -54,7 +54,6 @@ def filter_files(
"lora_weights.safetensors",
"weights.pb",
"onnx_data",
"spiece.model", # Added for `black-forest-labs/FLUX.1-schnell`.
)
):
paths.append(file)
@@ -63,7 +62,7 @@ def filter_files(
# downloading random checkpoints that might also be in the repo. However there is no guarantee
# that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models
# will adhere to this naming convention, so this is an area to be careful of.
elif re.search(r"model.*\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name):
elif re.search(r"model(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name):
paths.append(file)
# limit search to subfolder if requested
@@ -98,9 +97,7 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
if variant == ModelRepoVariant.Flax:
result.add(path)
# Note: '.model' was added to support:
# https://huggingface.co/black-forest-labs/FLUX.1-schnell/blob/768d12a373ed5cc9ef9a9dea7504dc09fcc14842/tokenizer_2/spiece.model
elif path.suffix in [".json", ".txt", ".model"]:
elif path.suffix in [".json", ".txt"]:
result.add(path)
elif variant in [
@@ -143,23 +140,6 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
continue
for candidate_list in subfolder_weights.values():
# Check if at least one of the files has the explicit fp16 variant.
at_least_one_fp16 = False
for candidate in candidate_list:
if len(candidate.path.suffixes) == 2 and candidate.path.suffixes[0] == ".fp16":
at_least_one_fp16 = True
break
if not at_least_one_fp16:
# If none of the candidates in this candidate_list have the explicit fp16 variant label, then this
# candidate_list probably doesn't adhere to the variant naming convention that we expected. In this case,
# we'll simply keep all the candidates. An example of a model that hits this case is
# `black-forest-labs/FLUX.1-schnell` (as of commit 012d2fd).
for candidate in candidate_list:
result.add(candidate.path)
# The candidate_list seems to have the expected variant naming convention. We'll select the highest scoring
# candidate.
highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score)
if highest_score_candidate:
result.add(highest_score_candidate.path)

View File

@@ -1,77 +0,0 @@
import json
import os
from typing import Union
from diffusers.models.model_loading_utils import load_state_dict
from diffusers.utils import (
CONFIG_NAME,
SAFE_WEIGHTS_INDEX_NAME,
SAFETENSORS_WEIGHTS_NAME,
_get_checkpoint_shard_files,
is_accelerate_available,
)
from optimum.quanto.models import QuantizedDiffusersModel
from optimum.quanto.models.shared_dict import ShardedStateDict
from invokeai.backend.requantize import requantize
class FastQuantizedDiffusersModel(QuantizedDiffusersModel):
@classmethod
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike]):
"""We override the `from_pretrained()` method in order to use our custom `requantize()` implementation."""
if cls.base_class is None:
raise ValueError("The `base_class` attribute needs to be configured.")
if not is_accelerate_available():
raise ValueError("Reloading a quantized diffusers model requires the accelerate library.")
from accelerate import init_empty_weights
if os.path.isdir(model_name_or_path):
# Look for a quantization map
qmap_path = os.path.join(model_name_or_path, cls._qmap_name())
if not os.path.exists(qmap_path):
raise ValueError(f"No quantization map found in {model_name_or_path}: is this a quantized model ?")
# Look for original model config file.
model_config_path = os.path.join(model_name_or_path, CONFIG_NAME)
if not os.path.exists(model_config_path):
raise ValueError(f"{CONFIG_NAME} not found in {model_name_or_path}.")
with open(qmap_path, "r", encoding="utf-8") as f:
qmap = json.load(f)
with open(model_config_path, "r", encoding="utf-8") as f:
original_model_cls_name = json.load(f)["_class_name"]
configured_cls_name = cls.base_class.__name__
if configured_cls_name != original_model_cls_name:
raise ValueError(
f"Configured base class ({configured_cls_name}) differs from what was derived from the provided configuration ({original_model_cls_name})."
)
# Create an empty model
config = cls.base_class.load_config(model_name_or_path)
with init_empty_weights():
model = cls.base_class.from_config(config)
# Look for the index of a sharded checkpoint
checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
if os.path.exists(checkpoint_file):
# Convert the checkpoint path to a list of shards
_, sharded_metadata = _get_checkpoint_shard_files(model_name_or_path, checkpoint_file)
# Create a mapping for the sharded safetensor files
state_dict = ShardedStateDict(model_name_or_path, sharded_metadata["weight_map"])
else:
# Look for a single checkpoint file
checkpoint_file = os.path.join(model_name_or_path, SAFETENSORS_WEIGHTS_NAME)
if not os.path.exists(checkpoint_file):
raise ValueError(f"No safetensor weights found in {model_name_or_path}.")
# Get state_dict from model checkpoint
state_dict = load_state_dict(checkpoint_file)
# Requantize and load quantized weights from state_dict
requantize(model, state_dict=state_dict, quantization_map=qmap)
model.eval()
return cls(model)
else:
raise NotImplementedError("Reloading quantized models directly from the hub is not supported yet.")

View File

@@ -1,61 +0,0 @@
import json
import os
from typing import Union
from optimum.quanto.models import QuantizedTransformersModel
from optimum.quanto.models.shared_dict import ShardedStateDict
from transformers import AutoConfig
from transformers.modeling_utils import get_checkpoint_shard_files, load_state_dict
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, is_accelerate_available
from invokeai.backend.requantize import requantize
class FastQuantizedTransformersModel(QuantizedTransformersModel):
@classmethod
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike]):
"""We override the `from_pretrained()` method in order to use our custom `requantize()` implementation."""
if cls.auto_class is None:
raise ValueError(
"Quantized models cannot be reloaded using {cls}: use a specialized quantized class such as QuantizedModelForCausalLM instead."
)
if not is_accelerate_available():
raise ValueError("Reloading a quantized transformers model requires the accelerate library.")
from accelerate import init_empty_weights
if os.path.isdir(model_name_or_path):
# Look for a quantization map
qmap_path = os.path.join(model_name_or_path, cls._qmap_name())
if not os.path.exists(qmap_path):
raise ValueError(f"No quantization map found in {model_name_or_path}: is this a quantized model ?")
with open(qmap_path, "r", encoding="utf-8") as f:
qmap = json.load(f)
# Create an empty model
config = AutoConfig.from_pretrained(model_name_or_path)
with init_empty_weights():
model = cls.auto_class.from_config(config)
# Look for the index of a sharded checkpoint
checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
if os.path.exists(checkpoint_file):
# Convert the checkpoint path to a list of shards
checkpoint_file, sharded_metadata = get_checkpoint_shard_files(model_name_or_path, checkpoint_file)
# Create a mapping for the sharded safetensor files
state_dict = ShardedStateDict(model_name_or_path, sharded_metadata["weight_map"])
else:
# Look for a single checkpoint file
checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_NAME)
if not os.path.exists(checkpoint_file):
raise ValueError(f"No safetensor weights found in {model_name_or_path}.")
# Get state_dict from model checkpoint
state_dict = load_state_dict(checkpoint_file)
# Requantize and load quantized weights from state_dict
requantize(model, state_dict=state_dict, quantization_map=qmap)
if getattr(model.config, "tie_word_embeddings", True):
# Tie output weight embeddings to input weight embeddings
# Note that if they were quantized they would NOT be tied
model.tie_weights()
# Set model in evaluation mode as it is done in transformers
model.eval()
return cls(model)
else:
raise NotImplementedError("Reloading quantized models directly from the hub is not supported yet.")

View File

@@ -1,53 +0,0 @@
from typing import Any, Dict
import torch
from optimum.quanto.quantize import _quantize_submodule
# def custom_freeze(model: torch.nn.Module):
# for name, m in model.named_modules():
# if isinstance(m, QModuleMixin):
# m.weight =
# m.freeze()
def requantize(
model: torch.nn.Module,
state_dict: Dict[str, Any],
quantization_map: Dict[str, Dict[str, str]],
device: torch.device = None,
):
if device is None:
device = next(model.parameters()).device
if device.type == "meta":
device = torch.device("cpu")
# Quantize the model with parameters from the quantization map
for name, m in model.named_modules():
qconfig = quantization_map.get(name, None)
if qconfig is not None:
weights = qconfig["weights"]
if weights == "none":
weights = None
activations = qconfig["activations"]
if activations == "none":
activations = None
_quantize_submodule(model, name, m, weights=weights, activations=activations)
# Move model parameters and buffers to CPU before materializing quantized weights
for name, m in model.named_modules():
def move_tensor(t, device):
if t.device.type == "meta":
return torch.empty_like(t, device=device)
return t.to(device)
for name, param in m.named_parameters(recurse=False):
setattr(m, name, torch.nn.Parameter(move_tensor(param, "cpu")))
for name, param in m.named_buffers(recurse=False):
setattr(m, name, move_tensor(param, "cpu"))
# Freeze model and move to target device
# freeze(model)
# model.to(device)
# Load the quantized model weights
model.load_state_dict(state_dict, strict=False)

View File

@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import diffusers
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders.single_file_model import FromOriginalModelMixin
from diffusers.loaders import FromOriginalControlNetMixin
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
from diffusers.models.embeddings import (
@@ -32,7 +32,7 @@ from invokeai.backend.util.logging import InvokeAILogger
logger = InvokeAILogger.get_logger(__name__)
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
"""
A ControlNet model.

View File

@@ -53,64 +53,63 @@
},
"dependencies": {
"@chakra-ui/react-use-size": "^2.1.0",
"@dagrejs/dagre": "^1.1.2",
"@dagrejs/graphlib": "^2.2.2",
"@dagrejs/dagre": "^1.1.3",
"@dagrejs/graphlib": "^2.2.3",
"@dnd-kit/core": "^6.1.0",
"@dnd-kit/sortable": "^8.0.0",
"@dnd-kit/utilities": "^3.2.2",
"@fontsource-variable/inter": "^5.0.18",
"@invoke-ai/ui-library": "^0.0.25",
"@nanostores/react": "^0.7.2",
"@fontsource-variable/inter": "^5.0.20",
"@invoke-ai/ui-library": "^0.0.29",
"@nanostores/react": "^0.7.3",
"@reduxjs/toolkit": "2.2.3",
"@roarr/browser-log-writer": "^1.3.0",
"chakra-react-select": "^4.7.6",
"compare-versions": "^6.1.0",
"chakra-react-select": "^4.9.1",
"compare-versions": "^6.1.1",
"dateformat": "^5.0.3",
"fracturedjsonjs": "^4.0.1",
"framer-motion": "^11.1.8",
"i18next": "^23.11.3",
"i18next-http-backend": "^2.5.1",
"fracturedjsonjs": "^4.0.2",
"framer-motion": "^11.3.24",
"i18next": "^23.12.2",
"i18next-http-backend": "^2.5.2",
"idb-keyval": "^6.2.1",
"jsondiffpatch": "^0.6.0",
"konva": "^9.3.6",
"konva": "^9.3.14",
"lodash-es": "^4.17.21",
"nanostores": "^0.10.3",
"nanostores": "^0.11.2",
"new-github-issue-url": "^1.0.0",
"overlayscrollbars": "^2.7.3",
"overlayscrollbars": "^2.10.0",
"overlayscrollbars-react": "^0.5.6",
"query-string": "^9.0.0",
"query-string": "^9.1.0",
"react": "^18.3.1",
"react-colorful": "^5.6.1",
"react-dom": "^18.3.1",
"react-dropzone": "^14.2.3",
"react-error-boundary": "^4.0.13",
"react-hook-form": "^7.51.4",
"react-hook-form": "^7.52.2",
"react-hotkeys-hook": "4.5.0",
"react-i18next": "^14.1.1",
"react-icons": "^5.2.0",
"react-i18next": "^14.1.3",
"react-icons": "^5.2.1",
"react-konva": "^18.2.10",
"react-redux": "9.1.2",
"react-resizable-panels": "^2.0.19",
"react-resizable-panels": "^2.0.23",
"react-select": "5.8.0",
"react-use": "^17.5.0",
"react-virtuoso": "^4.7.10",
"reactflow": "^11.11.3",
"react-use": "^17.5.1",
"react-virtuoso": "^4.9.0",
"reactflow": "^11.11.4",
"redux-dynamic-middlewares": "^2.2.0",
"redux-remember": "^5.1.0",
"redux-undo": "^1.1.0",
"rfdc": "^1.3.1",
"rfdc": "^1.4.1",
"roarr": "^7.21.1",
"serialize-error": "^11.0.3",
"socket.io-client": "^4.7.5",
"use-debounce": "^10.0.0",
"use-debounce": "^10.0.2",
"use-device-pixel-ratio": "^1.1.2",
"use-image": "^1.1.1",
"uuid": "^9.0.1",
"zod": "^3.23.6",
"zod-validation-error": "^3.2.0"
"uuid": "^10.0.0",
"zod": "^3.23.8",
"zod-validation-error": "^3.3.1"
},
"peerDependencies": {
"@chakra-ui/react": "^2.8.2",
"react": "^18.2.0",
"react-dom": "^18.2.0",
"ts-toolbelt": "^9.6.0"
@@ -118,38 +117,38 @@
"devDependencies": {
"@invoke-ai/eslint-config-react": "^0.0.14",
"@invoke-ai/prettier-config-react": "^0.0.7",
"@storybook/addon-essentials": "^8.0.10",
"@storybook/addon-interactions": "^8.0.10",
"@storybook/addon-links": "^8.0.10",
"@storybook/addon-storysource": "^8.0.10",
"@storybook/manager-api": "^8.0.10",
"@storybook/react": "^8.0.10",
"@storybook/react-vite": "^8.0.10",
"@storybook/theming": "^8.0.10",
"@storybook/addon-essentials": "^8.2.8",
"@storybook/addon-interactions": "^8.2.8",
"@storybook/addon-links": "^8.2.8",
"@storybook/addon-storysource": "^8.2.8",
"@storybook/manager-api": "^8.2.8",
"@storybook/react": "^8.2.8",
"@storybook/react-vite": "^8.2.8",
"@storybook/theming": "^8.2.8",
"@types/dateformat": "^5.0.2",
"@types/lodash-es": "^4.17.12",
"@types/node": "^20.12.10",
"@types/react": "^18.3.1",
"@types/node": "^20.14.15",
"@types/react": "^18.3.3",
"@types/react-dom": "^18.3.0",
"@types/uuid": "^9.0.8",
"@vitejs/plugin-react-swc": "^3.6.0",
"@types/uuid": "^10.0.0",
"@vitejs/plugin-react-swc": "^3.7.0",
"@vitest/coverage-v8": "^1.5.0",
"@vitest/ui": "^1.5.0",
"concurrently": "^8.2.2",
"dpdm": "^3.14.0",
"eslint": "^8.57.0",
"eslint-plugin-i18next": "^6.0.3",
"eslint-plugin-i18next": "^6.0.9",
"eslint-plugin-path": "^1.3.0",
"knip": "^5.12.3",
"knip": "^5.27.2",
"openapi-types": "^12.1.3",
"openapi-typescript": "^6.7.5",
"prettier": "^3.2.5",
"openapi-typescript": "^7.3.0",
"prettier": "^3.3.3",
"rollup-plugin-visualizer": "^5.12.0",
"storybook": "^8.0.10",
"storybook": "^8.2.8",
"ts-toolbelt": "^9.6.0",
"tsafe": "^1.6.6",
"typescript": "^5.4.5",
"vite": "^5.2.11",
"tsafe": "^1.7.2",
"typescript": "^5.5.4",
"vite": "^5.4.0",
"vite-plugin-css-injected-by-js": "^3.5.1",
"vite-plugin-dts": "^3.9.1",
"vite-plugin-eslint": "^1.8.1",

File diff suppressed because it is too large Load Diff

View File

@@ -200,6 +200,7 @@
"delete": "Delete",
"depthAnything": "Depth Anything",
"depthAnythingDescription": "Depth map generation using the Depth Anything technique",
"depthAnythingSmallV2": "Small V2",
"depthMidas": "Depth (Midas)",
"depthMidasDescription": "Depth map generation using Midas",
"depthZoe": "Depth (Zoe)",
@@ -1140,6 +1141,8 @@
"imageSavingFailed": "Image Saving Failed",
"imageUploaded": "Image Uploaded",
"imageUploadFailed": "Image Upload Failed",
"importFailed": "Import Failed",
"importSuccessful": "Import Successful",
"invalidUpload": "Invalid Upload",
"loadedWithWarnings": "Workflow Loaded with Warnings",
"maskSavedAssets": "Mask Saved to Assets",
@@ -1672,6 +1675,7 @@
"layers_other": "Layers"
},
"upscaling": {
"upscale": "Upscale",
"creativity": "Creativity",
"exceedsMaxSize": "Upscale settings exceed max size limit",
"exceedsMaxSizeDetails": "Max upscale limit is {{maxUpscaleDimension}}x{{maxUpscaleDimension}} pixels. Please try a smaller image or decrease your scale selection.",
@@ -1688,6 +1692,53 @@
"missingUpscaleModel": "Missing upscale model",
"missingTileControlNetModel": "No valid tile ControlNet models installed"
},
"stylePresets": {
"active": "Active",
"choosePromptTemplate": "Choose Prompt Template",
"clearTemplateSelection": "Clear Template Selection",
"copyTemplate": "Copy Template",
"createPromptTemplate": "Create Prompt Template",
"defaultTemplates": "Default Templates",
"deleteImage": "Delete Image",
"deleteTemplate": "Delete Template",
"deleteTemplate2": "Are you sure you want to delete this template? This cannot be undone.",
"exportPromptTemplates": "Export My Prompt Templates (CSV)",
"editTemplate": "Edit Template",
"exportDownloaded": "Export Downloaded",
"exportFailed": "Unable to generate and download CSV",
"flatten": "Flatten selected template into current prompt",
"importTemplates": "Import Prompt Templates (CSV/JSON)",
"acceptedColumnsKeys": "Accepted columns/keys:",
"nameColumn": "'name'",
"positivePromptColumn": "'prompt' or 'positive_prompt'",
"negativePromptColumn": "'negative_prompt'",
"insertPlaceholder": "Insert placeholder",
"myTemplates": "My Templates",
"name": "Name",
"negativePrompt": "Negative Prompt",
"noTemplates": "No templates",
"noMatchingTemplates": "No matching templates",
"promptTemplatesDesc1": "Prompt templates add text to the prompts you write in the prompt box.",
"promptTemplatesDesc2": "Use the placeholder string <Pre>{{placeholder}}</Pre> to specify where your prompt should be included in the template.",
"promptTemplatesDesc3": "If you omit the placeholder, the template will be appended to the end of your prompt.",
"positivePrompt": "Positive Prompt",
"preview": "Preview",
"private": "Private",
"promptTemplateCleared": "Prompt Template Cleared",
"searchByName": "Search by name",
"shared": "Shared",
"sharedTemplates": "Shared Templates",
"templateActions": "Template Actions",
"templateDeleted": "Prompt template deleted",
"toggleViewMode": "Toggle View Mode",
"type": "Type",
"unableToDeleteTemplate": "Unable to delete prompt template",
"updatePromptTemplate": "Update Prompt Template",
"uploadImage": "Upload Image",
"useForTemplate": "Use For Prompt Template",
"viewList": "View Template List",
"viewModeTooltip": "This is how your prompt will look with your currently selected template. To edit your prompt, click anywhere in the text box."
},
"upsell": {
"inviteTeammates": "Invite Teammates",
"professional": "Professional",

View File

@@ -90,7 +90,7 @@
"disabled": "Disabilitato",
"comparingDesc": "Confronta due immagini",
"comparing": "Confronta",
"dontShowMeThese": "Non mostrarmi questi"
"dontShowMeThese": "Non mostrare più"
},
"gallery": {
"galleryImageSize": "Dimensione dell'immagine",
@@ -701,7 +701,9 @@
"baseModelChanged": "Modello base modificato",
"sessionRef": "Sessione: {{sessionId}}",
"somethingWentWrong": "Qualcosa è andato storto",
"outOfMemoryErrorDesc": "Le impostazioni della generazione attuale superano la capacità del sistema. Modifica le impostazioni e riprova."
"outOfMemoryErrorDesc": "Le impostazioni della generazione attuale superano la capacità del sistema. Modifica le impostazioni e riprova.",
"importFailed": "Importazione non riuscita",
"importSuccessful": "Importazione riuscita"
},
"tooltip": {
"feature": {
@@ -927,7 +929,7 @@
"missingInvocationTemplate": "Modello di invocazione mancante",
"missingFieldTemplate": "Modello di campo mancante",
"singleFieldType": "{{name}} (Singola)",
"imageAccessError": "Impossibile trovare l'immagine {{image_name}}, ripristino delle impostazioni predefinite",
"imageAccessError": "Impossibile trovare l'immagine {{image_name}}, ripristino ai valori predefiniti",
"boardAccessError": "Impossibile trovare la bacheca {{board_id}}, ripristino ai valori predefiniti",
"modelAccessError": "Impossibile trovare il modello {{key}}, ripristino ai valori predefiniti"
},
@@ -1526,7 +1528,7 @@
},
"upscaleModel": {
"paragraphs": [
"Il modello di ampliamento ridimensiona l'immagine alle dimensioni di uscita prima che vengano aggiunti i dettagli. È possibile utilizzare qualsiasi modello di ampliamento supportato, ma alcuni sono specializzati per diversi tipi di immagini, come foto o disegni al tratto."
"Il modello di ampliamento (Upscale), scala l'immagine alle dimensioni di uscita prima di aggiungere i dettagli. È possibile utilizzare qualsiasi modello di ampliamento supportato, ma alcuni sono specializzati per diversi tipi di immagini, come foto o disegni al tratto."
],
"heading": "Modello di ampliamento"
},
@@ -1735,12 +1737,58 @@
"missingUpscaleModel": "Modello per lampliamento mancante",
"missingTileControlNetModel": "Nessun modello ControlNet Tile valido installato",
"postProcessingModel": "Modello di post-elaborazione",
"postProcessingMissingModelWarning": "Visita <LinkComponent>Gestione modelli</LinkComponent> per installare un modello di post-elaborazione (da immagine a immagine)."
"postProcessingMissingModelWarning": "Visita <LinkComponent>Gestione modelli</LinkComponent> per installare un modello di post-elaborazione (da immagine a immagine).",
"exceedsMaxSize": "Le impostazioni di ampliamento superano il limite massimo delle dimensioni",
"exceedsMaxSizeDetails": "Il limite massimo di ampliamento è {{maxUpscaleDimension}}x{{maxUpscaleDimension}} pixel. Prova un'immagine più piccola o diminuisci la scala selezionata."
},
"upsell": {
"inviteTeammates": "Invita collaboratori",
"shareAccess": "Condividi l'accesso",
"professional": "Professionale",
"professionalUpsell": "Disponibile nell'edizione Professional di Invoke. Fai clic qui o visita invoke.com/pricing per ulteriori dettagli."
},
"stylePresets": {
"active": "Attivo",
"choosePromptTemplate": "Scegli un modello di prompt",
"clearTemplateSelection": "Cancella selezione modello",
"copyTemplate": "Copia modello",
"createPromptTemplate": "Crea modello di prompt",
"defaultTemplates": "Modelli predefiniti",
"deleteImage": "Elimina immagine",
"deleteTemplate": "Elimina modello",
"editTemplate": "Modifica modello",
"flatten": "Unisci il modello selezionato al prompt corrente",
"insertPlaceholder": "Inserisci segnaposto",
"myTemplates": "I miei modelli",
"name": "Nome",
"negativePrompt": "Prompt Negativo",
"noMatchingTemplates": "Nessun modello corrispondente",
"promptTemplatesDesc1": "I modelli di prompt aggiungono testo ai prompt che scrivi nelle caselle dei prompt.",
"promptTemplatesDesc3": "Se si omette il segnaposto, il modello verrà aggiunto alla fine del prompt.",
"positivePrompt": "Prompt Positivo",
"preview": "Anteprima",
"private": "Privato",
"searchByName": "Cerca per nome",
"shared": "Condiviso",
"sharedTemplates": "Modelli condivisi",
"templateDeleted": "Modello di prompt eliminato",
"toggleViewMode": "Attiva/disattiva visualizzazione",
"uploadImage": "Carica immagine",
"useForTemplate": "Usa per modello di prompt",
"viewList": "Visualizza l'elenco dei modelli",
"viewModeTooltip": "Ecco come apparirà il tuo prompt con il modello attualmente selezionato. Per modificare il tuo prompt, clicca in un punto qualsiasi della casella di testo.",
"deleteTemplate2": "Vuoi davvero eliminare questo modello? Questa operazione non può essere annullata.",
"unableToDeleteTemplate": "Impossibile eliminare il modello di prompt",
"updatePromptTemplate": "Aggiorna il modello di prompt",
"type": "Tipo",
"promptTemplatesDesc2": "Utilizza la stringa segnaposto <Pre>{{placeholder}}</Pre> per specificare dove inserire il tuo prompt nel modello.",
"importTemplates": "Importa modelli di prompt (CSV/JSON)",
"exportDownloaded": "Esportazione completata",
"exportFailed": "Impossibile generare e scaricare il file CSV",
"exportPromptTemplates": "Esporta i miei modelli di prompt (CSV)",
"positivePromptColumn": "'prompt' o 'positive_prompt'",
"noTemplates": "Nessun modello",
"acceptedColumnsKeys": "Colonne/chiavi accettate:",
"templateActions": "Azioni modello"
}
}

View File

@@ -91,7 +91,8 @@
"enabled": "Включено",
"disabled": "Отключено",
"comparingDesc": "Сравнение двух изображений",
"comparing": "Сравнение"
"comparing": "Сравнение",
"dontShowMeThese": "Не показывай мне это"
},
"gallery": {
"galleryImageSize": "Размер изображений",
@@ -153,7 +154,11 @@
"showArchivedBoards": "Показать архивированные доски",
"searchImages": "Поиск по метаданным",
"displayBoardSearch": "Отобразить поиск досок",
"displaySearch": "Отобразить поиск"
"displaySearch": "Отобразить поиск",
"exitBoardSearch": "Выйти из поиска досок",
"go": "Перейти",
"exitSearch": "Выйти из поиска",
"jump": "Пыгнуть"
},
"hotkeys": {
"keyboardShortcuts": "Горячие клавиши",
@@ -376,6 +381,10 @@
"toggleViewer": {
"title": "Переключить просмотр изображений",
"desc": "Переключение между средством просмотра изображений и рабочей областью для текущей вкладки."
},
"postProcess": {
"desc": "Обработайте текущее изображение с помощью выбранной модели постобработки",
"title": "Обработать изображение"
}
},
"modelManager": {
@@ -589,7 +598,10 @@
"infillColorValue": "Цвет заливки",
"globalSettings": "Глобальные настройки",
"globalNegativePromptPlaceholder": "Глобальный негативный запрос",
"globalPositivePromptPlaceholder": "Глобальный запрос"
"globalPositivePromptPlaceholder": "Глобальный запрос",
"postProcessing": "Постобработка (Shift + U)",
"processImage": "Обработка изображения",
"sendToUpscale": "Отправить на увеличение"
},
"settings": {
"models": "Модели",
@@ -623,7 +635,9 @@
"intermediatesCleared_many": "Очищено {{count}} промежуточных",
"clearIntermediatesDesc1": "Очистка промежуточных элементов приведет к сбросу состояния Canvas и ControlNet.",
"intermediatesClearedFailed": "Проблема очистки промежуточных",
"reloadingIn": "Перезагрузка через"
"reloadingIn": "Перезагрузка через",
"informationalPopoversDisabled": "Информационные всплывающие окна отключены",
"informationalPopoversDisabledDesc": "Информационные всплывающие окна были отключены. Включите их в Настройках."
},
"toast": {
"uploadFailed": "Загрузка не удалась",
@@ -694,7 +708,9 @@
"sessionRef": "Сессия: {{sessionId}}",
"outOfMemoryError": "Ошибка нехватки памяти",
"outOfMemoryErrorDesc": "Ваши текущие настройки генерации превышают возможности системы. Пожалуйста, измените настройки и повторите попытку.",
"somethingWentWrong": "Что-то пошло не так"
"somethingWentWrong": "Что-то пошло не так",
"importFailed": "Импорт неудачен",
"importSuccessful": "Импорт успешен"
},
"tooltip": {
"feature": {
@@ -1017,7 +1033,8 @@
"composition": "Только композиция",
"hed": "HED",
"beginEndStepPercentShort": "Начало/конец %",
"setControlImageDimensionsForce": "Скопируйте размер в Ш/В (игнорируйте модель)"
"setControlImageDimensionsForce": "Скопируйте размер в Ш/В (игнорируйте модель)",
"depthAnythingSmallV2": "Small V2"
},
"boards": {
"autoAddBoard": "Авто добавление Доски",
@@ -1042,7 +1059,7 @@
"downloadBoard": "Скачать доску",
"deleteBoard": "Удалить доску",
"deleteBoardAndImages": "Удалить доску и изображения",
"deletedBoardsCannotbeRestored": "Удаленные доски не подлежат восстановлению",
"deletedBoardsCannotbeRestored": "Удаленные доски не могут быть восстановлены. Выбор «Удалить только доску» переведет изображения в состояние без категории.",
"assetsWithCount_one": "{{count}} ассет",
"assetsWithCount_few": "{{count}} ассета",
"assetsWithCount_many": "{{count}} ассетов",
@@ -1057,7 +1074,11 @@
"boards": "Доски",
"addPrivateBoard": "Добавить личную доску",
"private": "Личные доски",
"shared": "Общие доски"
"shared": "Общие доски",
"hideBoards": "Скрыть доски",
"viewBoards": "Просмотреть доски",
"noBoards": "Нет досок {{boardType}}",
"deletedPrivateBoardsCannotbeRestored": "Удаленные доски не могут быть восстановлены. Выбор «Удалить только доску» переведет изображения в приватное состояние без категории для создателя изображения."
},
"dynamicPrompts": {
"seedBehaviour": {
@@ -1417,6 +1438,30 @@
"paragraphs": [
"Метод, с помощью которого применяется текущий IP-адаптер."
]
},
"structure": {
"paragraphs": [
"Структура контролирует, насколько точно выходное изображение будет соответствовать макету оригинала. Низкая структура допускает значительные изменения, в то время как высокая структура строго сохраняет исходную композицию и макет."
],
"heading": "Структура"
},
"scale": {
"paragraphs": [
"Масштаб управляет размером выходного изображения и основывается на кратном разрешении входного изображения. Например, при увеличении в 2 раза изображения 1024x1024 на выходе получится 2048 x 2048."
],
"heading": "Масштаб"
},
"creativity": {
"paragraphs": [
"Креативность контролирует степень свободы, предоставляемой модели при добавлении деталей. При низкой креативности модель остается близкой к оригинальному изображению, в то время как высокая креативность позволяет вносить больше изменений. При использовании подсказки высокая креативность увеличивает влияние подсказки."
],
"heading": "Креативность"
},
"upscaleModel": {
"heading": "Модель увеличения",
"paragraphs": [
"Модель увеличения масштаба масштабирует изображение до выходного размера перед добавлением деталей. Можно использовать любую поддерживаемую модель масштабирования, но некоторые из них специализированы для различных видов изображений, например фотографий или линейных рисунков."
]
}
},
"metadata": {
@@ -1693,7 +1738,78 @@
"canvasTab": "$t(ui.tabs.canvas) $t(common.tab)",
"queueTab": "$t(ui.tabs.queue) $t(common.tab)",
"modelsTab": "$t(ui.tabs.models) $t(common.tab)",
"queue": "Очередь"
"queue": "Очередь",
"upscaling": "Увеличение",
"upscalingTab": "$t(ui.tabs.upscaling) $t(common.tab)"
}
},
"upscaling": {
"exceedsMaxSize": "Параметры масштабирования превышают максимальный размер",
"exceedsMaxSizeDetails": "Максимальный предел масштабирования составляет {{maxUpscaleDimension}}x{{maxUpscaleDimension}} пикселей. Пожалуйста, попробуйте использовать меньшее изображение или уменьшите масштаб.",
"structure": "Структура",
"missingTileControlNetModel": "Не установлены подходящие модели ControlNet",
"missingUpscaleInitialImage": "Отсутствует увеличиваемое изображение",
"missingUpscaleModel": "Отсутствует увеличивающая модель",
"creativity": "Креативность",
"upscaleModel": "Модель увеличения",
"scale": "Масштаб",
"mainModelDesc": "Основная модель (архитектура SD1.5 или SDXL)",
"upscaleModelDesc": "Модель увеличения (img2img)",
"postProcessingModel": "Модель постобработки",
"tileControlNetModelDesc": "Модель ControlNet для выбранной архитектуры основной модели",
"missingModelsWarning": "Зайдите в <LinkComponent>Менеджер моделей</LinkComponent> чтоб установить необходимые модели:",
"postProcessingMissingModelWarning": "Посетите <LinkComponent>Менеджер моделей</LinkComponent>, чтобы установить модель постобработки (img2img)."
},
"stylePresets": {
"noMatchingTemplates": "Нет подходящих шаблонов",
"promptTemplatesDesc1": "Шаблоны подсказок добавляют текст к подсказкам, которые вы пишете в окне подсказок.",
"sharedTemplates": "Общие шаблоны",
"templateDeleted": "Шаблон запроса удален",
"toggleViewMode": "Переключить режим просмотра",
"type": "Тип",
"unableToDeleteTemplate": "Не получилось удалить шаблон запроса",
"viewModeTooltip": "Вот как будет выглядеть ваш запрос с выбранным шаблоном. Чтобы его отредактировать, щелкните в любом месте текстового поля.",
"viewList": "Просмотреть список шаблонов",
"active": "Активно",
"choosePromptTemplate": "Выберите шаблон запроса",
"defaultTemplates": "Стандартные шаблоны",
"deleteImage": "Удалить изображение",
"deleteTemplate": "Удалить шаблон",
"deleteTemplate2": "Вы уверены, что хотите удалить этот шаблон? Это нельзя отменить.",
"editTemplate": "Редактировать шаблон",
"exportPromptTemplates": "Экспорт моих шаблонов запроса (CSV)",
"exportDownloaded": "Экспорт скачан",
"exportFailed": "Невозможно сгенерировать и загрузить CSV",
"flatten": "Объединить выбранный шаблон с текущим запросом",
"acceptedColumnsKeys": "Принимаемые столбцы/ключи:",
"positivePromptColumn": "'prompt' или 'positive_prompt'",
"insertPlaceholder": "Вставить заполнитель",
"name": "Имя",
"negativePrompt": "Негативный запрос",
"promptTemplatesDesc3": "Если вы не используете заполнитель, шаблон будет добавлен в конец запроса.",
"positivePrompt": "Позитивный запрос",
"preview": "Предпросмотр",
"private": "Приватный",
"templateActions": "Действия с шаблоном",
"updatePromptTemplate": "Обновить шаблон запроса",
"uploadImage": "Загрузить изображение",
"useForTemplate": "Использовать для шаблона запроса",
"clearTemplateSelection": "Очистить выбор шаблона",
"copyTemplate": "Копировать шаблон",
"createPromptTemplate": "Создать шаблон запроса",
"importTemplates": "Импортировать шаблоны запроса (CSV/JSON)",
"nameColumn": "'name'",
"negativePromptColumn": "'negative_prompt'",
"myTemplates": "Мои шаблоны",
"noTemplates": "Нет шаблонов",
"promptTemplatesDesc2": "Используйте строку-заполнитель <Pre>{{placeholder}}</Pre>, чтобы указать место, куда должен быть включен ваш запрос в шаблоне.",
"searchByName": "Поиск по имени",
"shared": "Общий"
},
"upsell": {
"inviteTeammates": "Пригласите членов команды",
"professional": "Профессионал",
"professionalUpsell": "Доступно в профессиональной версии Invoke. Нажмите здесь или посетите invoke.com/pricing для получения более подробной информации.",
"shareAccess": "Поделиться доступом"
}
}

View File

@@ -493,7 +493,8 @@
"defaultSettingsSaved": "默认设置已保存",
"huggingFacePlaceholder": "所有者或模型名称",
"huggingFaceRepoID": "HuggingFace仓库ID",
"loraTriggerPhrases": "LoRA 触发词"
"loraTriggerPhrases": "LoRA 触发词",
"ipAdapters": "IP适配器"
},
"parameters": {
"images": "图像",
@@ -1702,7 +1703,9 @@
"upscaleModelDesc": "图像放大(图像到图像转换)模型",
"postProcessingMissingModelWarning": "请访问 <LinkComponent>模型管理器</LinkComponent>来安装一个后处理(图像到图像转换)模型.",
"missingModelsWarning": "请访问<LinkComponent>模型管理器</LinkComponent> 安装所需的模型:",
"mainModelDesc": "主模型SD1.5或SDXL架构"
"mainModelDesc": "主模型SD1.5或SDXL架构",
"exceedsMaxSize": "放大设置超出了最大尺寸限制",
"exceedsMaxSizeDetails": "最大放大限制是 {{maxUpscaleDimension}}x{{maxUpscaleDimension}} 像素.请尝试一个较小的图像或减少您的缩放选择."
},
"upsell": {
"inviteTeammates": "邀请团队成员",

View File

@@ -1,26 +1,40 @@
/* eslint-disable no-console */
import fs from 'node:fs';
import openapiTS from 'openapi-typescript';
import openapiTS, { astToString } from 'openapi-typescript';
import ts from 'typescript';
const OPENAPI_URL = 'http://127.0.0.1:9090/openapi.json';
const OUTPUT_FILE = 'src/services/api/schema.ts';
async function generateTypes(schema) {
process.stdout.write(`Generating types ${OUTPUT_FILE}...`);
// Use https://ts-ast-viewer.com to figure out how to create these AST nodes - define a type and use the bottom-left pane's output
// `Blob` type
const BLOB = ts.factory.createTypeReferenceNode(ts.factory.createIdentifier('Blob'));
// `null` type
const NULL = ts.factory.createLiteralTypeNode(ts.factory.createNull());
// `Record<string, unknown>` type
const RECORD_STRING_UNKNOWN = ts.factory.createTypeReferenceNode(ts.factory.createIdentifier('Record'), [
ts.factory.createKeywordTypeNode(ts.SyntaxKind.StringKeyword),
ts.factory.createKeywordTypeNode(ts.SyntaxKind.UnknownKeyword),
]);
const types = await openapiTS(schema, {
exportType: true,
transform: (schemaObject) => {
if ('format' in schemaObject && schemaObject.format === 'binary') {
return schemaObject.nullable ? 'Blob | null' : 'Blob';
return schemaObject.nullable ? ts.factory.createUnionTypeNode([BLOB, NULL]) : BLOB;
}
if (schemaObject.title === 'MetadataField') {
// This is `Record<string, never>` by default, but it actually accepts any a dict of any valid JSON value.
return 'Record<string, unknown>';
return RECORD_STRING_UNKNOWN;
}
},
defaultNonNullable: false,
});
fs.writeFileSync(OUTPUT_FILE, types);
fs.writeFileSync(OUTPUT_FILE, astToString(types));
process.stdout.write(`\nOK!\r\n`);
}

View File

@@ -13,11 +13,13 @@ import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardMo
import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageModal';
import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicPromptsPreviewModal';
import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast';
import { StylePresetModal } from 'features/stylePresets/components/StylePresetForm/StylePresetModal';
import { configChanged } from 'features/system/store/configSlice';
import { languageSelector } from 'features/system/store/systemSelectors';
import InvokeTabs from 'features/ui/components/InvokeTabs';
import type { InvokeTabName } from 'features/ui/store/tabMap';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { useGetAndLoadLibraryWorkflow } from 'features/workflowLibrary/hooks/useGetAndLoadLibraryWorkflow';
import { AnimatePresence } from 'framer-motion';
import i18n from 'i18n';
import { size } from 'lodash-es';
@@ -36,10 +38,11 @@ interface Props {
imageName: string;
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
};
selectedWorkflowId?: string;
destination?: InvokeTabName | undefined;
}
const App = ({ config = DEFAULT_CONFIG, selectedImage, destination }: Props) => {
const App = ({ config = DEFAULT_CONFIG, selectedImage, selectedWorkflowId, destination }: Props) => {
const language = useAppSelector(languageSelector);
const logger = useLogger('system');
const dispatch = useAppDispatch();
@@ -70,6 +73,14 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage, destination }: Props) =>
}
}, [dispatch, config, logger]);
const { getAndLoadWorkflow } = useGetAndLoadLibraryWorkflow();
useEffect(() => {
if (selectedWorkflowId) {
getAndLoadWorkflow(selectedWorkflowId);
}
}, [selectedWorkflowId, getAndLoadWorkflow]);
useEffect(() => {
if (destination) {
dispatch(setActiveTab(destination));
@@ -104,6 +115,7 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage, destination }: Props) =>
<DeleteImageModal />
<ChangeBoardModal />
<DynamicPromptsModal />
<StylePresetModal />
<PreselectedImage selectedImage={selectedImage} />
</ErrorBoundary>
);

View File

@@ -44,6 +44,7 @@ interface Props extends PropsWithChildren {
imageName: string;
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
};
selectedWorkflowId?: string;
destination?: InvokeTabName;
customStarUi?: CustomStarUi;
socketOptions?: Partial<ManagerOptions & SocketOptions>;
@@ -64,6 +65,7 @@ const InvokeAIUI = ({
projectUrl,
queueId,
selectedImage,
selectedWorkflowId,
destination,
customStarUi,
socketOptions,
@@ -221,7 +223,12 @@ const InvokeAIUI = ({
<React.Suspense fallback={<Loading />}>
<ThemeLocaleProvider>
<AppDndContext>
<App config={config} selectedImage={selectedImage} destination={destination} />
<App
config={config}
selectedImage={selectedImage}
selectedWorkflowId={selectedWorkflowId}
destination={destination}
/>
</AppDndContext>
</ThemeLocaleProvider>
</React.Suspense>

View File

@@ -11,6 +11,9 @@ import {
promptsChanged,
} from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import { getShouldProcessPrompt } from 'features/dynamicPrompts/util/getShouldProcessPrompt';
import { getPresetModifiedPrompts } from 'features/nodes/util/graph/graphBuilderUtils';
import { activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice';
import { stylePresetsApi } from 'services/api/endpoints/stylePresets';
import { utilitiesApi } from 'services/api/endpoints/utilities';
import { socketConnected } from 'services/events/actions';
@@ -19,7 +22,11 @@ const matcher = isAnyOf(
combinatorialToggled,
maxPromptsChanged,
maxPromptsReset,
socketConnected
socketConnected,
activeStylePresetIdChanged,
stylePresetsApi.endpoints.deleteStylePreset.matchFulfilled,
stylePresetsApi.endpoints.updateStylePreset.matchFulfilled,
stylePresetsApi.endpoints.listStylePresets.matchFulfilled
);
export const addDynamicPromptsListener = (startAppListening: AppStartListening) => {
@@ -28,7 +35,7 @@ export const addDynamicPromptsListener = (startAppListening: AppStartListening)
effect: async (action, { dispatch, getState, cancelActiveListeners, delay }) => {
cancelActiveListeners();
const state = getState();
const { positivePrompt } = state.controlLayers.present;
const { positivePrompt } = getPresetModifiedPrompts(state);
const { maxPrompts } = state.dynamicPrompts;
if (state.config.disabledFeatures.includes('dynamicPrompting')) {

View File

@@ -28,6 +28,7 @@ import { generationPersistConfig, generationSlice } from 'features/parameters/st
import { upscalePersistConfig, upscaleSlice } from 'features/parameters/store/upscaleSlice';
import { queueSlice } from 'features/queue/store/queueSlice';
import { sdxlPersistConfig, sdxlSlice } from 'features/sdxl/store/sdxlSlice';
import { stylePresetPersistConfig, stylePresetSlice } from 'features/stylePresets/store/stylePresetSlice';
import { configSlice } from 'features/system/store/configSlice';
import { systemPersistConfig, systemSlice } from 'features/system/store/systemSlice';
import { uiPersistConfig, uiSlice } from 'features/ui/store/uiSlice';
@@ -69,6 +70,7 @@ const allReducers = {
[workflowSettingsSlice.name]: workflowSettingsSlice.reducer,
[api.reducerPath]: api.reducer,
[upscaleSlice.name]: upscaleSlice.reducer,
[stylePresetSlice.name]: stylePresetSlice.reducer,
};
const rootReducer = combineReducers(allReducers);
@@ -114,6 +116,7 @@ const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
[controlLayersPersistConfig.name]: controlLayersPersistConfig,
[workflowSettingsPersistConfig.name]: workflowSettingsPersistConfig,
[upscalePersistConfig.name]: upscalePersistConfig,
[stylePresetPersistConfig.name]: stylePresetPersistConfig,
};
const unserialize: UnserializeFunction = (data, key) => {
@@ -164,8 +167,8 @@ export const createStore = (uniqueStoreKey?: string, persist = true) =>
reducer: rememberedRootReducer,
middleware: (getDefaultMiddleware) =>
getDefaultMiddleware({
serializableCheck: false,
immutableCheck: false,
serializableCheck: import.meta.env.MODE === 'development',
immutableCheck: import.meta.env.MODE === 'development',
})
.concat(api.middleware)
.concat(dynamicMiddlewares)

View File

@@ -71,6 +71,7 @@ export type AppConfig = {
*/
maxUpscaleDimension?: number;
allowPrivateBoards: boolean;
allowPrivateStylePresets: boolean;
disabledTabs: InvokeTabName[];
disabledFeatures: AppFeature[];
disabledSDFeatures: SDFeature[];

View File

@@ -47,6 +47,7 @@ export const IAINoContentFallback = memo((props: IAINoImageFallbackProps) => {
userSelect: 'none',
opacity: 0.7,
color: 'base.500',
fontSize: 'md',
...sx,
}),
[sx]
@@ -55,11 +56,7 @@ export const IAINoContentFallback = memo((props: IAINoImageFallbackProps) => {
return (
<Flex sx={styles} {...rest}>
{icon && <Icon as={icon} boxSize={boxSize} opacity={0.7} />}
{props.label && (
<Text textAlign="center" fontSize="md">
{props.label}
</Text>
)}
{props.label && <Text textAlign="center">{props.label}</Text>}
</Flex>
);
});

View File

@@ -1,4 +1,4 @@
import { useImageUrlToBlob } from 'common/hooks/useImageUrlToBlob';
import { convertImageUrlToBlob } from 'common/util/convertImageUrlToBlob';
import { copyBlobToClipboard } from 'features/system/util/copyBlobToClipboard';
import { toast } from 'features/toast/toast';
import { useCallback, useMemo } from 'react';
@@ -6,7 +6,6 @@ import { useTranslation } from 'react-i18next';
export const useCopyImageToClipboard = () => {
const { t } = useTranslation();
const imageUrlToBlob = useImageUrlToBlob();
const isClipboardAPIAvailable = useMemo(() => {
return Boolean(navigator.clipboard) && Boolean(window.ClipboardItem);
@@ -23,7 +22,7 @@ export const useCopyImageToClipboard = () => {
});
}
try {
const blob = await imageUrlToBlob(image_url);
const blob = await convertImageUrlToBlob(image_url);
if (!blob) {
throw new Error('Unable to create Blob');
@@ -45,7 +44,7 @@ export const useCopyImageToClipboard = () => {
});
}
},
[imageUrlToBlob, isClipboardAPIAvailable, t]
[isClipboardAPIAvailable, t]
);
return { isClipboardAPIAvailable, copyImageToClipboard };

View File

@@ -1,40 +0,0 @@
import { $authToken } from 'app/store/nanostores/authToken';
import { useCallback } from 'react';
/**
* Converts an image URL to a Blob by creating an <img /> element, drawing it to canvas
* and then converting the canvas to a Blob.
*
* @returns A function that takes a URL and returns a Promise that resolves with a Blob
*/
export const useImageUrlToBlob = () => {
const imageUrlToBlob = useCallback(
async (url: string) =>
new Promise<Blob | null>((resolve) => {
const img = new Image();
img.onload = () => {
const canvas = document.createElement('canvas');
canvas.width = img.width;
canvas.height = img.height;
const context = canvas.getContext('2d');
if (!context) {
return;
}
context.drawImage(img, 0, 0);
resolve(
new Promise<Blob | null>((resolve) => {
canvas.toBlob(function (blob) {
resolve(blob);
}, 'image/png');
})
);
};
img.crossOrigin = $authToken.get() ? 'use-credentials' : 'anonymous';
img.src = url;
}),
[]
);
return imageUrlToBlob;
};

View File

@@ -0,0 +1,39 @@
import { $authToken } from 'app/store/nanostores/authToken';
/**
* Converts an image URL to a Blob by creating an <img /> element, drawing it to canvas
* and then converting the canvas to a Blob.
*
* @returns A function that takes a URL and returns a Promise that resolves with a Blob
*/
export const convertImageUrlToBlob = async (url: string) =>
new Promise<Blob | null>((resolve, reject) => {
const img = new Image();
img.onload = () => {
const canvas = document.createElement('canvas');
canvas.width = img.width;
canvas.height = img.height;
const context = canvas.getContext('2d');
if (!context) {
reject(new Error('Failed to get canvas context'));
return;
}
context.drawImage(img, 0, 0);
canvas.toBlob((blob) => {
if (blob) {
resolve(blob);
} else {
reject(new Error('Failed to convert image to blob'));
}
}, 'image/png');
};
img.onerror = () => {
reject(new Error('Image failed to load. The URL may be invalid or the object may not exist.'));
};
img.crossOrigin = $authToken.get() ? 'use-credentials' : 'anonymous';
img.src = url;
});

View File

@@ -42,6 +42,7 @@ const DepthAnythingProcessor = (props: Props) => {
const options: { label: string; value: DepthAnythingModelSize }[] = useMemo(
() => [
{ label: t('controlnet.depthAnythingSmallV2'), value: 'small_v2' },
{ label: t('controlnet.small'), value: 'small' },
{ label: t('controlnet.base'), value: 'base' },
{ label: t('controlnet.large'), value: 'large' },

View File

@@ -94,7 +94,7 @@ export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = {
buildDefaults: (baseModel?: BaseModelType) => ({
id: 'depth_anything_image_processor',
type: 'depth_anything_image_processor',
model_size: 'small',
model_size: 'small_v2',
resolution: baseModel === 'sdxl' ? 1024 : 512,
}),
},

View File

@@ -84,7 +84,7 @@ export type RequiredDepthAnythingImageProcessorInvocation = O.Required<
'type' | 'model_size' | 'resolution' | 'offload'
>;
const zDepthAnythingModelSize = z.enum(['large', 'base', 'small']);
const zDepthAnythingModelSize = z.enum(['large', 'base', 'small', 'small_v2']);
export type DepthAnythingModelSize = z.infer<typeof zDepthAnythingModelSize>;
export const isDepthAnythingModelSize = (v: unknown): v is DepthAnythingModelSize =>
zDepthAnythingModelSize.safeParse(v).success;

View File

@@ -24,6 +24,7 @@ export const DepthAnythingProcessor = memo(({ onChange, config }: Props) => {
const options: { label: string; value: DepthAnythingModelSize }[] = useMemo(
() => [
{ label: t('controlnet.depthAnythingSmallV2'), value: 'small_v2' },
{ label: t('controlnet.small'), value: 'small' },
{ label: t('controlnet.base'), value: 'base' },
{ label: t('controlnet.large'), value: 'large' },

View File

@@ -36,7 +36,7 @@ const zContentShuffleProcessorConfig = z.object({
});
export type ContentShuffleProcessorConfig = z.infer<typeof zContentShuffleProcessorConfig>;
const zDepthAnythingModelSize = z.enum(['large', 'base', 'small']);
const zDepthAnythingModelSize = z.enum(['large', 'base', 'small', 'small_v2']);
export type DepthAnythingModelSize = z.infer<typeof zDepthAnythingModelSize>;
export const isDepthAnythingModelSize = (v: unknown): v is DepthAnythingModelSize =>
zDepthAnythingModelSize.safeParse(v).success;
@@ -298,7 +298,7 @@ export const CA_PROCESSOR_DATA: CAProcessorsData = {
buildDefaults: () => ({
id: 'depth_anything_image_processor',
type: 'depth_anything_image_processor',
model_size: 'small',
model_size: 'small_v2',
}),
buildNode: (image, config) => ({
...config,

View File

@@ -66,7 +66,7 @@ export const Gallery = () => {
<Flex flexDirection="column" alignItems="center" justifyContent="space-between" h="full" w="full" pt={1}>
<Tabs index={galleryView === 'images' ? 0 : 1} variant="enclosed" display="flex" flexDir="column" w="full">
<TabList gap={2} fontSize="sm" borderColor="base.800" alignItems="center" w="full">
<Text fontSize="sm" fontWeight="semibold" noOfLines={1} px="2">
<Text fontSize="sm" fontWeight="semibold" noOfLines={1} px="2" wordBreak="break-all">
{boardName}
</Text>
<Spacer />

View File

@@ -30,6 +30,7 @@ import {
PiFlowArrowBold,
PiFoldersBold,
PiImagesBold,
PiPaintBrushBold,
PiPlantBold,
PiQuotesBold,
PiShareFatBold,
@@ -55,8 +56,17 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
const { downloadImage } = useDownloadImage();
const templates = useStore($templates);
const { recallAll, remix, recallSeed, recallPrompts, hasMetadata, hasSeed, hasPrompts, isLoadingMetadata } =
useImageActions(imageDTO?.image_name);
const {
recallAll,
remix,
recallSeed,
recallPrompts,
hasMetadata,
hasSeed,
hasPrompts,
isLoadingMetadata,
createAsPreset,
} = useImageActions(imageDTO?.image_name);
const { getAndLoadEmbeddedWorkflow, getAndLoadEmbeddedWorkflowResult } = useGetAndLoadEmbeddedWorkflow({});
@@ -182,6 +192,13 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
>
{t('parameters.useAll')}
</MenuItem>
<MenuItem
icon={isLoadingMetadata ? <SpinnerIcon /> : <PiPaintBrushBold />}
onClickCapture={createAsPreset}
isDisabled={isLoadingMetadata || !hasPrompts}
>
{t('stylePresets.useForTemplate')}
</MenuItem>
<MenuDivider />
<MenuItem icon={<PiShareFatBold />} onClickCapture={handleSendToImageToImage} id="send-to-img2img">
{t('parameters.sendToImg2Img')}

View File

@@ -60,7 +60,6 @@ const ImageGalleryContent = () => {
<GalleryHeader />
<Flex alignItems="center" justifyContent="space-between" w="full">
<Button
w={112}
size="sm"
variant="ghost"
onClick={handleToggleBoardPanel}

View File

@@ -1,15 +1,25 @@
import { useAppSelector } from 'app/store/storeHooks';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { handlers, parseAndRecallAllMetadata, parseAndRecallPrompts } from 'features/metadata/util/handlers';
import { $stylePresetModalState } from 'features/stylePresets/store/stylePresetModal';
import { activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice';
import { toast } from 'features/toast/toast';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { useCallback, useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { useDebouncedMetadata } from 'services/api/hooks/useDebouncedMetadata';
export const useImageActions = (image_name?: string) => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const activeTabName = useAppSelector(activeTabNameSelector);
const activeStylePresetId = useAppSelector((s) => s.stylePreset.activeStylePresetId);
const { metadata, isLoading: isLoadingMetadata } = useDebouncedMetadata(image_name);
const [hasMetadata, setHasMetadata] = useState(false);
const [hasSeed, setHasSeed] = useState(false);
const [hasPrompts, setHasPrompts] = useState(false);
const { data: imageDTO } = useGetImageDTOQuery(image_name ?? skipToken);
useEffect(() => {
const parseMetadata = async () => {
@@ -42,14 +52,26 @@ export const useImageActions = (image_name?: string) => {
parseMetadata();
}, [metadata]);
const clearStylePreset = useCallback(() => {
if (activeStylePresetId) {
dispatch(activeStylePresetIdChanged(null));
toast({
status: 'info',
title: t('stylePresets.promptTemplateCleared'),
});
}
}, [dispatch, activeStylePresetId, t]);
const recallAll = useCallback(() => {
parseAndRecallAllMetadata(metadata, activeTabName === 'generation');
}, [activeTabName, metadata]);
clearStylePreset();
}, [activeTabName, metadata, clearStylePreset]);
const remix = useCallback(() => {
// Recalls all metadata parameters except seed
parseAndRecallAllMetadata(metadata, activeTabName === 'generation', ['seed']);
}, [activeTabName, metadata]);
clearStylePreset();
}, [activeTabName, metadata, clearStylePreset]);
const recallSeed = useCallback(() => {
handlers.seed.parse(metadata).then((seed) => {
@@ -59,7 +81,37 @@ export const useImageActions = (image_name?: string) => {
const recallPrompts = useCallback(() => {
parseAndRecallPrompts(metadata);
}, [metadata]);
clearStylePreset();
}, [metadata, clearStylePreset]);
return { recallAll, remix, recallSeed, recallPrompts, hasMetadata, hasSeed, hasPrompts, isLoadingMetadata };
const createAsPreset = useCallback(async () => {
if (image_name && metadata && imageDTO) {
const positivePrompt = await handlers.positivePrompt.parse(metadata);
const negativePrompt = await handlers.negativePrompt.parse(metadata);
$stylePresetModalState.set({
prefilledFormData: {
name: '',
positivePrompt,
negativePrompt,
imageUrl: imageDTO.image_url,
type: 'user',
},
updatingStylePresetId: null,
isModalOpen: true,
});
}
}, [image_name, metadata, imageDTO]);
return {
recallAll,
remix,
recallSeed,
recallPrompts,
hasMetadata,
hasSeed,
hasPrompts,
isLoadingMetadata,
createAsPreset,
};
};

View File

@@ -22,11 +22,10 @@ import {
} from './constants';
import { addLoRAs } from './generation/addLoRAs';
import { addSDXLLoRas } from './generation/addSDXLLoRAs';
import { getBoardField, getSDXLStylePrompts } from './graphBuilderUtils';
import { getBoardField, getPresetModifiedPrompts } from './graphBuilderUtils';
export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise<GraphType> => {
const { model, cfgScale: cfg_scale, scheduler, steps, vaePrecision, seed, vae } = state.generation;
const { positivePrompt, negativePrompt } = state.controlLayers.present;
const { upscaleModel, upscaleInitialImage, structure, creativity, tileControlnetModel, scale } = state.upscale;
assert(model, 'No model found in state');
@@ -99,7 +98,8 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
let modelNode;
if (model.base === 'sdxl') {
const { positiveStylePrompt, negativeStylePrompt } = getSDXLStylePrompts(state);
const { positivePrompt, negativePrompt, positiveStylePrompt, negativeStylePrompt } =
getPresetModifiedPrompts(state);
posCondNode = g.addNode({
type: 'sdxl_compel_prompt',
@@ -132,6 +132,8 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
negative_style_prompt: negativeStylePrompt,
});
} else {
const { positivePrompt, negativePrompt } = getPresetModifiedPrompts(state);
posCondNode = g.addNode({
type: 'compel',
id: POSITIVE_CONDITIONING,

View File

@@ -16,7 +16,7 @@ import {
SDXL_REFINER_POSITIVE_CONDITIONING,
SDXL_REFINER_SEAMLESS,
} from 'features/nodes/util/graph/constants';
import { getSDXLStylePrompts } from 'features/nodes/util/graph/graphBuilderUtils';
import { getPresetModifiedPrompts } from 'features/nodes/util/graph/graphBuilderUtils';
import type { NonNullableGraph } from 'services/api/types';
import { isRefinerMainModelModelConfig } from 'services/api/types';
@@ -59,7 +59,7 @@ export const addSDXLRefinerToGraph = async (
const modelLoaderId = modelLoaderNodeId ? modelLoaderNodeId : SDXL_MODEL_LOADER;
// Construct Style Prompt
const { positiveStylePrompt, negativeStylePrompt } = getSDXLStylePrompts(state);
const { positiveStylePrompt, negativeStylePrompt } = getPresetModifiedPrompts(state);
// Unplug SDXL Latents Generation To Latents To Image
graph.edges = graph.edges.filter((e) => !(e.source.node_id === baseNodeId && ['latents'].includes(e.source.field)));

View File

@@ -16,7 +16,11 @@ import {
POSITIVE_CONDITIONING,
SEAMLESS,
} from 'features/nodes/util/graph/constants';
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
import {
getBoardField,
getIsIntermediate,
getPresetModifiedPrompts,
} from 'features/nodes/util/graph/graphBuilderUtils';
import type { ImageDTO, Invocation, NonNullableGraph } from 'services/api/types';
import { isNonRefinerMainModelConfig } from 'services/api/types';
@@ -51,7 +55,6 @@ export const buildCanvasImageToImageGraph = async (
seamlessXAxis,
seamlessYAxis,
} = state.generation;
const { positivePrompt, negativePrompt } = state.controlLayers.present;
// The bounding box determines width and height, not the width and height params
const { width, height } = state.canvas.boundingBoxDimensions;
@@ -71,6 +74,8 @@ export const buildCanvasImageToImageGraph = async (
const use_cpu = shouldUseCpuNoise;
const { positivePrompt, negativePrompt } = getPresetModifiedPrompts(state);
/**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
* full graph here as a template. Then use the parameters from app state and set friendlier node

View File

@@ -19,7 +19,11 @@ import {
POSITIVE_CONDITIONING,
SEAMLESS,
} from 'features/nodes/util/graph/constants';
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
import {
getBoardField,
getIsIntermediate,
getPresetModifiedPrompts,
} from 'features/nodes/util/graph/graphBuilderUtils';
import type { ImageDTO, Invocation, NonNullableGraph } from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
@@ -58,7 +62,6 @@ export const buildCanvasInpaintGraph = async (
canvasCoherenceEdgeSize,
maskBlur,
} = state.generation;
const { positivePrompt, negativePrompt } = state.controlLayers.present;
if (!model) {
log.error('No model found in state');
@@ -79,6 +82,8 @@ export const buildCanvasInpaintGraph = async (
const use_cpu = shouldUseCpuNoise;
const { positivePrompt, negativePrompt } = getPresetModifiedPrompts(state);
const graph: NonNullableGraph = {
id: CANVAS_INPAINT_GRAPH,
nodes: {

View File

@@ -23,7 +23,11 @@ import {
POSITIVE_CONDITIONING,
SEAMLESS,
} from 'features/nodes/util/graph/constants';
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
import {
getBoardField,
getIsIntermediate,
getPresetModifiedPrompts,
} from 'features/nodes/util/graph/graphBuilderUtils';
import type { ImageDTO, Invocation, NonNullableGraph } from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
@@ -70,7 +74,6 @@ export const buildCanvasOutpaintGraph = async (
canvasCoherenceEdgeSize,
maskBlur,
} = state.generation;
const { positivePrompt, negativePrompt } = state.controlLayers.present;
if (!model) {
log.error('No model found in state');
@@ -91,6 +94,8 @@ export const buildCanvasOutpaintGraph = async (
const use_cpu = shouldUseCpuNoise;
const { positivePrompt, negativePrompt } = getPresetModifiedPrompts(state);
const graph: NonNullableGraph = {
id: CANVAS_OUTPAINT_GRAPH,
nodes: {

View File

@@ -16,7 +16,11 @@ import {
SDXL_REFINER_SEAMLESS,
SEAMLESS,
} from 'features/nodes/util/graph/constants';
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from 'features/nodes/util/graph/graphBuilderUtils';
import {
getBoardField,
getIsIntermediate,
getPresetModifiedPrompts,
} from 'features/nodes/util/graph/graphBuilderUtils';
import type { ImageDTO, Invocation, NonNullableGraph } from 'services/api/types';
import { isNonRefinerMainModelConfig } from 'services/api/types';
@@ -51,7 +55,6 @@ export const buildCanvasSDXLImageToImageGraph = async (
seamlessYAxis,
img2imgStrength: strength,
} = state.generation;
const { positivePrompt, negativePrompt } = state.controlLayers.present;
const { refinerModel, refinerStart } = state.sdxl;
@@ -75,7 +78,7 @@ export const buildCanvasSDXLImageToImageGraph = async (
const use_cpu = shouldUseCpuNoise;
// Construct Style Prompt
const { positiveStylePrompt, negativeStylePrompt } = getSDXLStylePrompts(state);
const { positivePrompt, negativePrompt, positiveStylePrompt, negativeStylePrompt } = getPresetModifiedPrompts(state);
/**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the

View File

@@ -19,7 +19,11 @@ import {
SDXL_REFINER_SEAMLESS,
SEAMLESS,
} from 'features/nodes/util/graph/constants';
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from 'features/nodes/util/graph/graphBuilderUtils';
import {
getBoardField,
getIsIntermediate,
getPresetModifiedPrompts,
} from 'features/nodes/util/graph/graphBuilderUtils';
import type { ImageDTO, Invocation, NonNullableGraph } from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
@@ -58,7 +62,6 @@ export const buildCanvasSDXLInpaintGraph = async (
canvasCoherenceEdgeSize,
maskBlur,
} = state.generation;
const { positivePrompt, negativePrompt } = state.controlLayers.present;
const { refinerModel, refinerStart } = state.sdxl;
@@ -83,7 +86,7 @@ export const buildCanvasSDXLInpaintGraph = async (
const use_cpu = shouldUseCpuNoise;
// Construct Style Prompt
const { positiveStylePrompt, negativeStylePrompt } = getSDXLStylePrompts(state);
const { positivePrompt, negativePrompt, positiveStylePrompt, negativeStylePrompt } = getPresetModifiedPrompts(state);
const graph: NonNullableGraph = {
id: SDXL_CANVAS_INPAINT_GRAPH,

View File

@@ -23,7 +23,11 @@ import {
SDXL_REFINER_SEAMLESS,
SEAMLESS,
} from 'features/nodes/util/graph/constants';
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from 'features/nodes/util/graph/graphBuilderUtils';
import {
getBoardField,
getIsIntermediate,
getPresetModifiedPrompts,
} from 'features/nodes/util/graph/graphBuilderUtils';
import type { ImageDTO, Invocation, NonNullableGraph } from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
@@ -70,7 +74,6 @@ export const buildCanvasSDXLOutpaintGraph = async (
canvasCoherenceEdgeSize,
maskBlur,
} = state.generation;
const { positivePrompt, negativePrompt } = state.controlLayers.present;
const { refinerModel, refinerStart } = state.sdxl;
@@ -94,7 +97,7 @@ export const buildCanvasSDXLOutpaintGraph = async (
const use_cpu = shouldUseCpuNoise;
// Construct Style Prompt
const { positiveStylePrompt, negativeStylePrompt } = getSDXLStylePrompts(state);
const { positivePrompt, negativePrompt, positiveStylePrompt, negativeStylePrompt } = getPresetModifiedPrompts(state);
const graph: NonNullableGraph = {
id: SDXL_CANVAS_OUTPAINT_GRAPH,

View File

@@ -14,7 +14,11 @@ import {
SDXL_REFINER_SEAMLESS,
SEAMLESS,
} from 'features/nodes/util/graph/constants';
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from 'features/nodes/util/graph/graphBuilderUtils';
import {
getBoardField,
getIsIntermediate,
getPresetModifiedPrompts,
} from 'features/nodes/util/graph/graphBuilderUtils';
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
@@ -44,7 +48,6 @@ export const buildCanvasSDXLTextToImageGraph = async (state: RootState): Promise
seamlessXAxis,
seamlessYAxis,
} = state.generation;
const { positivePrompt, negativePrompt } = state.controlLayers.present;
// The bounding box determines width and height, not the width and height params
const { width, height } = state.canvas.boundingBoxDimensions;
@@ -67,7 +70,7 @@ export const buildCanvasSDXLTextToImageGraph = async (state: RootState): Promise
let modelLoaderNodeId = SDXL_MODEL_LOADER;
// Construct Style Prompt
const { positiveStylePrompt, negativeStylePrompt } = getSDXLStylePrompts(state);
const { positivePrompt, negativePrompt, positiveStylePrompt, negativeStylePrompt } = getPresetModifiedPrompts(state);
/**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the

View File

@@ -14,7 +14,11 @@ import {
POSITIVE_CONDITIONING,
SEAMLESS,
} from 'features/nodes/util/graph/constants';
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
import {
getBoardField,
getIsIntermediate,
getPresetModifiedPrompts,
} from 'features/nodes/util/graph/graphBuilderUtils';
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
@@ -44,7 +48,6 @@ export const buildCanvasTextToImageGraph = async (state: RootState): Promise<Non
seamlessXAxis,
seamlessYAxis,
} = state.generation;
const { positivePrompt, negativePrompt } = state.controlLayers.present;
// The bounding box determines width and height, not the width and height params
const { width, height } = state.canvas.boundingBoxDimensions;
@@ -64,6 +67,8 @@ export const buildCanvasTextToImageGraph = async (state: RootState): Promise<Non
let modelLoaderNodeId = MAIN_MODEL_LOADER;
const { positivePrompt, negativePrompt } = getPresetModifiedPrompts(state);
/**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
* full graph here as a template. Then use the parameters from app state and set friendlier node

View File

@@ -22,7 +22,7 @@ import { addSeamless } from 'features/nodes/util/graph/generation/addSeamless';
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
import type { GraphType } from 'features/nodes/util/graph/generation/Graph';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import { getBoardField } from 'features/nodes/util/graph/graphBuilderUtils';
import { getBoardField, getPresetModifiedPrompts } from 'features/nodes/util/graph/graphBuilderUtils';
import type { Invocation } from 'services/api/types';
import { isNonRefinerMainModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
@@ -40,11 +40,12 @@ export const buildGenerationTabGraph = async (state: RootState): Promise<GraphTy
seed,
vae,
} = state.generation;
const { positivePrompt, negativePrompt } = state.controlLayers.present;
const { width, height } = state.controlLayers.present.size;
assert(model, 'No model found in state');
const { positivePrompt, negativePrompt } = getPresetModifiedPrompts(state);
const g = new Graph(CONTROL_LAYERS_GRAPH);
const modelLoader = g.addNode({
type: 'main_model_loader',

View File

@@ -19,7 +19,7 @@ import { addSDXLRefiner } from 'features/nodes/util/graph/generation/addSDXLRefi
import { addSeamless } from 'features/nodes/util/graph/generation/addSeamless';
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import { getBoardField, getSDXLStylePrompts } from 'features/nodes/util/graph/graphBuilderUtils';
import { getBoardField, getPresetModifiedPrompts } from 'features/nodes/util/graph/graphBuilderUtils';
import type { Invocation, NonNullableGraph } from 'services/api/types';
import { isNonRefinerMainModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
@@ -36,14 +36,13 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
vaePrecision,
vae,
} = state.generation;
const { positivePrompt, negativePrompt } = state.controlLayers.present;
const { width, height } = state.controlLayers.present.size;
const { refinerModel, refinerStart } = state.sdxl;
assert(model, 'No model found in state');
const { positiveStylePrompt, negativeStylePrompt } = getSDXLStylePrompts(state);
const { positivePrompt, negativePrompt, positiveStylePrompt, negativeStylePrompt } = getPresetModifiedPrompts(state);
const g = new Graph(SDXL_CONTROL_LAYERS_GRAPH);
const modelLoader = g.addNode({

View File

@@ -1,6 +1,8 @@
import type { RootState } from 'app/store/store';
import type { BoardField } from 'features/nodes/types/common';
import { buildPresetModifiedPrompt } from 'features/stylePresets/hooks/usePresetModifiedPrompts';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { stylePresetsApi } from 'services/api/endpoints/stylePresets';
/**
* Gets the board field, based on the autoAddBoardId setting.
@@ -14,13 +16,43 @@ export const getBoardField = (state: RootState): BoardField | undefined => {
};
/**
* Gets the SDXL style prompts, based on the concat setting.
* Gets the prompts, modified for the active style preset.
*/
export const getSDXLStylePrompts = (state: RootState): { positiveStylePrompt: string; negativeStylePrompt: string } => {
export const getPresetModifiedPrompts = (
state: RootState
): { positivePrompt: string; negativePrompt: string; positiveStylePrompt?: string; negativeStylePrompt?: string } => {
const { positivePrompt, negativePrompt, positivePrompt2, negativePrompt2, shouldConcatPrompts } =
state.controlLayers.present;
const { activeStylePresetId } = state.stylePreset;
if (activeStylePresetId) {
const { data } = stylePresetsApi.endpoints.listStylePresets.select()(state);
const activeStylePreset = data?.find((item) => item.id === activeStylePresetId);
if (activeStylePreset) {
const presetModifiedPositivePrompt = buildPresetModifiedPrompt(
activeStylePreset.preset_data.positive_prompt,
positivePrompt
);
const presetModifiedNegativePrompt = buildPresetModifiedPrompt(
activeStylePreset.preset_data.negative_prompt,
negativePrompt
);
return {
positivePrompt: presetModifiedPositivePrompt,
negativePrompt: presetModifiedNegativePrompt,
positiveStylePrompt: shouldConcatPrompts ? presetModifiedPositivePrompt : positivePrompt2,
negativeStylePrompt: shouldConcatPrompts ? presetModifiedNegativePrompt : negativePrompt2,
};
}
}
return {
positivePrompt,
negativePrompt,
positiveStylePrompt: shouldConcatPrompts ? positivePrompt : positivePrompt2,
negativeStylePrompt: shouldConcatPrompts ? negativePrompt : negativePrompt2,
};

View File

@@ -1,16 +1,32 @@
import { Box, Textarea } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { negativePromptChanged } from 'features/controlLayers/store/controlLayersSlice';
import { PromptLabel } from 'features/parameters/components/Prompts/PromptLabel';
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
import { ViewModePrompt } from 'features/parameters/components/Prompts/ViewModePrompt';
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
import { PromptPopover } from 'features/prompt/PromptPopover';
import { usePrompt } from 'features/prompt/usePrompt';
import { memo, useCallback, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import { useListStylePresetsQuery } from 'services/api/endpoints/stylePresets';
export const ParamNegativePrompt = memo(() => {
const dispatch = useAppDispatch();
const prompt = useAppSelector((s) => s.controlLayers.present.negativePrompt);
const viewMode = useAppSelector((s) => s.stylePreset.viewMode);
const activeStylePresetId = useAppSelector((s) => s.stylePreset.activeStylePresetId);
const { activeStylePreset } = useListStylePresetsQuery(undefined, {
selectFromResult: ({ data }) => {
let activeStylePreset = null;
if (data) {
activeStylePreset = data.find((sp) => sp.id === activeStylePresetId);
}
return { activeStylePreset };
},
});
const textareaRef = useRef<HTMLTextAreaElement>(null);
const { t } = useTranslation();
const _onChange = useCallback(
@@ -27,22 +43,34 @@ export const ParamNegativePrompt = memo(() => {
return (
<PromptPopover isOpen={isOpen} onClose={onClose} onSelect={onSelect} width={textareaRef.current?.clientWidth}>
<Box pos="relative">
<Box pos="relative" w="full">
<Textarea
id="negativePrompt"
name="negativePrompt"
ref={textareaRef}
value={prompt}
placeholder={t('parameters.globalNegativePromptPlaceholder')}
onChange={onChange}
onKeyDown={onKeyDown}
fontSize="sm"
variant="darkFilled"
paddingRight={30}
minH={28}
borderTopWidth={24} // This prevents the prompt from being hidden behind the header
paddingInlineEnd={10}
paddingInlineStart={3}
paddingTop={0}
paddingBottom={3}
/>
<PromptOverlayButtonWrapper>
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
</PromptOverlayButtonWrapper>
<PromptLabel label={t('parameters.negativePromptPlaceholder')} />
{viewMode && (
<ViewModePrompt
prompt={prompt}
presetPrompt={activeStylePreset?.preset_data.negative_prompt || ''}
label={`${t('parameters.negativePromptPlaceholder')} (${t('stylePresets.preview')})`}
/>
)}
</Box>
</PromptPopover>
);

View File

@@ -2,7 +2,9 @@ import { Box, Textarea } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { positivePromptChanged } from 'features/controlLayers/store/controlLayersSlice';
import { ShowDynamicPromptsPreviewButton } from 'features/dynamicPrompts/components/ShowDynamicPromptsPreviewButton';
import { PromptLabel } from 'features/parameters/components/Prompts/PromptLabel';
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
import { ViewModePrompt } from 'features/parameters/components/Prompts/ViewModePrompt';
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
import { PromptPopover } from 'features/prompt/PromptPopover';
import { usePrompt } from 'features/prompt/usePrompt';
@@ -11,11 +13,24 @@ import { memo, useCallback, useRef } from 'react';
import type { HotkeyCallback } from 'react-hotkeys-hook';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { useListStylePresetsQuery } from 'services/api/endpoints/stylePresets';
export const ParamPositivePrompt = memo(() => {
const dispatch = useAppDispatch();
const prompt = useAppSelector((s) => s.controlLayers.present.positivePrompt);
const baseModel = useAppSelector((s) => s.generation.model)?.base;
const viewMode = useAppSelector((s) => s.stylePreset.viewMode);
const activeStylePresetId = useAppSelector((s) => s.stylePreset.activeStylePresetId);
const { activeStylePreset } = useListStylePresetsQuery(undefined, {
selectFromResult: ({ data }) => {
let activeStylePreset = null;
if (data) {
activeStylePreset = data.find((sp) => sp.id === activeStylePresetId);
}
return { activeStylePreset };
},
});
const textareaRef = useRef<HTMLTextAreaElement>(null);
const { t } = useTranslation();
@@ -49,18 +64,29 @@ export const ParamPositivePrompt = memo(() => {
name="prompt"
ref={textareaRef}
value={prompt}
placeholder={t('parameters.globalPositivePromptPlaceholder')}
onChange={onChange}
minH={28}
minH={40}
onKeyDown={onKeyDown}
variant="darkFilled"
paddingRight={30}
borderTopWidth={24} // This prevents the prompt from being hidden behind the header
paddingInlineEnd={10}
paddingInlineStart={3}
paddingTop={0}
paddingBottom={3}
/>
<PromptOverlayButtonWrapper>
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
{baseModel === 'sdxl' && <SDXLConcatButton />}
<ShowDynamicPromptsPreviewButton />
</PromptOverlayButtonWrapper>
<PromptLabel label={t('parameters.positivePromptPlaceholder')} />
{viewMode && (
<ViewModePrompt
prompt={prompt}
presetPrompt={activeStylePreset?.preset_data.positive_prompt || ''}
label={`${t('parameters.positivePromptPlaceholder')} (${t('stylePresets.preview')})`}
/>
)}
</Box>
</PromptPopover>
);

View File

@@ -0,0 +1,9 @@
import { Text } from '@invoke-ai/ui-library';
export const PromptLabel = ({ label }: { label: string }) => {
return (
<Text variant="subtext" fontWeight="semibold" pos="absolute" top={1} left={2}>
{label}
</Text>
);
};

View File

@@ -1,13 +1,29 @@
import { Flex } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { selectControlLayersSlice } from 'features/controlLayers/store/controlLayersSlice';
import { ParamNegativePrompt } from 'features/parameters/components/Core/ParamNegativePrompt';
import { ParamPositivePrompt } from 'features/parameters/components/Core/ParamPositivePrompt';
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
import { ParamSDXLNegativeStylePrompt } from 'features/sdxl/components/SDXLPrompts/ParamSDXLNegativeStylePrompt';
import { ParamSDXLPositiveStylePrompt } from 'features/sdxl/components/SDXLPrompts/ParamSDXLPositiveStylePrompt';
import { memo } from 'react';
const concatPromptsSelector = createSelector(
[selectGenerationSlice, selectControlLayersSlice],
(generation, controlLayers) => {
return generation.model?.base !== 'sdxl' || controlLayers.present.shouldConcatPrompts;
}
);
export const Prompts = memo(() => {
const shouldConcatPrompts = useAppSelector(concatPromptsSelector);
return (
<Flex flexDir="column" gap={2}>
<ParamPositivePrompt />
{!shouldConcatPrompts && <ParamSDXLPositiveStylePrompt />}
<ParamNegativePrompt />
{!shouldConcatPrompts && <ParamSDXLNegativeStylePrompt />}
</Flex>
);
});

View File

@@ -0,0 +1,81 @@
import { Box, Flex, Icon, Text, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { viewModeChanged } from 'features/stylePresets/store/stylePresetSlice';
import { getViewModeChunks } from 'features/stylePresets/util/getViewModeChunks';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiEyeBold } from 'react-icons/pi';
import { PromptLabel } from './PromptLabel';
export const ViewModePrompt = ({
presetPrompt,
prompt,
label,
}: {
presetPrompt: string;
prompt: string;
label: string;
}) => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const presetChunks = useMemo(() => {
return getViewModeChunks(prompt, presetPrompt);
}, [presetPrompt, prompt]);
const handleExitViewMode = useCallback(() => {
dispatch(viewModeChanged(false));
}, [dispatch]);
return (
<Box position="absolute" top={0} bottom={0} left={0} right={0} layerStyle="second" borderRadius="base">
<Flex
flexDir="column"
onClick={handleExitViewMode}
justifyContent="space-between"
h="full"
borderWidth={1}
borderTopWidth={24} // This prevents the prompt from being hidden behind the header
borderColor="transparent"
paddingInlineEnd={10}
paddingInlineStart={3}
paddingTop={0}
paddingBottom={3}
>
<PromptLabel label={label} />
<Flex overflow="scroll">
<Text w="full" lineHeight="short">
{presetChunks.map((chunk, index) => (
<Text
as="span"
color={index === 1 ? 'white' : 'base.200'}
fontWeight={index === 1 ? 'semibold' : 'normal'}
key={index}
>
{chunk}
</Text>
))}
</Text>
</Flex>
<Tooltip label={t('stylePresets.viewModeTooltip')}>
<Flex
position="absolute"
insetInlineEnd={0}
insetBlockStart={0}
alignItems="center"
justifyContent="center"
p={2}
bg="base.900"
opacity={0.8}
backgroundClip="border-box"
borderBottomStartRadius="base"
>
<Icon as={PiEyeBold} color="base.500" boxSize={4} />
</Flex>
</Tooltip>
</Flex>
</Box>
);
};

View File

@@ -1,6 +1,7 @@
import { Box, Textarea } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { negativePrompt2Changed } from 'features/controlLayers/store/controlLayersSlice';
import { PromptLabel } from 'features/parameters/components/Prompts/PromptLabel';
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
import { PromptPopover } from 'features/prompt/PromptPopover';
@@ -36,16 +37,21 @@ export const ParamSDXLNegativeStylePrompt = memo(() => {
name="prompt"
ref={textareaRef}
value={prompt}
placeholder={t('sdxl.negStylePrompt')}
onChange={onChange}
onKeyDown={onKeyDown}
fontSize="sm"
variant="darkFilled"
paddingRight={30}
minH={24}
borderTopWidth={24} // This prevents the prompt from being hidden behind the header
paddingInlineEnd={10}
paddingInlineStart={3}
paddingTop={0}
paddingBottom={3}
/>
<PromptOverlayButtonWrapper>
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
</PromptOverlayButtonWrapper>
<PromptLabel label={t('sdxl.negStylePrompt')} />
</Box>
</PromptPopover>
);

Some files were not shown because too many files have changed in this diff Show More