Compare commits

...

223 Commits

Author SHA1 Message Date
Brandon Rising
bda579577c chore: 4.2.9 version bump 2024-09-05 16:17:48 -04:00
Brandon Rising
a16b555d47 Simplify flux model dtype conversion in model loader 2024-09-05 15:47:14 -04:00
Brandon Rising
6667c39c73 Remove dependency of asizeof 2024-09-05 15:47:14 -04:00
Brandon Rising
5219ac12a6 Add comment explaining the cache make room call 2024-09-05 15:47:14 -04:00
Brandon Rising
445f813fb9 Update flux transformer loader to more efficiently use and release memory during upcasting 2024-09-05 15:47:14 -04:00
Brandon Rising
87f9e59cfb Cast tensors in unquantized flux models to bfloat16 during loading 2024-09-05 15:47:14 -04:00
Phrixus2023
8b03b39aa8 translationBot(ui): update translation (Chinese (Simplified Han script))
Currently translated at 97.6% (1342 of 1374 strings)

Co-authored-by: Phrixus2023 <920414016@qq.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/zh_Hans/
Translation: InvokeAI/Web UI
2024-09-05 15:34:13 -04:00
Tobias Lechner
e59b6bb971 translationBot(ui): update translation (German)
Currently translated at 63.3% (870 of 1374 strings)

Co-authored-by: Tobias Lechner <me@tobias-lechner.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/de/
Translation: InvokeAI/Web UI
2024-09-05 15:34:13 -04:00
Riccardo Giovanetti
24a7ed467c translationBot(ui): update translation (Italian)
Currently translated at 98.2% (1350 of 1374 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.2% (1350 of 1374 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.2% (1350 of 1374 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.4% (1349 of 1370 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.4% (1348 of 1369 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2024-09-05 15:34:13 -04:00
Васянатор
f01f1033ac translationBot(ui): update translation (Russian)
Currently translated at 100.0% (1370 of 1370 strings)

translationBot(ui): update translation (Russian)

Currently translated at 100.0% (1369 of 1369 strings)

Co-authored-by: Васянатор <ilabulanov339@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/ru/
Translation: InvokeAI/Web UI
2024-09-05 15:34:13 -04:00
smk-e
d35f515413 translationBot(ui): update translation (Spanish)
Currently translated at 33.0% (452 of 1369 strings)

Co-authored-by: smk-e <jit-r8@outlook.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/es/
Translation: InvokeAI/Web UI
2024-09-05 15:34:13 -04:00
Brandon Rising
125b459e56 chore: 4.2.9rc2 version bump 2024-09-04 10:42:16 -04:00
Brandon Rising
33edee1ba6 Delete all flux bundle state dict keys when extracting the transformer state dict 2024-09-04 09:36:23 -04:00
Brandon Rising
d20335dabc convert_bundle_to_flux_transformer_checkpoint now removes processed keys to decrease memory usage 2024-09-04 09:36:23 -04:00
Brandon Rising
d10d258213 Add a comment for why we're converting scale tensors in flux models to bfloat16 2024-09-04 09:36:23 -04:00
Brandon
d57ba1ed8b Update invokeai/backend/model_manager/probe.py
Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>
2024-09-04 09:36:23 -04:00
Brandon Rising
2d0e34e57b Support non-quantized bundles 2024-09-04 09:36:23 -04:00
Brandon Rising
a005d06255 feat: support checkpoint bundles containing more than just the transformer 2024-09-04 09:36:23 -04:00
Eugene Brodsky
a301ef5a5a chore(ci): update github action version pins in container build workflow 2024-09-03 16:01:58 -04:00
Eugene Brodsky
9422df2737 feat(ci): enable a checkbox to push the container image when manually building via workflow dispatch 2024-09-03 16:01:58 -04:00
Lincoln Stein
6dabe4d3ca assign T5 encoder to base type "Any" 2024-09-03 15:55:51 -04:00
Lincoln Stein
00e4652d30 add more reliable fallback method for determining BnbQuantizedLlmInt8b 2024-09-03 15:55:51 -04:00
Lincoln Stein
b6434c5318 correct modelformat probe for t5 encoders 2024-09-03 15:55:51 -04:00
Lincoln Stein
3f7f9f8d61 add probes for T5_encoder and ClipTextModel 2024-09-03 15:55:51 -04:00
Brandon Rising
f3bb592544 Update latents used for preview images in flux 2024-09-03 14:04:16 -04:00
Brandon Rising
69f080fb75 Move flux step callback code into the step_callback util scripts, use other services within the invocation context 2024-09-03 14:04:16 -04:00
Brandon Rising
04272a7cc8 Initial attempt at preview images 2024-09-03 14:04:16 -04:00
Lincoln Stein
8d35af946e [MM] add API routes for getting & setting MM cache sizes (#6523)
* [MM] add API routes for getting & setting MM cache sizes, and retrieving MM stats

* Update invokeai/app/api/routers/model_manager.py

Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>

* code cleanup after @ryand review

* Update invokeai/app/api/routers/model_manager.py

Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>

* fix merge conflicts; tested and working

---------

Co-authored-by: Lincoln Stein <lstein@gmail.com>
Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>
2024-09-02 12:18:21 -04:00
Ryan Dick
24065ec6b6 Add FLUX image-to-image and inpainting (#6798)
## Summary

This PR adds support for Image-to-Image and inpainting workflows with
the FLUX model.

Full changelog:
- Split out `FLUX VAE Encode` and `FLUX VAE Decode` nodes
- Renamed `FLUX Text-to-Image` node to `FLUX Denoise` (since it now
supports image-to-image too). This is a workflow-breaking change.
- Added support for FLUX image-to-image via the `Latents` param on the
FLUX denoising node.
- Added support for FLUX masked inpainting via the `Denoise Mask` param
on the FLUX denoising node.
- Added "Denoise Start" and "Denoise End" params to the "FLUX Denoise"
node.
- Updated the "FLUX Text to Image" default workflow.
- Added a "FLUX Image to Image" default workflow.

### Example

FLUX inpainting workflow
<img width="1282" alt="image"
src="https://github.com/user-attachments/assets/86fc1170-e620-4412-8fd8-e119f875fc2e">

Input image

![image](https://github.com/user-attachments/assets/9c381b86-9f87-4257-bd2e-da22c56ca26c)

Mask

![image](https://github.com/user-attachments/assets/8f774c5c-2a25-45fe-9d4b-b233e3d58d2c)

Output image

![image](https://github.com/user-attachments/assets/8576a630-24ce-4a00-8052-e86bab59c855)


### Callouts for reviewers:
- I renamed FLUXTextToImageInvocation -> FLUXDenoisingInvocation. This
is, of course, a breaking change. It feels like the right move and now
is the right time to do it. Any objection?
- I added new `FLUX VAE Encode` and `FLUX VAE Decode` nodes.
Alternatively, I could have tried to match these names to the
corresponding SD nodes (e.g. `FLUX Image to Latents`, `FLUX Latents to
Image`). Personally, I prefer the current names, but want to hear other
opinions.

### Usage notes:
- With the default dev timestep scheduler, the image structure is
largely determined in the first ~3 steps. A consequence of this is that
the denoise_start parameter provides limited 'granularity' of control.
This will likely be improved in the future as we add more scheduler
options. In the meantime, you will likely want to use small values for
`denoise_start` (e.g. 0.03) to start denoising on step ~1-4 out of ~30.
- Currently, there is no 'noise' parameter on the `FLUX Denoise` node,
so the `denoise_end` parameter has limited utility. This will be added
in the future.

## QA Instructions

Test the following workflows:
- [x] Vanilla FLUX text-to-image behaviour is unchanged
- [x] Image-to-image with FLUX dev, no mask
- [x] Image-to-image with FLUX dev, with mask
- [x] Image-to-image with FLUX schnell, no mask (smoke test, not
expected to work well)

## Merge Plan

No special instructions.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
2024-09-02 09:50:31 -04:00
Ryan Dick
627b0bf644 Expose all FLUX model params in the default FLUX models. 2024-09-02 09:38:17 -04:00
Ryan Dick
b43da46b82 Rename 'FLUX VAE Encode'/'FLUX VAE Decode' to 'FLUX Image to Latents'/'FLUX Latents to Image' 2024-09-02 09:38:17 -04:00
Ryan Dick
4255a01c64 Restore line that was accidentally removed during development. 2024-09-02 09:38:17 -04:00
Ryan Dick
23adbd4002 Update schema.ts. 2024-09-02 09:38:17 -04:00
Ryan Dick
fb5a24fcc6 Update default workflows for FLUX. 2024-09-02 09:38:17 -04:00
Ryan Dick
cfdd5a1900 Rename flux_text_to_image.py -> flex_denoise.py 2024-09-02 09:38:17 -04:00
Ryan Dick
2313f326df Add denoise_end param to FluxDenoiseInvocation. 2024-09-02 09:38:17 -04:00
Ryan Dick
2e092a2313 Rename FluxTextToImageInvocation -> FluxDenoiseInvocation. 2024-09-02 09:38:17 -04:00
Ryan Dick
763ef06c18 Use the existence of initial latents to decide whether we are doing image-to-image in the FLUX denoising node. Previously we were using the denoising_start value, but in some cases with an inpaintin mask you may want to run image-to-image from densoising_start=0. 2024-09-02 09:38:17 -04:00
Ryan Dick
8292f6cd42 Code cleanup and documentation around FLUX inpainting. 2024-09-02 09:38:17 -04:00
Ryan Dick
278bba499e Split FLUX VAE decoding out into its own node from LatentsToImageInvocation. 2024-09-02 09:38:17 -04:00
Ryan Dick
dd99ed28e0 Split FLUX VAE encoding out into its own node from ImageToLatentsInvocation. 2024-09-02 09:38:17 -04:00
Ryan Dick
9a8aca69bf Get a rough version of FLUX inpainting working. 2024-09-02 09:38:17 -04:00
Ryan Dick
7ad62512eb Update MaskTensorToImageInvocation to support input mask tensors with or without a channel dimension. 2024-09-02 09:38:17 -04:00
Ryan Dick
bd466661ec Remove unused vae field from FLUXTextToImageInvocation. 2024-09-02 09:38:17 -04:00
Ryan Dick
7ebb509d05 Bump FLUX node versions after splitting out VAE encode/decode. 2024-09-02 09:38:17 -04:00
Ryan Dick
0aa13c046c Split VAE decoding out from the FLUXTextToImageInvocation. 2024-09-02 09:38:17 -04:00
Ryan Dick
a7a33d73f5 Get FLUX non-masked image-to-image working - still rough. 2024-09-02 09:38:17 -04:00
Ryan Dick
ffa39857d3 Add FLUX VAE decoding support to LatentsToImageInvocation. 2024-09-02 09:38:17 -04:00
Ryan Dick
e85c3bc465 Add FLUX VAE support to ImageToLatentsInvocation. 2024-09-02 09:38:17 -04:00
psychedelicious
8185ba7054 scripts: add allocate_vram script
Allocates the specified amount of VRAM, or allocates enough VRAM such that you have the specified amount of VRAM free.

Useful to simulate an environment with a specific amount of VRAM.
2024-09-02 18:18:26 +10:00
Lincoln Stein
d501865bec add a new FAQ for converting safetensors (#6736)
Co-authored-by: Lincoln Stein <lstein@gmail.com>
2024-08-31 18:56:08 +00:00
Brandon Rising
d62310bb5f Support HF repos with subfolders in source on windows OS 2024-08-30 19:31:42 -04:00
Brandon Rising
1835bff196 Fix source string in hugging face installs with subfolders 2024-08-30 19:31:42 -04:00
Ryan Dick
87261bdbc9 FLUX memory management improvements (#6791)
## Summary

This PR contains several improvements to memory management for FLUX
workflows.

It is now possible to achieve better FLUX model caching performance, but
this still requires users to manually configure their `ram`/`vram`
settings. E.g. a `vram` setting of 16.0 should allow for all quantized
FLUX models to be kept in memory on the GPU.

Changes:
- Check the size of a model on disk and free the requisite space in the
model cache before loading it. (This behaviour existed previously, but
was removed in https://github.com/invoke-ai/InvokeAI/pull/6072/files.
The removal did not seem to be intentional).
- Removed the hack to free 24GB of space in the cache before loading the
FLUX model.
- Split the T5 embedding and CLIP embedding steps into separate
functions so that the two models don't both have to be held in RAM at
the same time.
- Fix a bug in `InvokeLinear8bitLt` that was causing some tensors to be
left on the GPU when the model was offloaded to the CPU. (This class is
getting very messy due to the non-standard state_dict handling in
`bnb.nn.Linear8bitLt`. )
- Tidy up some dtype handling in FluxTextToImageInvocation to avoid
situations where we hold references to two copies of the same tensor
unnecessarily.
- (minor) Misc cleanup of ModelCache: improve docs and remove unused
vars.

Future:
We should revisit our default ram/vram configs. The current defaults are
very conservative, and users could see major performance improvements
from tuning these values.

## QA Instructions

I tested the FLUX workflow with the following configurations and
verified that the cache hit rates and memory usage matched the expected
behaviour:
- `ram = 16` and `vram = 16`
- `ram = 16` and `vram = 1`
- `ram = 1` and `vram = 1`

Note that the changes in this PR are not isolated to FLUX. Since we now
check the size of models on disk, we may see slight changes in model
cache offload patterns for other models as well.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
2024-08-29 15:17:45 -04:00
Ryan Dick
4e4b6c6dbc Tidy variable management and dtype handling in FluxTextToImageInvocation. 2024-08-29 19:08:18 +00:00
Ryan Dick
5e8cf9fb6a Remove hack to clear cache from the FluxTextToImageInvocation. We now clear the cache based on the on-disk model size. 2024-08-29 19:08:18 +00:00
Ryan Dick
c738fe051f Split T5 encoding and CLIP encoding into separate functions to ensure that all model references are locally-scoped so that the two models don't have to be help in memory at the same time. 2024-08-29 19:08:18 +00:00
Ryan Dick
29fe1533f2 Fix bug in InvokeLinear8bitLt that was causing old state information to persist after loading from a state dict. This manifested as state tensors being left on the GPU even when a model had been offloaded to the CPU cache. 2024-08-29 19:08:18 +00:00
Ryan Dick
77090070bd Check the size of a model on disk and make room for it in the cache before loading it. 2024-08-29 19:08:18 +00:00
Ryan Dick
6ba9b1b6b0 Tidy up GIG -> GB and remove unused GIG constant. 2024-08-29 19:08:18 +00:00
Ryan Dick
c578b8df1e Improve ModelCache docs. 2024-08-29 19:08:18 +00:00
Ryan Dick
cad9a41433 Remove unused MOdelCache.exists(...) function. 2024-08-29 19:08:18 +00:00
Ryan Dick
5fefb3b0f4 Remove unused param from ModelCache. 2024-08-29 19:08:18 +00:00
Ryan Dick
5284a870b0 Remove unused constructor params from ModelCache. 2024-08-29 19:08:18 +00:00
Ryan Dick
e064377c05 Remove default model cache sizes from model_cache_default.py. These defaults were misleading, because the config defaults take precedence over them. 2024-08-29 19:08:18 +00:00
Mary Hipp
3e569c8312 feat(ui): add fields for CLIP embed models and Flux VAE models in workflows 2024-08-29 11:52:51 -04:00
maryhipp
16825ee6e9 feat(nodes): bump version of flux model node, update default workflow 2024-08-29 11:52:51 -04:00
Mary Hipp
3f5340fa53 feat(nodes): add submodels as inputs to FLUX main model node instead of hardcoded names 2024-08-29 11:52:51 -04:00
chainchompa
f2a1a39b33 Add selectedStylePreset to app parameters (#6787)
## Summary
- Add selectedStylePreset to app parameters
<!--A description of the changes in this PR. Include the kind of change
(fix, feature, docs, etc), the "why" and the "how". Screenshots or
videos are useful for frontend changes.-->

## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
2024-08-28 10:53:07 -04:00
chainchompa
326de55d3e remove api changes and only preselect style preset 2024-08-28 09:53:29 -04:00
chainchompa
b2df909570 added selectedStylePreset to preload presets when app loads 2024-08-28 09:50:44 -04:00
chainchompa
026ac36b06 Revert "added selectedStylePreset to preload presets when app loads"
This reverts commit e97fd85904.
2024-08-28 09:44:08 -04:00
chainchompa
92125e5fd2 bug fixes 2024-08-27 16:13:38 -04:00
chainchompa
c0c139da88 formatting ruff 2024-08-27 15:46:51 -04:00
chainchompa
404ad6a7fd cleanup 2024-08-27 15:42:42 -04:00
chainchompa
fc39086fb4 call stylePresetSelected 2024-08-27 15:34:31 -04:00
chainchompa
cd215700fe added route for selecting style preset 2024-08-27 15:34:07 -04:00
chainchompa
e97fd85904 added selectedStylePreset to preload presets when app loads 2024-08-27 15:33:24 -04:00
Brandon Rising
0a263fa5b1 chore: bump version to v4.2.9rc1 2024-08-27 12:09:27 -04:00
Mary Hipp
fae3836a8d fix CLIP 2024-08-27 10:29:10 -04:00
Mary Hipp
b3d2eb4178 add translations for new model types in MM, remove clip vision from filter since its not displayed in list 2024-08-27 10:29:10 -04:00
psychedelicious
576f1cbb75 build: remove broken scripts
These two scripts are broken and can cause data loss. Remove them.

They are not in the launcher script, but _are_ available to users in the terminal/file browser.

Hopefully, when we removing them here, `pip` will delete them on next installation of the package...
2024-08-27 22:01:45 +10:00
Ryan Dick
50085b40bb Update starter model size estimates. 2024-08-26 20:17:50 -04:00
Mary Hipp
cff382715a default workflow: add steps to exposed fields, add more notes 2024-08-26 20:17:50 -04:00
Brandon Rising
54d54d1bf2 Run ruff 2024-08-26 20:17:50 -04:00
Mary Hipp
e84ea68282 remove prompt 2024-08-26 20:17:50 -04:00
Mary Hipp
160dd36782 update default workflow for flux 2024-08-26 20:17:50 -04:00
Brandon Rising
65bb46bcca Rename params for flux and flux vae, add comments explaining use of the config_path in model config 2024-08-26 20:17:50 -04:00
Brandon Rising
2d185fb766 Run ruff 2024-08-26 20:17:50 -04:00
Brandon Rising
2ba9b02932 Fix type error in tsc 2024-08-26 20:17:50 -04:00
Brandon Rising
849da67cc7 Remove no longer used code in the flux denoise function 2024-08-26 20:17:50 -04:00
Brandon Rising
3ea6c9666e Remove in progress images until we're able to make the valuable 2024-08-26 20:17:50 -04:00
Brandon Rising
cf633e4ef2 Only install starter models if not already installed 2024-08-26 20:17:50 -04:00
Ryan Dick
bbf934d980 Remove outdated TODO. 2024-08-26 20:17:50 -04:00
Ryan Dick
620f733110 ruff format 2024-08-26 20:17:50 -04:00
Ryan Dick
67928609a3 Downgrade accelerate and huggingface-hub deps to original versions. 2024-08-26 20:17:50 -04:00
Ryan Dick
5f15afb7db Remove flux repo dependency 2024-08-26 20:17:50 -04:00
Ryan Dick
635d2f480d ruff 2024-08-26 20:17:50 -04:00
Brandon Rising
70c278c810 Remove dependency on flux config files 2024-08-26 20:17:50 -04:00
Brandon Rising
56b9906e2e Setup scaffolding for in progress images and add ability to cancel the flux node 2024-08-26 20:17:50 -04:00
Ryan Dick
a808ce81fd Replace swish() with torch.nn.functional.silu(h). They are functionally equivalent, but in my test VAE deconding was ~8% faster after the change. 2024-08-26 20:17:50 -04:00
Ryan Dick
83f82c5ddf Switch the CLIP-L start model to use our hosted version - which is much smaller. 2024-08-26 20:17:50 -04:00
Brandon Rising
101de8c25d Update t5 encoder formats to accurately reflect the quantization strategy and data type 2024-08-26 20:17:50 -04:00
Ryan Dick
3339a4baf0 Downgrade revert torch version after removing optimum-qanto, and other minor version-related fixes. 2024-08-26 20:17:50 -04:00
Ryan Dick
dff4a88baa Move quantization scripts to a scripts/ subdir. 2024-08-26 20:17:50 -04:00
Ryan Dick
a21f6c4964 Update docs for T5 quantization script. 2024-08-26 20:17:50 -04:00
Ryan Dick
97562504b7 Remove all references to optimum-quanto and downgrade diffusers. 2024-08-26 20:17:50 -04:00
Ryan Dick
75d8ac378c Update the T5 8-bit quantized starter model to use the BnB LLM.int8() variant. 2024-08-26 20:17:50 -04:00
Ryan Dick
b9dd354e2b Fixes to the T5XXL quantization script. 2024-08-26 20:17:50 -04:00
Ryan Dick
33c2fbd201 Add script for quantizing a T5 model. 2024-08-26 20:17:50 -04:00
Brandon Rising
5063be92bf Switch flux to using its own conditioning field 2024-08-26 20:17:50 -04:00
Brandon Rising
1047584b3e Only import bnb quantize file if bitsandbytes is installed 2024-08-26 20:17:50 -04:00
Brandon Rising
6764dcfdaa Load and unload clip/t5 encoders and run inference separately in text encoding 2024-08-26 20:17:50 -04:00
Brandon Rising
012864ceb1 Update macos test vm to macOS-14 2024-08-26 20:17:50 -04:00
Ryan Dick
a0bf20bcee Run FLUX VAE decoding in the user's preferred dtype rather than float32. Tested, and seems to work well at float16. 2024-08-26 20:17:50 -04:00
Ryan Dick
14ab339b33 Move prepare_latent_image_patches(...) to sampling.py with all of the related FLUX inference code. 2024-08-26 20:17:50 -04:00
Ryan Dick
25c91efbb6 Rename field positive_prompt -> prompt. 2024-08-26 20:17:50 -04:00
Ryan Dick
1c1f2c6664 Add comment about incorrect T5 Tokenizer size calculation. 2024-08-26 20:17:50 -04:00
Ryan Dick
d7c22b3bf7 Tidy is_schnell detection logic. 2024-08-26 20:17:50 -04:00
Ryan Dick
185f2a395f Make FLUX get_noise(...) consistent across devices/dtypes. 2024-08-26 20:17:50 -04:00
Ryan Dick
0c5649491e Mark FLUX nodes as prototypes. 2024-08-26 20:17:50 -04:00
Brandon Rising
94aba5892a Attribute black-forest-labs/flux for much of the flux code 2024-08-26 20:17:50 -04:00
Brandon Rising
ef093dde29 Don't install bitsandbytes on macOS 2024-08-26 20:17:50 -04:00
maryhipp
34451e5f27 added FLUX dev to starter models 2024-08-26 20:17:50 -04:00
Brandon Rising
1f9bdd1a9a Undo changes to the v2 dir of frontend types 2024-08-26 20:17:50 -04:00
Brandon Rising
c27d59baf7 Run ruff 2024-08-26 20:17:50 -04:00
Brandon Rising
f130ddec7c Remove automatic install of models during flux model loader, remove no longer used import function on context 2024-08-26 20:17:50 -04:00
Ryan Dick
a0a259eef1 Fix max_seq_len field description. 2024-08-26 20:17:50 -04:00
Ryan Dick
b66f19d4d1 Add docs to the quantization scripts. 2024-08-26 20:17:50 -04:00
Ryan Dick
4105a78b83 Update load_flux_model_bnb_llm_int8.py to work with a single-file FLUX transformer checkpoint. 2024-08-26 20:17:50 -04:00
Ryan Dick
19a68afb3a Fix bug in InvokeInt8Params that was causing it to use double the necessary VRAM. 2024-08-26 20:17:50 -04:00
maryhipp
fd68a2475b add better workflow name 2024-08-26 20:17:50 -04:00
maryhipp
28ff7ba830 add better workflow description 2024-08-26 20:17:50 -04:00
maryhipp
5d0b248fdb fix(worker) fix T5 type 2024-08-26 20:17:50 -04:00
maryhipp
01a4e0f6ef update default workflow 2024-08-26 20:17:50 -04:00
Mary Hipp
91e0731506 fix schema 2024-08-26 20:17:50 -04:00
Mary Hipp
d1f904d41f tsc and lint fix 2024-08-26 20:17:50 -04:00
Mary Hipp
269388c9f4 feat(ui): create new field for t5 encoder models in nodes 2024-08-26 20:17:50 -04:00
Mary Hipp
b8486379ce fix(ui): pass base/type when installing models, add flux formats to MM badges 2024-08-26 20:17:50 -04:00
Mary Hipp
400eb94d3b fix(ui): only exclude flux main models from linear UI dropdown, not model manager list 2024-08-26 20:17:50 -04:00
maryhipp
e210c96485 add FLUX schnell starter models and submodels as dependenices or adhoc download options 2024-08-26 20:17:50 -04:00
maryhipp
5f567f41f4 add case for clip embed models in probe 2024-08-26 20:17:50 -04:00
maryhipp
5fed573a29 update flux_model_loader node to take a T5 encoder from node field instead of hardcoded list, assume all models have been downloaded 2024-08-26 20:17:50 -04:00
Ryan Dick
cfac7c8189 Move requantize.py to the quatnization/ dir. 2024-08-26 20:17:50 -04:00
Ryan Dick
1787de6836 Add docs to the requantize(...) function explaining why it was copied from optimum-quanto. 2024-08-26 20:17:50 -04:00
Ryan Dick
ac96f187bd Remove duplicate log_time(...) function. 2024-08-26 20:17:50 -04:00
Brandon Rising
72398350b4 More flux loader cleanup 2024-08-26 20:17:50 -04:00
Brandon Rising
df9445c351 Various styling and exception type updates 2024-08-26 20:17:50 -04:00
Brandon Rising
87b7a2e39b Switch inheritance class of flux model loaders 2024-08-26 20:17:50 -04:00
Brandon Rising
f7e46622a1 Update doc string for import_local_model and remove access_token since it's only usable for local file paths 2024-08-26 20:17:50 -04:00
Ryan Dick
71f18353a9 Address minor review comments. 2024-08-26 20:17:50 -04:00
Ryan Dick
4228de707b Rename t5Encoder -> t5_encoder. 2024-08-26 20:17:50 -04:00
Mary Hipp
b6a05629ef add default workflow for flux t2i 2024-08-26 20:17:50 -04:00
Mary Hipp
fbaa820643 exclude flux models from main model dropdown 2024-08-26 20:17:50 -04:00
Brandon Rising
db2a2d5e38 Some cleanup of the tags and description of flux nodes 2024-08-26 20:17:50 -04:00
Brandon Rising
8ba6e6b1f8 Add t5 encoders and clip embeds to the model manager 2024-08-26 20:17:50 -04:00
Brandon Rising
57168d719b Fix styling/lint 2024-08-26 20:17:50 -04:00
Brandon Rising
dee6d2c98e Fix support for 8b quantized t5 encoders, update exception messages in flux loaders 2024-08-26 20:17:50 -04:00
Ryan Dick
e49105ece5 Add tqdm progress bar to FLUX denoising. 2024-08-26 20:17:50 -04:00
Ryan Dick
0c5e11f521 Fix FLUX output image clamping. And a few other minor fixes to make inference work with the full bfloat16 FLUX transformer model. 2024-08-26 20:17:50 -04:00
Brandon Rising
a63f842a13 Select dev/schnell based on state dict, use correct max seq len based on dev/schnell, and shift in inference, separate vae flux params into separate config 2024-08-26 20:17:50 -04:00
Brandon Rising
4bd7fda694 Install sub directories with folders correctly, ensure consistent dtype of tensors in flux pipeline and vae 2024-08-26 20:17:50 -04:00
Brandon Rising
81f0886d6f Working inference node with quantized bnb nf4 checkpoint 2024-08-26 20:17:50 -04:00
Brandon Rising
2eb87f3306 Remove unused param on _run_vae_decoding in flux text to image 2024-08-26 20:17:50 -04:00
Brandon Rising
723f3ab0a9 Add nf4 bnb quantized format 2024-08-26 20:17:50 -04:00
Brandon Rising
1bd90e0fd4 Run ruff, setup initial text to image node 2024-08-26 20:17:50 -04:00
Brandon Rising
436f18ff55 Add backend functions and classes for Flux implementation, Update the way flux encoders/tokenizers are loaded for prompt encoding, Update way flux vae is loaded 2024-08-26 20:17:50 -04:00
Brandon Rising
cde9696214 Some UI cleanup, regenerate schema 2024-08-26 20:17:50 -04:00
Brandon Rising
2d9042fb93 Run Ruff 2024-08-26 20:17:50 -04:00
Brandon Rising
9ed53af520 Run Ruff 2024-08-26 20:17:50 -04:00
Brandon Rising
56fda669fd Manage quantization of models within the loader 2024-08-26 20:17:50 -04:00
Brandon Rising
1d8545a76c Remove changes to v1 workflow 2024-08-26 20:17:50 -04:00
Brandon Rising
5f59a828f9 Setup flux model loading in the UI 2024-08-26 20:17:50 -04:00
Ryan Dick
1fa6bddc89 WIP on moving from diffusers to FLUX 2024-08-26 20:17:50 -04:00
Ryan Dick
d3a5ca5247 More improvements for LLM.int8() - not fully tested. 2024-08-26 20:17:50 -04:00
Ryan Dick
f01f56a98e LLM.int8() quantization is working, but still some rough edges to solve. 2024-08-26 20:17:50 -04:00
Ryan Dick
99b0f79784 Clean up NF4 implementation. 2024-08-26 20:17:50 -04:00
Ryan Dick
e1eb104345 NF4 inference working 2024-08-26 20:17:50 -04:00
Ryan Dick
5c2f95ef50 NF4 loading working... I think. 2024-08-26 20:17:50 -04:00
Ryan Dick
b63df9bab9 wip 2024-08-26 20:17:50 -04:00
Ryan Dick
a52c899c6d Split a FluxTextEncoderInvocation out from the FluxTextToImageInvocation. This has the advantage that we benfit from automatic caching when the prompt isn't changed. 2024-08-26 20:17:50 -04:00
Ryan Dick
eeabb7ebe5 Make quantized loading fast for both T5XXL and FLUX transformer. 2024-08-26 20:17:50 -04:00
Ryan Dick
8b1cef978c Make quantized loading fast. 2024-08-26 20:17:50 -04:00
Ryan Dick
152da482cd WIP - experimentation 2024-08-26 20:17:50 -04:00
Ryan Dick
3cf0365a35 Make float16 inference work with FLUX on 24GB GPU. 2024-08-26 20:17:50 -04:00
Ryan Dick
5870742bb9 Add support for 8-bit quantizatino of the FLUX T5XXL text encoder. 2024-08-26 20:17:50 -04:00
Ryan Dick
01d8c62c57 Make 8-bit quantization save/reload work for the FLUX transformer. Reload is still very slow with the current optimum.quanto implementation. 2024-08-26 20:17:50 -04:00
Ryan Dick
55a242b2d6 Minor improvements to FLUX workflow. 2024-08-26 20:17:50 -04:00
Ryan Dick
45263b339f Got FLUX schnell working with 8-bit quantization. Still lots of rough edges to clean up. 2024-08-26 20:17:50 -04:00
Ryan Dick
3319491861 Use the FluxPipeline.encode_prompt() api rather than trying to run the two text encoders separately. 2024-08-26 20:17:50 -04:00
Ryan Dick
e687afac90 Add sentencepiece dependency for the T5 tokenizer. 2024-08-26 20:17:50 -04:00
Ryan Dick
b39031ea53 First draft of FluxTextToImageInvocation. 2024-08-26 20:17:50 -04:00
Ryan Dick
0b77511271 Update HF download logic to work for black-forest-labs/FLUX.1-schnell. 2024-08-26 20:17:50 -04:00
Ryan Dick
c99cd989c1 Update imports for compatibility with bumped diffusers version. 2024-08-26 20:17:50 -04:00
Ryan Dick
317fdadb21 Bump diffusers version to include FLUX support. 2024-08-26 20:17:50 -04:00
Mary Hipp
4e294f9e3e disable export button if no non-default presets 2024-08-26 09:23:15 -04:00
Jonathan
526e0f30a0 Added support for bounding boxes in the Invocation API
Adding built-in bounding boxes as a core type would help developers of nodes that include bounding box support.
2024-08-26 08:03:30 +10:00
psychedelicious
231e5ec94a chore: bump version v4.2.8post1 2024-08-23 06:55:30 +10:00
Mary Hipp
e5bb6f9693 lint fix 2024-08-23 06:46:19 +10:00
Mary Hipp
da7dee44c6 fix(ui): use empty string fallback if unable to parse prompts when creating style preset from existing image 2024-08-23 06:46:19 +10:00
Eugene Brodsky
83144f4fe3 fix(docs): follow-up docker readme fixes 2024-08-22 11:19:07 -04:00
psychedelicious
c451f52ea3 chore(ui): lint 2024-08-22 21:00:09 +10:00
psychedelicious
8a2c78f2e1 fix(ui): dynamic prompts not recalculating when deleting or updating a style preset
The root cause was the active style preset not being reset when it was deleted, or no longer present in the list of style presets.

- Add extra reducer to `stylePresetSlice` to reset the active preset if it is deleted or otherwise unavailable
- Update the dynamic prompts listener to trigger on delete/update/list of style presets
2024-08-22 21:00:09 +10:00
psychedelicious
bcc78bde9b chore: bump version to v4.2.8 2024-08-22 21:00:09 +10:00
Васянатор
054bb6fe0a translationBot(ui): update translation (Russian)
Currently translated at 100.0% (1367 of 1367 strings)

Co-authored-by: Васянатор <ilabulanov339@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/ru/
Translation: InvokeAI/Web UI
2024-08-22 13:09:56 +10:00
Riccardo Giovanetti
4f4aa6d92e translationBot(ui): update translation (Italian)
Currently translated at 98.4% (1346 of 1367 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.4% (1346 of 1367 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2024-08-22 13:09:56 +10:00
Hosted Weblate
eac51ac6f5 translationBot(ui): update translation files
Updated by "Cleanup translation files" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI
2024-08-22 13:09:56 +10:00
psychedelicious
9f349a7c0a fix(ui): do not constrain width of hide/show boards button
lets translations display fully
2024-08-22 11:36:07 +10:00
psychedelicious
918afa5b15 fix(ui): show more of current board name 2024-08-22 11:36:07 +10:00
psychedelicious
eb1113f95c feat(ui): add translation string for "Upscale" 2024-08-22 11:36:07 +10:00
psychedelicious
4f4ba7b462 tidy(ui): clean up ActiveStylePreset markup 2024-08-21 09:06:41 +10:00
Mary Hipp
2298be0e6b fix(ui): error handling if unable to convert image URL to blob 2024-08-21 09:06:41 +10:00
Mary Hipp
63494dfca7 remove extra slash in exports path 2024-08-21 09:06:41 +10:00
Mary Hipp
36a1d39454 fix(ui): handle badge styling when template name is long 2024-08-21 09:06:41 +10:00
Mary Hipp
a6f6d5c400 fix(ui): add loading state to button when creating or updating a style preset 2024-08-21 09:06:41 +10:00
Mary Hipp
e85f221aca fix(ui): clear prompt template when prompts are recalled 2024-08-21 09:04:35 +10:00
Mary Hipp
d4797e37dc fix(ui): properly unwrap delete style preset API request so that error is caught 2024-08-19 16:12:39 -04:00
Mary Hipp
3e7923d072 fix(api): allow updating of type for style preset 2024-08-19 16:12:39 -04:00
psychedelicious
a85d69ce3d tidy(ui): getViewModeChunks.tsx -> .ts 2024-08-19 08:25:39 +10:00
psychedelicious
96db006c99 fix(ui): edge case with getViewModeChunks 2024-08-19 08:25:39 +10:00
psychedelicious
8ca57d03d8 tests(ui): add tests for getViewModeChunks 2024-08-19 08:25:39 +10:00
psychedelicious
6c404ce5f8 fix(ui): prompt template preset preview out of order 2024-08-19 08:25:39 +10:00
psychedelicious
584e07182b fix(ui): use translations for style preset strings 2024-08-17 21:27:53 +10:00
109 changed files with 5807 additions and 348 deletions

View File

@@ -13,6 +13,12 @@ on:
tags:
- 'v*.*.*'
workflow_dispatch:
inputs:
push-to-registry:
description: Push the built image to the container registry
required: false
type: boolean
default: false
permissions:
contents: write
@@ -50,16 +56,15 @@ jobs:
df -h
- name: Checkout
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Docker meta
id: meta
uses: docker/metadata-action@v4
uses: docker/metadata-action@v5
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
images: |
ghcr.io/${{ github.repository }}
${{ env.DOCKERHUB_REPOSITORY }}
tags: |
type=ref,event=branch
type=ref,event=tag
@@ -72,49 +77,33 @@ jobs:
suffix=-${{ matrix.gpu-driver }},onlatest=false
- name: Set up QEMU
uses: docker/setup-qemu-action@v2
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
uses: docker/setup-buildx-action@v3
with:
platforms: ${{ env.PLATFORMS }}
- name: Login to GitHub Container Registry
if: github.event_name != 'pull_request'
uses: docker/login-action@v2
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
# - name: Login to Docker Hub
# if: github.event_name != 'pull_request' && vars.DOCKERHUB_REPOSITORY != ''
# uses: docker/login-action@v2
# with:
# username: ${{ secrets.DOCKERHUB_USERNAME }}
# password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Build container
timeout-minutes: 40
id: docker_build
uses: docker/build-push-action@v4
uses: docker/build-push-action@v6
with:
context: .
file: docker/Dockerfile
platforms: ${{ env.PLATFORMS }}
push: ${{ github.ref == 'refs/heads/main' || github.ref_type == 'tag' }}
push: ${{ github.ref == 'refs/heads/main' || github.ref_type == 'tag' || github.event.inputs.push-to-registry }}
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
cache-from: |
type=gha,scope=${{ github.ref_name }}-${{ matrix.gpu-driver }}
type=gha,scope=main-${{ matrix.gpu-driver }}
cache-to: type=gha,mode=max,scope=${{ github.ref_name }}-${{ matrix.gpu-driver }}
# - name: Docker Hub Description
# if: github.ref == 'refs/heads/main' || github.ref == 'refs/tags/*' && vars.DOCKERHUB_REPOSITORY != ''
# uses: peter-evans/dockerhub-description@v3
# with:
# username: ${{ secrets.DOCKERHUB_USERNAME }}
# password: ${{ secrets.DOCKERHUB_TOKEN }}
# repository: ${{ vars.DOCKERHUB_REPOSITORY }}
# short-description: ${{ github.event.repository.description }}

View File

@@ -60,7 +60,7 @@ jobs:
extra-index-url: 'https://download.pytorch.org/whl/cpu'
github-env: $GITHUB_ENV
- platform: macos-default
os: macOS-12
os: macOS-14
github-env: $GITHUB_ENV
- platform: windows-cpu
os: windows-2022

View File

@@ -1,20 +1,22 @@
# Invoke in Docker
- Ensure that Docker can use the GPU on your system
- This documentation assumes Linux, but should work similarly under Windows with WSL2
First things first:
- Ensure that Docker can use your [NVIDIA][nvidia docker docs] or [AMD][amd docker docs] GPU.
- This document assumes a Linux system, but should work similarly under Windows with WSL2.
- We don't recommend running Invoke in Docker on macOS at this time. It works, but very slowly.
## Quickstart :lightning:
## Quickstart
No `docker compose`, no persistence, just a simple one-liner using the official images:
No `docker compose`, no persistence, single command, using the official images:
**CUDA:**
**CUDA (NVIDIA GPU):**
```bash
docker run --runtime=nvidia --gpus=all --publish 9090:9090 ghcr.io/invoke-ai/invokeai
```
**ROCm:**
**ROCm (AMD GPU):**
```bash
docker run --device /dev/kfd --device /dev/dri --publish 9090:9090 ghcr.io/invoke-ai/invokeai:main-rocm
@@ -22,12 +24,20 @@ docker run --device /dev/kfd --device /dev/dri --publish 9090:9090 ghcr.io/invok
Open `http://localhost:9090` in your browser once the container finishes booting, install some models, and generate away!
> [!TIP]
> To persist your data (including downloaded models) outside of the container, add a `--volume/-v` flag to the above command, e.g.: `docker run --volume /some/local/path:/invokeai <...the rest of the command>`
### Data persistence
To persist your generated images and downloaded models outside of the container, add a `--volume/-v` flag to the above command, e.g.:
```bash
docker run --volume /some/local/path:/invokeai {...etc...}
```
`/some/local/path/invokeai` will contain all your data.
It can *usually* be reused between different installs of Invoke. Tread with caution and read the release notes!
## Customize the container
We ship the `run.sh` script, which is a convenient wrapper around `docker compose` for cases where custom image build args are needed. Alternatively, the familiar `docker compose` commands work just as well.
The included `run.sh` script is a convenience wrapper around `docker compose`. It can be helpful for passing additional build arguments to `docker compose`. Alternatively, the familiar `docker compose` commands work just as well.
```bash
cd docker
@@ -38,11 +48,14 @@ cp .env.sample .env
It will take a few minutes to build the image the first time. Once the application starts up, open `http://localhost:9090` in your browser to invoke!
>[!TIP]
>When using the `run.sh` script, the container will continue running after Ctrl+C. To shut it down, use the `docker compose down` command.
## Docker setup in detail
#### Linux
1. Ensure builkit is enabled in the Docker daemon settings (`/etc/docker/daemon.json`)
1. Ensure buildkit is enabled in the Docker daemon settings (`/etc/docker/daemon.json`)
2. Install the `docker compose` plugin using your package manager, or follow a [tutorial](https://docs.docker.com/compose/install/linux/#install-using-the-repository).
- The deprecated `docker-compose` (hyphenated) CLI probably won't work. Update to a recent version.
3. Ensure docker daemon is able to access the GPU.
@@ -98,25 +111,7 @@ GPU_DRIVER=cuda
Any environment variables supported by InvokeAI can be set here. See the [Configuration docs](https://invoke-ai.github.io/InvokeAI/features/CONFIGURATION/) for further detail.
## Even More Customizing!
---
See the `docker-compose.yml` file. The `command` instruction can be uncommented and used to run arbitrary startup commands. Some examples below.
### Reconfigure the runtime directory
Can be used to download additional models from the supported model list
In conjunction with `INVOKEAI_ROOT` can be also used to initialize a runtime directory
```yaml
command:
- invokeai-configure
- --yes
```
Or install models:
```yaml
command:
- invokeai-model-install
```
[nvidia docker docs]: https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html
[amd docker docs]: https://rocm.docs.amd.com/projects/install-on-linux/en/latest/how-to/docker.html

View File

@@ -196,6 +196,22 @@ tips to reduce the problem:
=== "12GB VRAM GPU"
This should be sufficient to generate larger images up to about 1280x1280.
## Checkpoint Models Load Slowly or Use Too Much RAM
The difference between diffusers models (a folder containing multiple
subfolders) and checkpoint models (a file ending with .safetensors or
.ckpt) is that InvokeAI is able to load diffusers models into memory
incrementally, while checkpoint models must be loaded all at
once. With very large models, or systems with limited RAM, you may
experience slowdowns and other memory-related issues when loading
checkpoint models.
To solve this, go to the Model Manager tab (the cube), select the
checkpoint model that's giving you trouble, and press the "Convert"
button in the upper right of your browser window. This will conver the
checkpoint into a diffusers model, after which loading should be
faster and less memory-intensive.
## Memory Leak (Linux)

View File

@@ -3,8 +3,10 @@
import io
import pathlib
import shutil
import traceback
from copy import deepcopy
from enum import Enum
from tempfile import TemporaryDirectory
from typing import List, Optional, Type
@@ -17,6 +19,7 @@ from starlette.exceptions import HTTPException
from typing_extensions import Annotated
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.config import get_config
from invokeai.app.services.model_images.model_images_common import ModelImageFileNotFoundException
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
from invokeai.app.services.model_records import (
@@ -31,6 +34,7 @@ from invokeai.backend.model_manager.config import (
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.load.model_cache.model_cache_base import CacheStats
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
from invokeai.backend.model_manager.search import ModelSearch
@@ -50,6 +54,13 @@ class ModelsList(BaseModel):
model_config = ConfigDict(use_enum_values=True)
class CacheType(str, Enum):
"""Cache type - one of vram or ram."""
RAM = "RAM"
VRAM = "VRAM"
def add_cover_image_to_model_config(config: AnyModelConfig, dependencies: Type[ApiDependencies]) -> AnyModelConfig:
"""Add a cover image URL to a model configuration."""
cover_image = dependencies.invoker.services.model_images.get_url(config.key)
@@ -797,3 +808,83 @@ async def get_starter_models() -> list[StarterModel]:
model.dependencies = missing_deps
return starter_models
@model_manager_router.get(
"/model_cache",
operation_id="get_cache_size",
response_model=float,
summary="Get maximum size of model manager RAM or VRAM cache.",
)
async def get_cache_size(cache_type: CacheType = Query(description="The cache type", default=CacheType.RAM)) -> float:
"""Return the current RAM or VRAM cache size setting (in GB)."""
cache = ApiDependencies.invoker.services.model_manager.load.ram_cache
value = 0.0
if cache_type == CacheType.RAM:
value = cache.max_cache_size
elif cache_type == CacheType.VRAM:
value = cache.max_vram_cache_size
return value
@model_manager_router.put(
"/model_cache",
operation_id="set_cache_size",
response_model=float,
summary="Set maximum size of model manager RAM or VRAM cache, optionally writing new value out to invokeai.yaml config file.",
)
async def set_cache_size(
value: float = Query(description="The new value for the maximum cache size"),
cache_type: CacheType = Query(description="The cache type", default=CacheType.RAM),
persist: bool = Query(description="Write new value out to invokeai.yaml", default=False),
) -> float:
"""Set the current RAM or VRAM cache size setting (in GB). ."""
cache = ApiDependencies.invoker.services.model_manager.load.ram_cache
app_config = get_config()
# Record initial state.
vram_old = app_config.vram
ram_old = app_config.ram
# Prepare target state.
vram_new = vram_old
ram_new = ram_old
if cache_type == CacheType.RAM:
ram_new = value
elif cache_type == CacheType.VRAM:
vram_new = value
else:
raise ValueError(f"Unexpected {cache_type=}.")
config_path = app_config.config_file_path
new_config_path = config_path.with_suffix(".yaml.new")
try:
# Try to apply the target state.
cache.max_vram_cache_size = vram_new
cache.max_cache_size = ram_new
app_config.ram = ram_new
app_config.vram = vram_new
if persist:
app_config.write_file(new_config_path)
shutil.move(new_config_path, config_path)
except Exception as e:
# If there was a failure, restore the initial state.
cache.max_cache_size = ram_old
cache.max_vram_cache_size = vram_old
app_config.ram = ram_old
app_config.vram = vram_old
raise RuntimeError("Failed to update cache size") from e
return value
@model_manager_router.get(
"/stats",
operation_id="get_stats",
response_model=Optional[CacheStats],
summary="Get model manager RAM cache performance statistics.",
)
async def get_stats() -> Optional[CacheStats]:
"""Return performance statistics on the model manager's RAM cache. Will return null if no models have been loaded."""
return ApiDependencies.invoker.services.model_manager.load.ram_cache.stats

View File

@@ -26,13 +26,10 @@ from invokeai.app.services.style_preset_records.style_preset_records_common impo
)
class StylePresetUpdateFormData(BaseModel):
class StylePresetFormData(BaseModel):
name: str = Field(description="Preset name")
positive_prompt: str = Field(description="Positive prompt")
negative_prompt: str = Field(description="Negative prompt")
class StylePresetCreateFormData(StylePresetUpdateFormData):
type: PresetType = Field(description="Preset type")
@@ -95,9 +92,10 @@ async def update_style_preset(
try:
parsed_data = json.loads(data)
validated_data = StylePresetUpdateFormData(**parsed_data)
validated_data = StylePresetFormData(**parsed_data)
name = validated_data.name
type = validated_data.type
positive_prompt = validated_data.positive_prompt
negative_prompt = validated_data.negative_prompt
@@ -105,7 +103,7 @@ async def update_style_preset(
raise HTTPException(status_code=400, detail="Invalid preset data")
preset_data = PresetData(positive_prompt=positive_prompt, negative_prompt=negative_prompt)
changes = StylePresetChanges(name=name, preset_data=preset_data)
changes = StylePresetChanges(name=name, preset_data=preset_data, type=type)
style_preset_image = ApiDependencies.invoker.services.style_preset_image_files.get_url(style_preset_id)
style_preset = ApiDependencies.invoker.services.style_preset_records.update(
@@ -145,7 +143,7 @@ async def create_style_preset(
try:
parsed_data = json.loads(data)
validated_data = StylePresetCreateFormData(**parsed_data)
validated_data = StylePresetFormData(**parsed_data)
name = validated_data.name
type = validated_data.type

View File

@@ -185,7 +185,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
)
denoise_mask: Optional[DenoiseMaskField] = InputField(
default=None,
description=FieldDescriptions.mask,
description=FieldDescriptions.denoise_mask,
input=Input.Connection,
ui_order=8,
)

View File

@@ -40,14 +40,18 @@ class UIType(str, Enum, metaclass=MetaEnum):
# region Model Field Types
MainModel = "MainModelField"
FluxMainModel = "FluxMainModelField"
SDXLMainModel = "SDXLMainModelField"
SDXLRefinerModel = "SDXLRefinerModelField"
ONNXModel = "ONNXModelField"
VAEModel = "VAEModelField"
FluxVAEModel = "FluxVAEModelField"
LoRAModel = "LoRAModelField"
ControlNetModel = "ControlNetModelField"
IPAdapterModel = "IPAdapterModelField"
T2IAdapterModel = "T2IAdapterModelField"
T5EncoderModel = "T5EncoderModelField"
CLIPEmbedModel = "CLIPEmbedModelField"
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
# endregion
@@ -125,13 +129,17 @@ class FieldDescriptions:
negative_cond = "Negative conditioning tensor"
noise = "Noise tensor"
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
t5_encoder = "T5 tokenizer and text encoder"
clip_embed_model = "CLIP Embed loader"
unet = "UNet (scheduler, LoRAs)"
transformer = "Transformer"
vae = "VAE"
cond = "Conditioning tensor"
controlnet_model = "ControlNet model to load"
vae_model = "VAE model to load"
lora_model = "LoRA model to load"
main_model = "Main model (UNet, VAE, CLIP) to load"
flux_model = "Flux model (Transformer) to load"
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
@@ -173,7 +181,7 @@ class FieldDescriptions:
)
num_1 = "The first number"
num_2 = "The second number"
mask = "The mask to use for the operation"
denoise_mask = "A mask of the region to apply the denoising process to."
board = "The board to save the image to"
image = "The image to process"
tile_size = "Tile size"
@@ -231,6 +239,12 @@ class ColorField(BaseModel):
return (self.r, self.g, self.b, self.a)
class FluxConditioningField(BaseModel):
"""A conditioning tensor primitive value"""
conditioning_name: str = Field(description="The name of conditioning tensor")
class ConditioningField(BaseModel):
"""A conditioning tensor primitive value"""

View File

@@ -0,0 +1,249 @@
from typing import Callable, Optional
import torch
import torchvision.transforms as tv_transforms
from torchvision.transforms.functional import resize as tv_resize
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
DenoiseMaskField,
FieldDescriptions,
FluxConditioningField,
Input,
InputField,
LatentsField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import TransformerField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.denoise import denoise
from invokeai.backend.flux.inpaint_extension import InpaintExtension
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.sampling_utils import (
clip_timestep_schedule,
generate_img_ids,
get_noise,
get_schedule,
pack,
unpack,
)
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
from invokeai.backend.util.devices import TorchDevice
@invocation(
"flux_denoise",
title="FLUX Denoise",
tags=["image", "flux"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Run denoising process with a FLUX transformer model."""
# If latents is provided, this means we are doing image-to-image.
latents: Optional[LatentsField] = InputField(
default=None,
description=FieldDescriptions.latents,
input=Input.Connection,
)
# denoise_mask is used for image-to-image inpainting. Only the masked region is modified.
denoise_mask: Optional[DenoiseMaskField] = InputField(
default=None,
description=FieldDescriptions.denoise_mask,
input=Input.Connection,
)
denoising_start: float = InputField(
default=0.0,
ge=0,
le=1,
description=FieldDescriptions.denoising_start,
)
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
transformer: TransformerField = InputField(
description=FieldDescriptions.flux_model,
input=Input.Connection,
title="Transformer",
)
positive_text_conditioning: FluxConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
num_steps: int = InputField(
default=4, description="Number of diffusion steps. Recommended values are schnell: 4, dev: 50."
)
guidance: float = InputField(
default=4.0,
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.",
)
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = self._run_diffusion(context)
latents = latents.detach().to("cpu")
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
def _run_diffusion(
self,
context: InvocationContext,
):
inference_dtype = torch.bfloat16
# Load the conditioning data.
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
assert len(cond_data.conditionings) == 1
flux_conditioning = cond_data.conditionings[0]
assert isinstance(flux_conditioning, FLUXConditioningInfo)
flux_conditioning = flux_conditioning.to(dtype=inference_dtype)
t5_embeddings = flux_conditioning.t5_embeds
clip_embeddings = flux_conditioning.clip_embeds
# Load the input latents, if provided.
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
if init_latents is not None:
init_latents = init_latents.to(device=TorchDevice.choose_torch_device(), dtype=inference_dtype)
# Prepare input noise.
noise = get_noise(
num_samples=1,
height=self.height,
width=self.width,
device=TorchDevice.choose_torch_device(),
dtype=inference_dtype,
seed=self.seed,
)
transformer_info = context.models.load(self.transformer.transformer)
is_schnell = "schnell" in transformer_info.config.config_path
# Calculate the timestep schedule.
image_seq_len = noise.shape[-1] * noise.shape[-2] // 4
timesteps = get_schedule(
num_steps=self.num_steps,
image_seq_len=image_seq_len,
shift=not is_schnell,
)
# Clip the timesteps schedule based on denoising_start and denoising_end.
timesteps = clip_timestep_schedule(timesteps, self.denoising_start, self.denoising_end)
# Prepare input latent image.
if init_latents is not None:
# If init_latents is provided, we are doing image-to-image.
if is_schnell:
context.logger.warning(
"Running image-to-image with a FLUX schnell model. This is not recommended. The results are likely "
"to be poor. Consider using a FLUX dev model instead."
)
# Noise the orig_latents by the appropriate amount for the first timestep.
t_0 = timesteps[0]
x = t_0 * noise + (1.0 - t_0) * init_latents
else:
# init_latents are not provided, so we are not doing image-to-image (i.e. we are starting from pure noise).
if self.denoising_start > 1e-5:
raise ValueError("denoising_start should be 0 when initial latents are not provided.")
x = noise
# If len(timesteps) == 1, then short-circuit. We are just noising the input latents, but not taking any
# denoising steps.
if len(timesteps) <= 1:
return x
inpaint_mask = self._prep_inpaint_mask(context, x)
b, _c, h, w = x.shape
img_ids = generate_img_ids(h=h, w=w, batch_size=b, device=x.device, dtype=x.dtype)
bs, t5_seq_len, _ = t5_embeddings.shape
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
# Pack all latent tensors.
init_latents = pack(init_latents) if init_latents is not None else None
inpaint_mask = pack(inpaint_mask) if inpaint_mask is not None else None
noise = pack(noise)
x = pack(x)
# Now that we have 'packed' the latent tensors, verify that we calculated the image_seq_len correctly.
assert image_seq_len == x.shape[1]
# Prepare inpaint extension.
inpaint_extension: InpaintExtension | None = None
if inpaint_mask is not None:
assert init_latents is not None
inpaint_extension = InpaintExtension(
init_latents=init_latents,
inpaint_mask=inpaint_mask,
noise=noise,
)
with transformer_info as transformer:
assert isinstance(transformer, Flux)
x = denoise(
model=transformer,
img=x,
img_ids=img_ids,
txt=t5_embeddings,
txt_ids=txt_ids,
vec=clip_embeddings,
timesteps=timesteps,
step_callback=self._build_step_callback(context),
guidance=self.guidance,
inpaint_extension=inpaint_extension,
)
x = unpack(x.float(), self.height, self.width)
return x
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None:
"""Prepare the inpaint mask.
- Loads the mask
- Resizes if necessary
- Casts to same device/dtype as latents
- Expands mask to the same shape as latents so that they line up after 'packing'
Args:
context (InvocationContext): The invocation context, for loading the inpaint mask.
latents (torch.Tensor): A latent image tensor. In 'unpacked' format. Used to determine the target shape,
device, and dtype for the inpaint mask.
Returns:
torch.Tensor | None: Inpaint mask.
"""
if self.denoise_mask is None:
return None
mask = context.tensors.load(self.denoise_mask.mask_name)
_, _, latent_height, latent_width = latents.shape
mask = tv_resize(
img=mask,
size=[latent_height, latent_width],
interpolation=tv_transforms.InterpolationMode.BILINEAR,
antialias=False,
)
mask = mask.to(device=latents.device, dtype=latents.dtype)
# Expand the inpaint mask to the same shape as `latents` so that when we 'pack' `mask` it lines up with
# `latents`.
return mask.expand_as(latents)
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
def step_callback(state: PipelineIntermediateState) -> None:
state.latents = unpack(state.latents.float(), self.height, self.width).squeeze()
context.util.flux_step_callback(state)
return step_callback

View File

@@ -0,0 +1,92 @@
from typing import Literal
import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
from invokeai.app.invocations.model import CLIPField, T5EncoderField
from invokeai.app.invocations.primitives import FluxConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.modules.conditioner import HFEncoder
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
@invocation(
"flux_text_encoder",
title="FLUX Text Encoding",
tags=["prompt", "conditioning", "flux"],
category="conditioning",
version="1.0.0",
classification=Classification.Prototype,
)
class FluxTextEncoderInvocation(BaseInvocation):
"""Encodes and preps a prompt for a flux image."""
clip: CLIPField = InputField(
title="CLIP",
description=FieldDescriptions.clip,
input=Input.Connection,
)
t5_encoder: T5EncoderField = InputField(
title="T5Encoder",
description=FieldDescriptions.t5_encoder,
input=Input.Connection,
)
t5_max_seq_len: Literal[256, 512] = InputField(
description="Max sequence length for the T5 encoder. Expected to be 256 for FLUX schnell models and 512 for FLUX dev models."
)
prompt: str = InputField(description="Text prompt to encode.")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
# Note: The T5 and CLIP encoding are done in separate functions to ensure that all model references are locally
# scoped. This ensures that the T5 model can be freed and gc'd before loading the CLIP model (if necessary).
t5_embeddings = self._t5_encode(context)
clip_embeddings = self._clip_encode(context)
conditioning_data = ConditioningFieldData(
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
)
conditioning_name = context.conditioning.save(conditioning_data)
return FluxConditioningOutput.build(conditioning_name)
def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
prompt = [self.prompt]
with (
t5_text_encoder_info as t5_text_encoder,
t5_tokenizer_info as t5_tokenizer,
):
assert isinstance(t5_text_encoder, T5EncoderModel)
assert isinstance(t5_tokenizer, T5Tokenizer)
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)
prompt_embeds = t5_encoder(prompt)
assert isinstance(prompt_embeds, torch.Tensor)
return prompt_embeds
def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
prompt = [self.prompt]
with (
clip_text_encoder_info as clip_text_encoder,
clip_tokenizer_info as clip_tokenizer,
):
assert isinstance(clip_text_encoder, CLIPTextModel)
assert isinstance(clip_tokenizer, CLIPTokenizer)
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
pooled_prompt_embeds = clip_encoder(prompt)
assert isinstance(pooled_prompt_embeds, torch.Tensor)
return pooled_prompt_embeds

View File

@@ -0,0 +1,60 @@
import torch
from einops import rearrange
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
LatentsField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.util.devices import TorchDevice
@invocation(
"flux_vae_decode",
title="FLUX Latents to Image",
tags=["latents", "image", "vae", "l2i", "flux"],
category="latents",
version="1.0.0",
)
class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates an image from latents."""
latents: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
def _vae_decode(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image:
with vae_info as vae:
assert isinstance(vae, AutoEncoder)
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype())
img = vae.decode(latents)
img = img.clamp(-1, 1)
img = rearrange(img[0], "c h w -> h w c") # noqa: F821
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
return img_pil
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae)
image = self._vae_decode(vae_info=vae_info, latents=latents)
TorchDevice.empty_cache()
image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto)

View File

@@ -0,0 +1,67 @@
import einops
import torch
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
Input,
InputField,
)
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.model_manager import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.util.devices import TorchDevice
@invocation(
"flux_vae_encode",
title="FLUX Image to Latents",
tags=["latents", "image", "vae", "i2l", "flux"],
category="latents",
version="1.0.0",
)
class FluxVaeEncodeInvocation(BaseInvocation):
"""Encodes an image into latents."""
image: ImageField = InputField(
description="The image to encode.",
)
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
@staticmethod
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
# TODO(ryand): Expose seed parameter at the invocation level.
# TODO(ryand): Write a util function for generating random tensors that is consistent across devices / dtypes.
# There's a starting point in get_noise(...), but it needs to be extracted and generalized. This function
# should be used for VAE encode sampling.
generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0)
with vae_info as vae:
assert isinstance(vae, AutoEncoder)
image_tensor = image_tensor.to(
device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()
)
latents = vae.encode(image_tensor, sample=True, generator=generator)
return latents
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.images.get_pil(self.image.image_name)
vae_info = context.models.load(self.vae.vae)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
latents = latents.to("cpu")
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)

View File

@@ -126,7 +126,7 @@ class ImageMaskToTensorInvocation(BaseInvocation, WithMetadata):
title="Tensor Mask to Image",
tags=["mask"],
category="mask",
version="1.0.0",
version="1.1.0",
)
class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Convert a mask tensor to an image."""
@@ -135,6 +135,11 @@ class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
mask = context.tensors.load(self.mask.tensor_name)
# Squeeze the channel dimension if it exists.
if mask.dim() == 3:
mask = mask.squeeze(0)
# Ensure that the mask is binary.
if mask.dtype != torch.bool:
mask = mask > 0.5

View File

@@ -1,5 +1,5 @@
import copy
from typing import List, Optional
from typing import List, Literal, Optional
from pydantic import BaseModel, Field
@@ -13,7 +13,14 @@ from invokeai.app.invocations.baseinvocation import (
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType
from invokeai.backend.flux.util import max_seq_lengths
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
CheckpointConfigBase,
ModelType,
SubModelType,
)
class ModelIdentifierField(BaseModel):
@@ -60,6 +67,15 @@ class CLIPField(BaseModel):
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
class TransformerField(BaseModel):
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
class T5EncoderField(BaseModel):
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
class VAEField(BaseModel):
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
@@ -122,6 +138,78 @@ class ModelIdentifierInvocation(BaseInvocation):
return ModelIdentifierOutput(model=self.model)
@invocation_output("flux_model_loader_output")
class FluxModelLoaderOutput(BaseInvocationOutput):
"""Flux base model loader output"""
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
max_seq_len: Literal[256, 512] = OutputField(
description="The max sequence length to used for the T5 encoder. (256 for schnell transformer, 512 for dev transformer)",
title="Max Seq Length",
)
@invocation(
"flux_model_loader",
title="Flux Main Model",
tags=["model", "flux"],
category="model",
version="1.0.4",
classification=Classification.Prototype,
)
class FluxModelLoaderInvocation(BaseInvocation):
"""Loads a flux base model, outputting its submodels."""
model: ModelIdentifierField = InputField(
description=FieldDescriptions.flux_model,
ui_type=UIType.FluxMainModel,
input=Input.Direct,
)
t5_encoder_model: ModelIdentifierField = InputField(
description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder"
)
clip_embed_model: ModelIdentifierField = InputField(
description=FieldDescriptions.clip_embed_model,
ui_type=UIType.CLIPEmbedModel,
input=Input.Direct,
title="CLIP Embed",
)
vae_model: ModelIdentifierField = InputField(
description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE"
)
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
for key in [self.model.key, self.t5_encoder_model.key, self.clip_embed_model.key, self.vae_model.key]:
if not context.models.exists(key):
raise ValueError(f"Unknown model: {key}")
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
transformer_config = context.models.get_config(transformer)
assert isinstance(transformer_config, CheckpointConfigBase)
return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer),
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
vae=VAEField(vae=vae),
max_seq_len=max_seq_lengths[transformer_config.config_path],
)
@invocation(
"main_model_loader",
title="Main Model",

View File

@@ -12,6 +12,7 @@ from invokeai.app.invocations.fields import (
ConditioningField,
DenoiseMaskField,
FieldDescriptions,
FluxConditioningField,
ImageField,
Input,
InputField,
@@ -414,6 +415,17 @@ class MaskOutput(BaseInvocationOutput):
height: int = OutputField(description="The height of the mask in pixels.")
@invocation_output("flux_conditioning_output")
class FluxConditioningOutput(BaseInvocationOutput):
"""Base class for nodes that output a single conditioning tensor"""
conditioning: FluxConditioningField = OutputField(description=FieldDescriptions.cond)
@classmethod
def build(cls, conditioning_name: str) -> "FluxConditioningOutput":
return cls(conditioning=FluxConditioningField(conditioning_name=conditioning_name))
@invocation_output("conditioning_output")
class ConditioningOutput(BaseInvocationOutput):
"""Base class for nodes that output a single conditioning tensor"""

View File

@@ -103,7 +103,7 @@ class HFModelSource(StringLikeSource):
if self.variant:
base += f":{self.variant or ''}"
if self.subfolder:
base += f":{self.subfolder}"
base += f"::{self.subfolder.as_posix()}"
return base

View File

@@ -783,8 +783,9 @@ class ModelInstallService(ModelInstallServiceBase):
# So what we do is to synthesize a folder named "sdxl-turbo_vae" here.
if subfolder:
top = Path(remote_files[0].path.parts[0]) # e.g. "sdxl-turbo/"
path_to_remove = top / subfolder.parts[-1] # sdxl-turbo/vae/
path_to_add = Path(f"{top}_{subfolder}")
path_to_remove = top / subfolder # sdxl-turbo/vae/
subfolder_rename = subfolder.name.replace("/", "_").replace("\\", "_")
path_to_add = Path(f"{top}_{subfolder_rename}")
else:
path_to_remove = Path(".")
path_to_add = Path(".")

View File

@@ -77,6 +77,7 @@ class ModelRecordChanges(BaseModelExcludeNull):
type: Optional[ModelType] = Field(description="Type of model", default=None)
key: Optional[str] = Field(description="Database ID for this model", default=None)
hash: Optional[str] = Field(description="hash of model file", default=None)
format: Optional[str] = Field(description="format of model file", default=None)
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
default_settings: Optional[MainModelDefaultSettings | ControlAdapterDefaultSettings] = Field(
description="Default settings for this model", default=None

View File

@@ -14,7 +14,7 @@ from invokeai.app.services.image_records.image_records_common import ImageCatego
from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.model_records.model_records_base import UnknownModelException
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.app.util.step_callback import flux_step_callback, stable_diffusion_step_callback
from invokeai.backend.model_manager.config import (
AnyModel,
AnyModelConfig,
@@ -557,6 +557,24 @@ class UtilInterface(InvocationContextInterface):
is_canceled=self.is_canceled,
)
def flux_step_callback(self, intermediate_state: PipelineIntermediateState) -> None:
"""
The step callback emits a progress event with the current step, the total number of
steps, a preview image, and some other internal metadata.
This should be called after each denoising step.
Args:
intermediate_state: The intermediate state of the diffusion pipeline.
"""
flux_step_callback(
context_data=self._data,
intermediate_state=intermediate_state,
events=self._services.events,
is_canceled=self.is_canceled,
)
class InvocationContext:
"""Provides access to various services and data for the current invocation.

View File

@@ -32,6 +32,7 @@ class PresetType(str, Enum, metaclass=MetaEnum):
class StylePresetChanges(BaseModel, extra="forbid"):
name: Optional[str] = Field(default=None, description="The style preset's new name.")
preset_data: Optional[PresetData] = Field(default=None, description="The updated data for style preset.")
type: Optional[PresetType] = Field(description="The updated type of the style preset")
class StylePresetWithoutId(BaseModel):

View File

@@ -0,0 +1,407 @@
{
"name": "FLUX Image to Image",
"author": "InvokeAI",
"description": "A simple image-to-image workflow using a FLUX dev model. ",
"version": "1.0.4",
"contact": "",
"tags": "image2image, flux, image-to-image",
"notes": "Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend using FLUX dev models for image-to-image workflows. The image-to-image performance with FLUX schnell models is poor.",
"exposedFields": [
{
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"fieldName": "model"
},
{
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"fieldName": "t5_encoder_model"
},
{
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"fieldName": "clip_embed_model"
},
{
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"fieldName": "vae_model"
},
{
"nodeId": "ace0258f-67d7-4eee-a218-6fff27065214",
"fieldName": "denoising_start"
},
{
"nodeId": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"fieldName": "prompt"
},
{
"nodeId": "ace0258f-67d7-4eee-a218-6fff27065214",
"fieldName": "num_steps"
}
],
"meta": {
"version": "3.0.0",
"category": "default"
},
"nodes": [
{
"id": "2981a67c-480f-4237-9384-26b68dbf912b",
"type": "invocation",
"data": {
"id": "2981a67c-480f-4237-9384-26b68dbf912b",
"type": "flux_vae_encode",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": true,
"inputs": {
"image": {
"name": "image",
"label": "",
"value": {
"image_name": "8a5c62aa-9335-45d2-9c71-89af9fc1f8d4.png"
}
},
"vae": {
"name": "vae",
"label": ""
}
}
},
"position": {
"x": 732.7680166609682,
"y": -24.37398171806909
}
},
{
"id": "ace0258f-67d7-4eee-a218-6fff27065214",
"type": "invocation",
"data": {
"id": "ace0258f-67d7-4eee-a218-6fff27065214",
"type": "flux_denoise",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": true,
"inputs": {
"board": {
"name": "board",
"label": ""
},
"metadata": {
"name": "metadata",
"label": ""
},
"latents": {
"name": "latents",
"label": ""
},
"denoise_mask": {
"name": "denoise_mask",
"label": ""
},
"denoising_start": {
"name": "denoising_start",
"label": "",
"value": 0.04
},
"denoising_end": {
"name": "denoising_end",
"label": "",
"value": 1
},
"transformer": {
"name": "transformer",
"label": ""
},
"positive_text_conditioning": {
"name": "positive_text_conditioning",
"label": ""
},
"width": {
"name": "width",
"label": "",
"value": 1024
},
"height": {
"name": "height",
"label": "",
"value": 1024
},
"num_steps": {
"name": "num_steps",
"label": "Steps (Recommend 30 for Dev, 4 for Schnell)",
"value": 30
},
"guidance": {
"name": "guidance",
"label": "",
"value": 4
},
"seed": {
"name": "seed",
"label": "",
"value": 0
}
}
},
"position": {
"x": 1182.8836633018684,
"y": -251.38882958913183
}
},
{
"id": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
"type": "invocation",
"data": {
"id": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
"type": "flux_vae_decode",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": false,
"useCache": true,
"inputs": {
"board": {
"name": "board",
"label": ""
},
"metadata": {
"name": "metadata",
"label": ""
},
"latents": {
"name": "latents",
"label": ""
},
"vae": {
"name": "vae",
"label": ""
}
}
},
"position": {
"x": 1575.5797431839133,
"y": -209.00150975507415
}
},
{
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"type": "invocation",
"data": {
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"type": "flux_model_loader",
"version": "1.0.4",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": false,
"inputs": {
"model": {
"name": "model",
"label": "Model (dev variant recommended for Image-to-Image)"
},
"t5_encoder_model": {
"name": "t5_encoder_model",
"label": ""
},
"clip_embed_model": {
"name": "clip_embed_model",
"label": "",
"value": {
"key": "fa23a584-b623-415d-832a-21b5098ff1a1",
"hash": "blake3:17c19f0ef941c3b7609a9c94a659ca5364de0be364a91d4179f0e39ba17c3b70",
"name": "clip-vit-large-patch14",
"base": "any",
"type": "clip_embed"
}
},
"vae_model": {
"name": "vae_model",
"label": "",
"value": {
"key": "74fc82ba-c0a8-479d-a890-2126f82da758",
"hash": "blake3:ce21cb76364aa6e2421311cf4a4b5eb052a76c4f1cd207b50703d8978198a068",
"name": "FLUX.1-schnell_ae",
"base": "flux",
"type": "vae"
}
}
}
},
"position": {
"x": 328.1809894659957,
"y": -90.2241133566946
}
},
{
"id": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"type": "invocation",
"data": {
"id": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"type": "flux_text_encoder",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": true,
"inputs": {
"clip": {
"name": "clip",
"label": ""
},
"t5_encoder": {
"name": "t5_encoder",
"label": ""
},
"t5_max_seq_len": {
"name": "t5_max_seq_len",
"label": "T5 Max Seq Len",
"value": 256
},
"prompt": {
"name": "prompt",
"label": "",
"value": "a cat wearing a birthday hat"
}
}
},
"position": {
"x": 745.8823365057267,
"y": -299.60249175851914
}
},
{
"id": "4754c534-a5f3-4ad0-9382-7887985e668c",
"type": "invocation",
"data": {
"id": "4754c534-a5f3-4ad0-9382-7887985e668c",
"type": "rand_int",
"version": "1.0.1",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": false,
"inputs": {
"low": {
"name": "low",
"label": "",
"value": 0
},
"high": {
"name": "high",
"label": "",
"value": 2147483647
}
}
},
"position": {
"x": 725.834098928012,
"y": 496.2710031089931
}
}
],
"edges": [
{
"id": "reactflow__edge-2981a67c-480f-4237-9384-26b68dbf912bheight-ace0258f-67d7-4eee-a218-6fff27065214height",
"type": "default",
"source": "2981a67c-480f-4237-9384-26b68dbf912b",
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
"sourceHandle": "height",
"targetHandle": "height"
},
{
"id": "reactflow__edge-2981a67c-480f-4237-9384-26b68dbf912bwidth-ace0258f-67d7-4eee-a218-6fff27065214width",
"type": "default",
"source": "2981a67c-480f-4237-9384-26b68dbf912b",
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
"sourceHandle": "width",
"targetHandle": "width"
},
{
"id": "reactflow__edge-2981a67c-480f-4237-9384-26b68dbf912blatents-ace0258f-67d7-4eee-a218-6fff27065214latents",
"type": "default",
"source": "2981a67c-480f-4237-9384-26b68dbf912b",
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
"sourceHandle": "latents",
"targetHandle": "latents"
},
{
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90vae-2981a67c-480f-4237-9384-26b68dbf912bvae",
"type": "default",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "2981a67c-480f-4237-9384-26b68dbf912b",
"sourceHandle": "vae",
"targetHandle": "vae"
},
{
"id": "reactflow__edge-ace0258f-67d7-4eee-a218-6fff27065214latents-7e5172eb-48c1-44db-a770-8fd83e1435d1latents",
"type": "default",
"source": "ace0258f-67d7-4eee-a218-6fff27065214",
"target": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
"sourceHandle": "latents",
"targetHandle": "latents"
},
{
"id": "reactflow__edge-4754c534-a5f3-4ad0-9382-7887985e668cvalue-ace0258f-67d7-4eee-a218-6fff27065214seed",
"type": "default",
"source": "4754c534-a5f3-4ad0-9382-7887985e668c",
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
"sourceHandle": "value",
"targetHandle": "seed"
},
{
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90transformer-ace0258f-67d7-4eee-a218-6fff27065214transformer",
"type": "default",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
"sourceHandle": "transformer",
"targetHandle": "transformer"
},
{
"id": "reactflow__edge-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cconditioning-ace0258f-67d7-4eee-a218-6fff27065214positive_text_conditioning",
"type": "default",
"source": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
"sourceHandle": "conditioning",
"targetHandle": "positive_text_conditioning"
},
{
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90vae-7e5172eb-48c1-44db-a770-8fd83e1435d1vae",
"type": "default",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
"sourceHandle": "vae",
"targetHandle": "vae"
},
{
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90max_seq_len-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_max_seq_len",
"type": "default",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"sourceHandle": "max_seq_len",
"targetHandle": "t5_max_seq_len"
},
{
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90t5_encoder-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_encoder",
"type": "default",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"sourceHandle": "t5_encoder",
"targetHandle": "t5_encoder"
},
{
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90clip-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cclip",
"type": "default",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"sourceHandle": "clip",
"targetHandle": "clip"
}
]
}

View File

@@ -0,0 +1,326 @@
{
"name": "FLUX Text to Image",
"author": "InvokeAI",
"description": "A simple text-to-image workflow using FLUX dev or schnell models.",
"version": "1.0.4",
"contact": "",
"tags": "text2image, flux",
"notes": "Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend 4 steps for FLUX schnell models and 30 steps for FLUX dev models.",
"exposedFields": [
{
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"fieldName": "model"
},
{
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"fieldName": "t5_encoder_model"
},
{
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"fieldName": "clip_embed_model"
},
{
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"fieldName": "vae_model"
},
{
"nodeId": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"fieldName": "prompt"
},
{
"nodeId": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
"fieldName": "num_steps"
}
],
"meta": {
"version": "3.0.0",
"category": "default"
},
"nodes": [
{
"id": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
"type": "invocation",
"data": {
"id": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
"type": "flux_denoise",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": true,
"inputs": {
"board": {
"name": "board",
"label": ""
},
"metadata": {
"name": "metadata",
"label": ""
},
"latents": {
"name": "latents",
"label": ""
},
"denoise_mask": {
"name": "denoise_mask",
"label": ""
},
"denoising_start": {
"name": "denoising_start",
"label": "",
"value": 0
},
"denoising_end": {
"name": "denoising_end",
"label": "",
"value": 1
},
"transformer": {
"name": "transformer",
"label": ""
},
"positive_text_conditioning": {
"name": "positive_text_conditioning",
"label": ""
},
"width": {
"name": "width",
"label": "",
"value": 1024
},
"height": {
"name": "height",
"label": "",
"value": 1024
},
"num_steps": {
"name": "num_steps",
"label": "Steps (Recommend 30 for Dev, 4 for Schnell)",
"value": 30
},
"guidance": {
"name": "guidance",
"label": "",
"value": 4
},
"seed": {
"name": "seed",
"label": "",
"value": 0
}
}
},
"position": {
"x": 1186.1868226120378,
"y": -214.9459927686657
}
},
{
"id": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
"type": "invocation",
"data": {
"id": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
"type": "flux_vae_decode",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": false,
"useCache": true,
"inputs": {
"board": {
"name": "board",
"label": ""
},
"metadata": {
"name": "metadata",
"label": ""
},
"latents": {
"name": "latents",
"label": ""
},
"vae": {
"name": "vae",
"label": ""
}
}
},
"position": {
"x": 1575.5797431839133,
"y": -209.00150975507415
}
},
{
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"type": "invocation",
"data": {
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"type": "flux_model_loader",
"version": "1.0.4",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": false,
"inputs": {
"model": {
"name": "model",
"label": ""
},
"t5_encoder_model": {
"name": "t5_encoder_model",
"label": ""
},
"clip_embed_model": {
"name": "clip_embed_model",
"label": ""
},
"vae_model": {
"name": "vae_model",
"label": ""
}
}
},
"position": {
"x": 381.1882713063478,
"y": -95.89663532854017
}
},
{
"id": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"type": "invocation",
"data": {
"id": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"type": "flux_text_encoder",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": true,
"inputs": {
"clip": {
"name": "clip",
"label": ""
},
"t5_encoder": {
"name": "t5_encoder",
"label": ""
},
"t5_max_seq_len": {
"name": "t5_max_seq_len",
"label": "T5 Max Seq Len",
"value": 256
},
"prompt": {
"name": "prompt",
"label": "",
"value": "a cat"
}
}
},
"position": {
"x": 778.4899149328337,
"y": -100.36469216659502
}
},
{
"id": "4754c534-a5f3-4ad0-9382-7887985e668c",
"type": "invocation",
"data": {
"id": "4754c534-a5f3-4ad0-9382-7887985e668c",
"type": "rand_int",
"version": "1.0.1",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": false,
"inputs": {
"low": {
"name": "low",
"label": "",
"value": 0
},
"high": {
"name": "high",
"label": "",
"value": 2147483647
}
}
},
"position": {
"x": 800.9667463219505,
"y": 285.8297267547506
}
}
],
"edges": [
{
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90transformer-4fe24f07-f906-4f55-ab2c-9beee56ef5bdtransformer",
"type": "default",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
"sourceHandle": "transformer",
"targetHandle": "transformer"
},
{
"id": "reactflow__edge-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cconditioning-4fe24f07-f906-4f55-ab2c-9beee56ef5bdpositive_text_conditioning",
"type": "default",
"source": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"target": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
"sourceHandle": "conditioning",
"targetHandle": "positive_text_conditioning"
},
{
"id": "reactflow__edge-4754c534-a5f3-4ad0-9382-7887985e668cvalue-4fe24f07-f906-4f55-ab2c-9beee56ef5bdseed",
"type": "default",
"source": "4754c534-a5f3-4ad0-9382-7887985e668c",
"target": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
"sourceHandle": "value",
"targetHandle": "seed"
},
{
"id": "reactflow__edge-4fe24f07-f906-4f55-ab2c-9beee56ef5bdlatents-7e5172eb-48c1-44db-a770-8fd83e1435d1latents",
"type": "default",
"source": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
"target": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
"sourceHandle": "latents",
"targetHandle": "latents"
},
{
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90vae-7e5172eb-48c1-44db-a770-8fd83e1435d1vae",
"type": "default",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
"sourceHandle": "vae",
"targetHandle": "vae"
},
{
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90max_seq_len-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_max_seq_len",
"type": "default",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"sourceHandle": "max_seq_len",
"targetHandle": "t5_max_seq_len"
},
{
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90t5_encoder-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_encoder",
"type": "default",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"sourceHandle": "t5_encoder",
"targetHandle": "t5_encoder"
},
{
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90clip-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cclip",
"type": "default",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"sourceHandle": "clip",
"targetHandle": "clip"
}
]
}

View File

@@ -38,6 +38,25 @@ SD1_5_LATENT_RGB_FACTORS = [
[-0.1307, -0.1874, -0.7445], # L4
]
FLUX_LATENT_RGB_FACTORS = [
[-0.0412, 0.0149, 0.0521],
[0.0056, 0.0291, 0.0768],
[0.0342, -0.0681, -0.0427],
[-0.0258, 0.0092, 0.0463],
[0.0863, 0.0784, 0.0547],
[-0.0017, 0.0402, 0.0158],
[0.0501, 0.1058, 0.1152],
[-0.0209, -0.0218, -0.0329],
[-0.0314, 0.0083, 0.0896],
[0.0851, 0.0665, -0.0472],
[-0.0534, 0.0238, -0.0024],
[0.0452, -0.0026, 0.0048],
[0.0892, 0.0831, 0.0881],
[-0.1117, -0.0304, -0.0789],
[0.0027, -0.0479, -0.0043],
[-0.1146, -0.0827, -0.0598],
]
def sample_to_lowres_estimated_image(
samples: torch.Tensor, latent_rgb_factors: torch.Tensor, smooth_matrix: Optional[torch.Tensor] = None
@@ -94,3 +113,32 @@ def stable_diffusion_step_callback(
intermediate_state,
ProgressImage(dataURL=dataURL, width=width, height=height),
)
def flux_step_callback(
context_data: "InvocationContextData",
intermediate_state: PipelineIntermediateState,
events: "EventServiceBase",
is_canceled: Callable[[], bool],
) -> None:
if is_canceled():
raise CanceledException
sample = intermediate_state.latents
latent_rgb_factors = torch.tensor(FLUX_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
latent_image_perm = sample.permute(1, 2, 0).to(dtype=sample.dtype, device=sample.device)
latent_image = latent_image_perm @ latent_rgb_factors
latents_ubyte = (
((latent_image + 1) / 2).clamp(0, 1).mul(0xFF) # change scale from -1..1 to 0..1 # to 0..255
).to(device="cpu", dtype=torch.uint8)
image = Image.fromarray(latents_ubyte.cpu().numpy())
(width, height) = image.size
width *= 8
height *= 8
dataURL = image_to_dataURL(image, image_format="JPEG")
events.emit_invocation_denoise_progress(
context_data.queue_item,
context_data.invocation,
intermediate_state,
ProgressImage(dataURL=dataURL, width=width, height=height),
)

View File

@@ -0,0 +1,56 @@
from typing import Callable
import torch
from tqdm import tqdm
from invokeai.backend.flux.inpaint_extension import InpaintExtension
from invokeai.backend.flux.model import Flux
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
def denoise(
model: Flux,
# model input
img: torch.Tensor,
img_ids: torch.Tensor,
txt: torch.Tensor,
txt_ids: torch.Tensor,
vec: torch.Tensor,
# sampling parameters
timesteps: list[float],
step_callback: Callable[[PipelineIntermediateState], None],
guidance: float,
inpaint_extension: InpaintExtension | None,
):
step = 0
# guidance_vec is ignored for schnell.
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
)
preview_img = img - t_curr * pred
img = img + (t_prev - t_curr) * pred
if inpaint_extension is not None:
img = inpaint_extension.merge_intermediate_latents_with_init_latents(img, t_prev)
step_callback(
PipelineIntermediateState(
step=step,
order=1,
total_steps=len(timesteps),
timestep=int(t_curr),
latents=preview_img,
),
)
step += 1
return img

View File

@@ -0,0 +1,35 @@
import torch
class InpaintExtension:
"""A class for managing inpainting with FLUX."""
def __init__(self, init_latents: torch.Tensor, inpaint_mask: torch.Tensor, noise: torch.Tensor):
"""Initialize InpaintExtension.
Args:
init_latents (torch.Tensor): The initial latents (i.e. un-noised at timestep 0). In 'packed' format.
inpaint_mask (torch.Tensor): A mask specifying which elements to inpaint. Range [0, 1]. Values of 1 will be
re-generated. Values of 0 will remain unchanged. Values between 0 and 1 can be used to blend the
inpainted region with the background. In 'packed' format.
noise (torch.Tensor): The noise tensor used to noise the init_latents. In 'packed' format.
"""
assert init_latents.shape == inpaint_mask.shape == noise.shape
self._init_latents = init_latents
self._inpaint_mask = inpaint_mask
self._noise = noise
def merge_intermediate_latents_with_init_latents(
self, intermediate_latents: torch.Tensor, timestep: float
) -> torch.Tensor:
"""Merge the intermediate latents with the initial latents for the current timestep using the inpaint mask. I.e.
update the intermediate latents to keep the regions that are not being inpainted on the correct noise
trajectory.
This function should be called after each denoising step.
"""
# Noise the init latents for the current timestep.
noised_init_latents = self._noise * timestep + (1.0 - timestep) * self._init_latents
# Merge the intermediate latents with the noised_init_latents using the inpaint_mask.
return intermediate_latents * self._inpaint_mask + noised_init_latents * (1.0 - self._inpaint_mask)

View File

@@ -0,0 +1,32 @@
# Initially pulled from https://github.com/black-forest-labs/flux
import torch
from einops import rearrange
from torch import Tensor
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
q, k = apply_rope(q, k, pe)
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "B H L D -> B L (H D)")
return x
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
assert dim % 2 == 0
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos, omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
return out.float()
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)

View File

@@ -0,0 +1,117 @@
# Initially pulled from https://github.com/black-forest-labs/flux
from dataclasses import dataclass
import torch
from torch import Tensor, nn
from invokeai.backend.flux.modules.layers import (
DoubleStreamBlock,
EmbedND,
LastLayer,
MLPEmbedder,
SingleStreamBlock,
timestep_embedding,
)
@dataclass
class FluxParams:
in_channels: int
vec_in_dim: int
context_in_dim: int
hidden_size: int
mlp_ratio: float
num_heads: int
depth: int
depth_single_blocks: int
axes_dim: list[int]
theta: int
qkv_bias: bool
guidance_embed: bool
class Flux(nn.Module):
"""
Transformer model for flow matching on sequences.
"""
def __init__(self, params: FluxParams):
super().__init__()
self.params = params
self.in_channels = params.in_channels
self.out_channels = self.in_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
)
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
)
for _ in range(params.depth)
]
)
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
for _ in range(params.depth_single_blocks)
]
)
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
def forward(
self,
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor | None = None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256))
if self.params.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
img = torch.cat((txt, img), 1)
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe)
img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img

View File

@@ -0,0 +1,324 @@
# Initially pulled from https://github.com/black-forest-labs/flux
from dataclasses import dataclass
import torch
from einops import rearrange
from torch import Tensor, nn
@dataclass
class AutoEncoderParams:
resolution: int
in_channels: int
ch: int
out_ch: int
ch_mult: list[int]
num_res_blocks: int
z_channels: int
scale_factor: float
shift_factor: float
class AttnBlock(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.in_channels = in_channels
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
def attention(self, h_: Tensor) -> Tensor:
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
def forward(self, x: Tensor) -> Tensor:
return x + self.proj_out(self.attention(x))
class ResnetBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
h = x
h = self.norm1(h)
h = torch.nn.functional.silu(h)
h = self.conv1(h)
h = self.norm2(h)
h = torch.nn.functional.silu(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x + h
class Downsample(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
# no asymmetric padding in torch conv, must do it ourselves
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x: Tensor):
pad = (0, 1, 0, 1)
x = nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
return x
class Upsample(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x: Tensor):
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
x = self.conv(x)
return x
class Encoder(nn.Module):
def __init__(
self,
resolution: int,
in_channels: int,
ch: int,
ch_mult: list[int],
num_res_blocks: int,
z_channels: int,
):
super().__init__()
self.ch = ch
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# downsampling
self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
block_in = self.ch
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
# end
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x: Tensor) -> Tensor:
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1])
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
# end
h = self.norm_out(h)
h = torch.nn.functional.silu(h)
h = self.conv_out(h)
return h
class Decoder(nn.Module):
def __init__(
self,
ch: int,
out_ch: int,
ch_mult: list[int],
num_res_blocks: int,
in_channels: int,
resolution: int,
z_channels: int,
):
super().__init__()
self.ch = ch
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.ffactor = 2 ** (self.num_resolutions - 1)
# compute in_ch_mult, block_in and curr_res at lowest res
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
# z to block_in
self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks + 1):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
def forward(self, z: Tensor) -> Tensor:
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
h = self.norm_out(h)
h = torch.nn.functional.silu(h)
h = self.conv_out(h)
return h
class DiagonalGaussian(nn.Module):
def __init__(self, chunk_dim: int = 1):
super().__init__()
self.chunk_dim = chunk_dim
def forward(self, z: Tensor, sample: bool = True, generator: torch.Generator | None = None) -> Tensor:
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
if sample:
std = torch.exp(0.5 * logvar)
# Unfortunately, torch.randn_like(...) does not accept a generator argument at the time of writing, so we
# have to use torch.randn(...) instead.
return mean + std * torch.randn(size=mean.size(), generator=generator, dtype=mean.dtype, device=mean.device)
else:
return mean
class AutoEncoder(nn.Module):
def __init__(self, params: AutoEncoderParams):
super().__init__()
self.encoder = Encoder(
resolution=params.resolution,
in_channels=params.in_channels,
ch=params.ch,
ch_mult=params.ch_mult,
num_res_blocks=params.num_res_blocks,
z_channels=params.z_channels,
)
self.decoder = Decoder(
resolution=params.resolution,
in_channels=params.in_channels,
ch=params.ch,
out_ch=params.out_ch,
ch_mult=params.ch_mult,
num_res_blocks=params.num_res_blocks,
z_channels=params.z_channels,
)
self.reg = DiagonalGaussian()
self.scale_factor = params.scale_factor
self.shift_factor = params.shift_factor
def encode(self, x: Tensor, sample: bool = True, generator: torch.Generator | None = None) -> Tensor:
"""Run VAE encoding on input tensor x.
Args:
x (Tensor): Input image tensor. Shape: (batch_size, in_channels, height, width).
sample (bool, optional): If True, sample from the encoded distribution, else, return the distribution mean.
Defaults to True.
generator (torch.Generator | None, optional): Optional random number generator for reproducibility.
Defaults to None.
Returns:
Tensor: Encoded latent tensor. Shape: (batch_size, z_channels, latent_height, latent_width).
"""
z = self.reg(self.encoder(x), sample=sample, generator=generator)
z = self.scale_factor * (z - self.shift_factor)
return z
def decode(self, z: Tensor) -> Tensor:
z = z / self.scale_factor + self.shift_factor
return self.decoder(z)
def forward(self, x: Tensor) -> Tensor:
return self.decode(self.encode(x))

View File

@@ -0,0 +1,33 @@
# Initially pulled from https://github.com/black-forest-labs/flux
from torch import Tensor, nn
from transformers import PreTrainedModel, PreTrainedTokenizer
class HFEncoder(nn.Module):
def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool, max_length: int):
super().__init__()
self.max_length = max_length
self.is_clip = is_clip
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
self.tokenizer = tokenizer
self.hf_module = encoder
self.hf_module = self.hf_module.eval().requires_grad_(False)
def forward(self, text: list[str]) -> Tensor:
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=False,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
outputs = self.hf_module(
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
attention_mask=None,
output_hidden_states=False,
)
return outputs[self.output_key]

View File

@@ -0,0 +1,253 @@
# Initially pulled from https://github.com/black-forest-labs/flux
import math
from dataclasses import dataclass
import torch
from einops import rearrange
from torch import Tensor, nn
from invokeai.backend.flux.math import attention, rope
class EmbedND(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: Tensor) -> Tensor:
n_axes = ids.shape[-1]
emb = torch.cat(
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
dim=-3,
)
return emb.unsqueeze(1)
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
t = time_factor * t
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
if torch.is_floating_point(t):
embedding = embedding.to(t)
return embedding
class MLPEmbedder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int):
super().__init__()
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
self.silu = nn.SiLU()
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
def forward(self, x: Tensor) -> Tensor:
return self.out_layer(self.silu(self.in_layer(x)))
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int):
super().__init__()
self.scale = nn.Parameter(torch.ones(dim))
def forward(self, x: Tensor):
x_dtype = x.dtype
x = x.float()
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
return (x * rrms).to(dtype=x_dtype) * self.scale
class QKNorm(torch.nn.Module):
def __init__(self, dim: int):
super().__init__()
self.query_norm = RMSNorm(dim)
self.key_norm = RMSNorm(dim)
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
q = self.query_norm(q)
k = self.key_norm(k)
return q.to(v), k.to(v)
class SelfAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.norm = QKNorm(head_dim)
self.proj = nn.Linear(dim, dim)
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
qkv = self.qkv(x)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k = self.norm(q, k, v)
x = attention(q, k, v, pe=pe)
x = self.proj(x)
return x
@dataclass
class ModulationOut:
shift: Tensor
scale: Tensor
gate: Tensor
class Modulation(nn.Module):
def __init__(self, dim: int, double: bool):
super().__init__()
self.is_double = double
self.multiplier = 6 if double else 3
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
return (
ModulationOut(*out[:3]),
ModulationOut(*out[3:]) if self.is_double else None,
)
class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
self.hidden_size = hidden_size
self.img_mod = Modulation(hidden_size, double=True)
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)
self.txt_mod = Modulation(hidden_size, double=True)
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
# run actual attention
q = torch.cat((txt_q, img_q), dim=2)
k = torch.cat((txt_k, img_k), dim=2)
v = torch.cat((txt_v, img_v), dim=2)
attn = attention(q, k, v, pe=pe)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# calculate the img bloks
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
# calculate the txt bloks
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
return img, txt
class SingleStreamBlock(nn.Module):
"""
A DiT block with parallel linear layers as described in
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
"""
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
qk_scale: float | None = None,
):
super().__init__()
self.hidden_dim = hidden_size
self.num_heads = num_heads
head_dim = hidden_size // num_heads
self.scale = qk_scale or head_dim**-0.5
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
# qkv and mlp_in
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
# proj and mlp_out
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
self.norm = QKNorm(head_dim)
self.hidden_size = hidden_size
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = Modulation(hidden_size, double=False)
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
mod, _ = self.modulation(vec)
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k = self.norm(q, k, v)
# compute attention
attn = attention(q, k, v, pe=pe)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
return x + mod.gate * output
class LastLayer(nn.Module):
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
x = self.linear(x)
return x

View File

@@ -0,0 +1,135 @@
# Initially pulled from https://github.com/black-forest-labs/flux
import math
from typing import Callable
import torch
from einops import rearrange, repeat
def get_noise(
num_samples: int,
height: int,
width: int,
device: torch.device,
dtype: torch.dtype,
seed: int,
):
# We always generate noise on the same device and dtype then cast to ensure consistency across devices/dtypes.
rand_device = "cpu"
rand_dtype = torch.float16
return torch.randn(
num_samples,
16,
# allow for packing
2 * math.ceil(height / 16),
2 * math.ceil(width / 16),
device=rand_device,
dtype=rand_dtype,
generator=torch.Generator(device=rand_device).manual_seed(seed),
).to(device=device, dtype=dtype)
def time_shift(mu: float, sigma: float, t: torch.Tensor) -> torch.Tensor:
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b
def get_schedule(
num_steps: int,
image_seq_len: int,
base_shift: float = 0.5,
max_shift: float = 1.15,
shift: bool = True,
) -> list[float]:
# extra step for zero
timesteps = torch.linspace(1, 0, num_steps + 1)
# shifting the schedule to favor high timesteps for higher signal images
if shift:
# estimate mu based on linear estimation between two points
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
timesteps = time_shift(mu, 1.0, timesteps)
return timesteps.tolist()
def _find_last_index_ge_val(timesteps: list[float], val: float, eps: float = 1e-6) -> int:
"""Find the last index in timesteps that is >= val.
We use epsilon-close equality to avoid potential floating point errors.
"""
idx = len(list(filter(lambda t: t >= (val - eps), timesteps))) - 1
assert idx >= 0
return idx
def clip_timestep_schedule(timesteps: list[float], denoising_start: float, denoising_end: float) -> list[float]:
"""Clip the timestep schedule to the denoising range.
Args:
timesteps (list[float]): The original timestep schedule: [1.0, ..., 0.0].
denoising_start (float): A value in [0, 1] specifying the start of the denoising process. E.g. a value of 0.2
would mean that the denoising process start at the last timestep in the schedule >= 0.8.
denoising_end (float): A value in [0, 1] specifying the end of the denoising process. E.g. a value of 0.8 would
mean that the denoising process end at the last timestep in the schedule >= 0.2.
Returns:
list[float]: The clipped timestep schedule.
"""
assert 0.0 <= denoising_start <= 1.0
assert 0.0 <= denoising_end <= 1.0
assert denoising_start <= denoising_end
t_start_val = 1.0 - denoising_start
t_end_val = 1.0 - denoising_end
t_start_idx = _find_last_index_ge_val(timesteps, t_start_val)
t_end_idx = _find_last_index_ge_val(timesteps, t_end_val)
clipped_timesteps = timesteps[t_start_idx : t_end_idx + 1]
return clipped_timesteps
def unpack(x: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""Unpack flat array of patch embeddings to latent image."""
return rearrange(
x,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=math.ceil(height / 16),
w=math.ceil(width / 16),
ph=2,
pw=2,
)
def pack(x: torch.Tensor) -> torch.Tensor:
"""Pack latent image to flattented array of patch embeddings."""
# Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches.
return rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
def generate_img_ids(h: int, w: int, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
"""Generate tensor of image position ids.
Args:
h (int): Height of image in latent space.
w (int): Width of image in latent space.
batch_size (int): Batch size.
device (torch.device): Device.
dtype (torch.dtype): dtype.
Returns:
torch.Tensor: Image position ids.
"""
img_ids = torch.zeros(h // 2, w // 2, 3, device=device, dtype=dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=device, dtype=dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=device, dtype=dtype)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
return img_ids

View File

@@ -0,0 +1,71 @@
# Initially pulled from https://github.com/black-forest-labs/flux
from dataclasses import dataclass
from typing import Dict, Literal
from invokeai.backend.flux.model import FluxParams
from invokeai.backend.flux.modules.autoencoder import AutoEncoderParams
@dataclass
class ModelSpec:
params: FluxParams
ae_params: AutoEncoderParams
ckpt_path: str | None
ae_path: str | None
repo_id: str | None
repo_flow: str | None
repo_ae: str | None
max_seq_lengths: Dict[str, Literal[256, 512]] = {
"flux-dev": 512,
"flux-schnell": 256,
}
ae_params = {
"flux": AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
)
}
params = {
"flux-dev": FluxParams(
in_channels=64,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=True,
),
"flux-schnell": FluxParams(
in_channels=64,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=False,
),
}

View File

@@ -52,6 +52,7 @@ class BaseModelType(str, Enum):
StableDiffusion2 = "sd-2"
StableDiffusionXL = "sdxl"
StableDiffusionXLRefiner = "sdxl-refiner"
Flux = "flux"
# Kandinsky2_1 = "kandinsky-2.1"
@@ -66,7 +67,9 @@ class ModelType(str, Enum):
TextualInversion = "embedding"
IPAdapter = "ip_adapter"
CLIPVision = "clip_vision"
CLIPEmbed = "clip_embed"
T2IAdapter = "t2i_adapter"
T5Encoder = "t5_encoder"
SpandrelImageToImage = "spandrel_image_to_image"
@@ -74,6 +77,7 @@ class SubModelType(str, Enum):
"""Submodel type."""
UNet = "unet"
Transformer = "transformer"
TextEncoder = "text_encoder"
TextEncoder2 = "text_encoder_2"
Tokenizer = "tokenizer"
@@ -104,6 +108,9 @@ class ModelFormat(str, Enum):
EmbeddingFile = "embedding_file"
EmbeddingFolder = "embedding_folder"
InvokeAI = "invokeai"
T5Encoder = "t5_encoder"
BnbQuantizedLlmInt8b = "bnb_quantized_int8b"
BnbQuantizednf4b = "bnb_quantized_nf4b"
class SchedulerPredictionType(str, Enum):
@@ -186,7 +193,9 @@ class ModelConfigBase(BaseModel):
class CheckpointConfigBase(ModelConfigBase):
"""Model config for checkpoint-style models."""
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
format: Literal[ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b] = Field(
description="Format of the provided checkpoint model", default=ModelFormat.Checkpoint
)
config_path: str = Field(description="path to the checkpoint model config file")
converted_at: Optional[float] = Field(
description="When this model was last converted to diffusers", default_factory=time.time
@@ -205,6 +214,26 @@ class LoRAConfigBase(ModelConfigBase):
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
class T5EncoderConfigBase(ModelConfigBase):
type: Literal[ModelType.T5Encoder] = ModelType.T5Encoder
class T5EncoderConfig(T5EncoderConfigBase):
format: Literal[ModelFormat.T5Encoder] = ModelFormat.T5Encoder
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.T5Encoder.value}.{ModelFormat.T5Encoder.value}")
class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase):
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = ModelFormat.BnbQuantizedLlmInt8b
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.T5Encoder.value}.{ModelFormat.BnbQuantizedLlmInt8b.value}")
class LoRALyCORISConfig(LoRAConfigBase):
"""Model config for LoRA/Lycoris models."""
@@ -229,7 +258,6 @@ class VAECheckpointConfig(CheckpointConfigBase):
"""Model config for standalone VAE models."""
type: Literal[ModelType.VAE] = ModelType.VAE
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
@staticmethod
def get_tag() -> Tag:
@@ -268,7 +296,6 @@ class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase)
"""Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
@staticmethod
def get_tag() -> Tag:
@@ -317,6 +344,21 @@ class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase):
return Tag(f"{ModelType.Main.value}.{ModelFormat.Checkpoint.value}")
class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase):
"""Model config for main checkpoint models."""
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.format = ModelFormat.BnbQuantizednf4b
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Main.value}.{ModelFormat.BnbQuantizednf4b.value}")
class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase):
"""Model config for main diffusers models."""
@@ -350,6 +392,17 @@ class IPAdapterCheckpointConfig(IPAdapterBaseConfig):
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.Checkpoint.value}")
class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
"""Model config for Clip Embeddings."""
type: Literal[ModelType.CLIPEmbed] = ModelType.CLIPEmbed
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}")
class CLIPVisionDiffusersConfig(DiffusersConfigBase):
"""Model config for CLIPVision."""
@@ -408,12 +461,15 @@ AnyModelConfig = Annotated[
Union[
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
Annotated[MainBnbQuantized4bCheckpointConfig, MainBnbQuantized4bCheckpointConfig.get_tag()],
Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()],
Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()],
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
Annotated[T5EncoderConfig, T5EncoderConfig.get_tag()],
Annotated[T5EncoderBnbQuantizedLlmInt8bConfig, T5EncoderBnbQuantizedLlmInt8bConfig.get_tag()],
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()],
@@ -421,6 +477,7 @@ AnyModelConfig = Annotated[
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()],
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
Annotated[CLIPEmbedDiffusersConfig, CLIPEmbedDiffusersConfig.get_tag()],
],
Discriminator(get_model_discriminator_value),
]

View File

@@ -66,12 +66,14 @@ class ModelLoader(ModelLoaderBase):
return (model_base / config.path).resolve()
def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> ModelLockerBase:
stats_name = ":".join([config.base, config.type, config.name, (submodel_type or "")])
try:
return self._ram_cache.get(config.key, submodel_type)
return self._ram_cache.get(config.key, submodel_type, stats_name=stats_name)
except IndexError:
pass
config.path = str(self._get_model_path(config))
self._ram_cache.make_room(self.get_size_fs(config, Path(config.path), submodel_type))
loaded_model = self._load_model(config, submodel_type)
self._ram_cache.put(
@@ -83,7 +85,7 @@ class ModelLoader(ModelLoaderBase):
return self._ram_cache.get(
key=config.key,
submodel_type=submodel_type,
stats_name=":".join([config.base, config.type, config.name, (submodel_type or "")]),
stats_name=stats_name,
)
def get_size_fs(

View File

@@ -128,7 +128,24 @@ class ModelCacheBase(ABC, Generic[T]):
@property
@abstractmethod
def max_cache_size(self) -> float:
"""Return true if the cache is configured to lazily offload models in VRAM."""
"""Return the maximum size the RAM cache can grow to."""
pass
@max_cache_size.setter
@abstractmethod
def max_cache_size(self, value: float) -> None:
"""Set the cap on vram cache size."""
@property
@abstractmethod
def max_vram_cache_size(self) -> float:
"""Return the maximum size the VRAM cache can grow to."""
pass
@max_vram_cache_size.setter
@abstractmethod
def max_vram_cache_size(self, value: float) -> float:
"""Set the maximum size the VRAM cache can grow to."""
pass
@abstractmethod
@@ -193,15 +210,6 @@ class ModelCacheBase(ABC, Generic[T]):
"""
pass
@abstractmethod
def exists(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
) -> bool:
"""Return true if the model identified by key and submodel_type is in the cache."""
pass
@abstractmethod
def cache_size(self) -> int:
"""Get the total size of the models currently cached."""

View File

@@ -1,22 +1,6 @@
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
# TODO: Add Stalker's proper name to copyright
"""
Manage a RAM cache of diffusion/transformer models for fast switching.
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
grows larger than a preset maximum, then the least recently used
model will be cleared and (re)loaded from disk when next needed.
The cache returns context manager generators designed to load the
model into the GPU within the context, and unload outside the
context. Use like this:
cache = ModelCache(max_cache_size=7.5)
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1,
cache.get_model('stabilityai/stable-diffusion-2') as SD2:
do_something_in_GPU(SD1,SD2)
"""
""" """
import gc
import math
@@ -40,53 +24,74 @@ from invokeai.backend.model_manager.load.model_util import calc_model_size_by_da
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
# Maximum size of the cache, in gigs
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
DEFAULT_MAX_CACHE_SIZE = 6.0
# amount of GPU memory to hold in reserve for use by generations (GB)
DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
# actual size of a gig
GIG = 1073741824
# Size of a GB in bytes.
GB = 2**30
# Size of a MB in bytes.
MB = 2**20
class ModelCache(ModelCacheBase[AnyModel]):
"""Implementation of ModelCacheBase."""
"""A cache for managing models in memory.
The cache is based on two levels of model storage:
- execution_device: The device where most models are executed (typically "cuda", "mps", or "cpu").
- storage_device: The device where models are offloaded when not in active use (typically "cpu").
The model cache is based on the following assumptions:
- storage_device_mem_size > execution_device_mem_size
- disk_to_storage_device_transfer_time >> storage_device_to_execution_device_transfer_time
A copy of all models in the cache is always kept on the storage_device. A subset of the models also have a copy on
the execution_device.
Models are moved between the storage_device and the execution_device as necessary. Cache size limits are enforced
on both the storage_device and the execution_device. The execution_device cache uses a smallest-first offload
policy. The storage_device cache uses a least-recently-used (LRU) offload policy.
Note: Neither of these offload policies has really been compared against alternatives. It's likely that different
policies would be better, although the optimal policies are likely heavily dependent on usage patterns and HW
configuration.
The cache returns context manager generators designed to load the model into the execution device (often GPU) within
the context, and unload outside the context.
Example usage:
```
cache = ModelCache(max_cache_size=7.5, max_vram_cache_size=6.0)
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1:
do_something_on_gpu(SD1)
```
"""
def __init__(
self,
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
max_cache_size: float,
max_vram_cache_size: float,
execution_device: torch.device = torch.device("cuda"),
storage_device: torch.device = torch.device("cpu"),
precision: torch.dtype = torch.float16,
sequential_offload: bool = False,
lazy_offloading: bool = True,
sha_chunksize: int = 16777216,
log_memory_usage: bool = False,
logger: Optional[Logger] = None,
):
"""
Initialize the model RAM cache.
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
:param max_cache_size: Maximum size of the storage_device cache in GBs.
:param max_vram_cache_size: Maximum size of the execution_device cache in GBs.
:param execution_device: Torch device to load active model into [torch.device('cuda')]
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
:param precision: Precision for loaded models [torch.float16]
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
behaviour.
:param logger: InvokeAILogger to use (otherwise creates one)
"""
# allow lazy offloading only when vram cache enabled
self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
self._precision: torch.dtype = precision
self._max_cache_size: float = max_cache_size
self._max_vram_cache_size: float = max_vram_cache_size
self._execution_device: torch.device = execution_device
@@ -128,6 +133,16 @@ class ModelCache(ModelCacheBase[AnyModel]):
"""Set the cap on cache size."""
self._max_cache_size = value
@property
def max_vram_cache_size(self) -> float:
"""Return the cap on vram cache size."""
return self._max_vram_cache_size
@max_vram_cache_size.setter
def max_vram_cache_size(self, value: float) -> None:
"""Set the cap on vram cache size."""
self._max_vram_cache_size = value
@property
def stats(self) -> Optional[CacheStats]:
"""Return collected CacheStats object."""
@@ -145,15 +160,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
total += cache_record.size
return total
def exists(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
) -> bool:
"""Return true if the model identified by key and submodel_type is in the cache."""
key = self._make_cache_key(key, submodel_type)
return key in self._cached_models
def put(
self,
key: str,
@@ -203,7 +209,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
# more stats
if self.stats:
stats_name = stats_name or key
self.stats.cache_size = int(self._max_cache_size * GIG)
self.stats.cache_size = int(self._max_cache_size * GB)
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
self.stats.in_cache = len(self._cached_models)
self.stats.loaded_model_sizes[stats_name] = max(
@@ -231,10 +237,13 @@ class ModelCache(ModelCacheBase[AnyModel]):
return model_key
def offload_unlocked_models(self, size_required: int) -> None:
"""Move any unused models from VRAM."""
reserved = self._max_vram_cache_size * GIG
"""Offload models from the execution_device to make room for size_required.
:param size_required: The amount of space to clear in the execution_device cache, in bytes.
"""
reserved = self._max_vram_cache_size * GB
vram_in_use = torch.cuda.memory_allocated() + size_required
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB")
self.logger.debug(f"{(vram_in_use/GB):.2f}GB VRAM needed for models; max allowed={(reserved/GB):.2f}GB")
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
if vram_in_use <= reserved:
break
@@ -245,7 +254,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
cache_entry.loaded = False
vram_in_use = torch.cuda.memory_allocated() + size_required
self.logger.debug(
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GB):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GB):.2f}GB"
)
TorchDevice.empty_cache()
@@ -303,7 +312,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
self.logger.debug(
f"Moved model '{cache_entry.key}' from {source_device} to"
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB."
f"Estimated model size: {(cache_entry.size/GB):.3f} GB."
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
@@ -326,14 +335,14 @@ class ModelCache(ModelCacheBase[AnyModel]):
f"Moving model '{cache_entry.key}' from {source_device} to"
f" {target_device} caused an unexpected change in VRAM usage. The model's"
" estimated size may be incorrect. Estimated model size:"
f" {(cache_entry.size/GIG):.3f} GB.\n"
f" {(cache_entry.size/GB):.3f} GB.\n"
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
def print_cuda_stats(self) -> None:
"""Log CUDA diagnostics."""
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
ram = "%4.2fG" % (self.cache_size() / GIG)
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GB)
ram = "%4.2fG" % (self.cache_size() / GB)
in_ram_models = 0
in_vram_models = 0
@@ -353,17 +362,20 @@ class ModelCache(ModelCacheBase[AnyModel]):
)
def make_room(self, size: int) -> None:
"""Make enough room in the cache to accommodate a new model of indicated size."""
# calculate how much memory this model will require
# multiplier = 2 if self.precision==torch.float32 else 1
"""Make enough room in the cache to accommodate a new model of indicated size.
Note: This function deletes all of the cache's internal references to a model in order to free it. If there are
external references to the model, there's nothing that the cache can do about it, and those models will not be
garbage-collected.
"""
bytes_needed = size
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
maximum_size = self.max_cache_size * GB # stored in GB, convert to bytes
current_size = self.cache_size()
if current_size + bytes_needed > maximum_size:
self.logger.debug(
f"Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional"
f" {(bytes_needed/GIG):.2f} GB"
f"Max cache size exceeded: {(current_size/GB):.2f}/{self.max_cache_size:.2f} GB, need an additional"
f" {(bytes_needed/GB):.2f} GB"
)
self.logger.debug(f"Before making_room: cached_models={len(self._cached_models)}")
@@ -380,7 +392,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
if not cache_entry.locked:
self.logger.debug(
f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
f"Removing {model_key} from RAM cache to free at least {(size/GB):.2f} GB (-{(cache_entry.size/GB):.2f} GB)"
)
current_size -= cache_entry.size
models_cleared += 1

View File

@@ -0,0 +1,246 @@
# Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team
"""Class for Flux model loading in InvokeAI."""
from pathlib import Path
from typing import Optional
import accelerate
import torch
from safetensors.torch import load_file
from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.flux.util import ae_params, params
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.config import (
CheckpointConfigBase,
CLIPEmbedDiffusersConfig,
MainBnbQuantized4bCheckpointConfig,
MainCheckpointConfig,
T5EncoderBnbQuantizedLlmInt8bConfig,
T5EncoderConfig,
VAECheckpointConfig,
)
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.util.model_util import (
convert_bundle_to_flux_transformer_checkpoint,
)
from invokeai.backend.util.silence_warnings import SilenceWarnings
try:
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
bnb_available = True
except ImportError:
bnb_available = False
app_config = get_config()
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.VAE, format=ModelFormat.Checkpoint)
class FluxVAELoader(ModelLoader):
"""Class to load VAE models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, VAECheckpointConfig):
raise ValueError("Only VAECheckpointConfig models are currently supported here.")
model_path = Path(config.path)
with SilenceWarnings():
model = AutoEncoder(ae_params[config.config_path])
sd = load_file(model_path)
model.load_state_dict(sd, assign=True)
model.to(dtype=self._torch_dtype)
return model
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPEmbed, format=ModelFormat.Diffusers)
class ClipCheckpointModel(ModelLoader):
"""Class to load main models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, CLIPEmbedDiffusersConfig):
raise ValueError("Only CLIPEmbedDiffusersConfig models are currently supported here.")
match submodel_type:
case SubModelType.Tokenizer:
return CLIPTokenizer.from_pretrained(Path(config.path) / "tokenizer")
case SubModelType.TextEncoder:
return CLIPTextModel.from_pretrained(Path(config.path) / "text_encoder")
raise ValueError(
f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.BnbQuantizedLlmInt8b)
class BnbQuantizedLlmInt8bCheckpointModel(ModelLoader):
"""Class to load main models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, T5EncoderBnbQuantizedLlmInt8bConfig):
raise ValueError("Only T5EncoderBnbQuantizedLlmInt8bConfig models are currently supported here.")
if not bnb_available:
raise ImportError(
"The bnb modules are not available. Please install bitsandbytes if available on your platform."
)
match submodel_type:
case SubModelType.Tokenizer2:
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
case SubModelType.TextEncoder2:
te2_model_path = Path(config.path) / "text_encoder_2"
model_config = AutoConfig.from_pretrained(te2_model_path)
with accelerate.init_empty_weights():
model = AutoModelForTextEncoding.from_config(model_config)
model = quantize_model_llm_int8(model, modules_to_not_convert=set())
state_dict_path = te2_model_path / "bnb_llm_int8_model.safetensors"
state_dict = load_file(state_dict_path)
self._load_state_dict_into_t5(model, state_dict)
return model
raise ValueError(
f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
)
@classmethod
def _load_state_dict_into_t5(cls, model: T5EncoderModel, state_dict: dict[str, torch.Tensor]):
# There is a shared reference to a single weight tensor in the model.
# Both "encoder.embed_tokens.weight" and "shared.weight" refer to the same tensor, so only the latter should
# be present in the state_dict.
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False, assign=True)
assert len(unexpected_keys) == 0
assert set(missing_keys) == {"encoder.embed_tokens.weight"}
# Assert that the layers we expect to be shared are actually shared.
assert model.encoder.embed_tokens.weight is model.shared.weight
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder)
class T5EncoderCheckpointModel(ModelLoader):
"""Class to load main models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, T5EncoderConfig):
raise ValueError("Only T5EncoderConfig models are currently supported here.")
match submodel_type:
case SubModelType.Tokenizer2:
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
case SubModelType.TextEncoder2:
return T5EncoderModel.from_pretrained(Path(config.path) / "text_encoder_2")
raise ValueError(
f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
)
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.Checkpoint)
class FluxCheckpointModel(ModelLoader):
"""Class to load main models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, CheckpointConfigBase):
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
match submodel_type:
case SubModelType.Transformer:
return self._load_from_singlefile(config)
raise ValueError(
f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
)
def _load_from_singlefile(
self,
config: AnyModelConfig,
) -> AnyModel:
assert isinstance(config, MainCheckpointConfig)
model_path = Path(config.path)
with SilenceWarnings():
model = Flux(params[config.config_path])
sd = load_file(model_path)
if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd:
sd = convert_bundle_to_flux_transformer_checkpoint(sd)
new_sd_size = sum([ten.nelement() * torch.bfloat16.itemsize for ten in sd.values()])
self._ram_cache.make_room(new_sd_size)
for k in sd.keys():
# We need to cast to bfloat16 due to it being the only currently supported dtype for inference
sd[k] = sd[k].to(torch.bfloat16)
model.load_state_dict(sd, assign=True)
return model
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.BnbQuantizednf4b)
class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
"""Class to load main models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, CheckpointConfigBase):
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
match submodel_type:
case SubModelType.Transformer:
return self._load_from_singlefile(config)
raise ValueError(
f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
)
def _load_from_singlefile(
self,
config: AnyModelConfig,
) -> AnyModel:
assert isinstance(config, MainBnbQuantized4bCheckpointConfig)
if not bnb_available:
raise ImportError(
"The bnb modules are not available. Please install bitsandbytes if available on your platform."
)
model_path = Path(config.path)
with SilenceWarnings():
with accelerate.init_empty_weights():
model = Flux(params[config.config_path])
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
sd = load_file(model_path)
if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd:
sd = convert_bundle_to_flux_transformer_checkpoint(sd)
model.load_state_dict(sd, assign=True)
return model

View File

@@ -78,7 +78,12 @@ class GenericDiffusersLoader(ModelLoader):
# TO DO: Add exception handling
def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: # fix with correct type
if module in ["diffusers", "transformers"]:
if module in [
"diffusers",
"transformers",
"invokeai.backend.quantization.fast_quantized_transformers_model",
"invokeai.backend.quantization.fast_quantized_diffusion_model",
]:
res_type = sys.modules[module]
else:
res_type = sys.modules["diffusers"].pipelines

View File

@@ -36,8 +36,18 @@ VARIANT_TO_IN_CHANNEL_MAP = {
}
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(
base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Diffusers
)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(
base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Checkpoint
)
class StableDiffusionDiffusersModel(GenericDiffusersLoader):
"""Class to load main models."""

View File

@@ -9,7 +9,7 @@ from typing import Optional
import torch
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from transformers import CLIPTokenizer
from transformers import CLIPTokenizer, T5Tokenizer, T5TokenizerFast
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
@@ -50,6 +50,17 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
),
):
return model.calc_size()
elif isinstance(
model,
(
T5TokenizerFast,
T5Tokenizer,
),
):
# HACK(ryand): len(model) just returns the vocabulary size, so this is blatantly wrong. It should be small
# relative to the text encoder that it's used with, so shouldn't matter too much, but we should fix this at some
# point.
return len(model)
else:
# TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the
# supported model types.

View File

@@ -95,6 +95,7 @@ class ModelProbe(object):
}
CLASS2TYPE = {
"FluxPipeline": ModelType.Main,
"StableDiffusionPipeline": ModelType.Main,
"StableDiffusionInpaintPipeline": ModelType.Main,
"StableDiffusionXLPipeline": ModelType.Main,
@@ -106,6 +107,9 @@ class ModelProbe(object):
"ControlNetModel": ModelType.ControlNet,
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
"T2IAdapter": ModelType.T2IAdapter,
"CLIPModel": ModelType.CLIPEmbed,
"CLIPTextModel": ModelType.CLIPEmbed,
"T5EncoderModel": ModelType.T5Encoder,
}
@classmethod
@@ -161,7 +165,7 @@ class ModelProbe(object):
fields["description"] = (
fields.get("description") or f"{fields['base'].value} {model_type.value} model {fields['name']}"
)
fields["format"] = fields.get("format") or probe.get_format()
fields["format"] = ModelFormat(fields.get("format")) if "format" in fields else probe.get_format()
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)
fields["default_settings"] = fields.get("default_settings")
@@ -176,10 +180,10 @@ class ModelProbe(object):
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
# additional fields needed for main and controlnet models
if (
fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE]
and fields["format"] is ModelFormat.Checkpoint
):
if fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE] and fields["format"] in [
ModelFormat.Checkpoint,
ModelFormat.BnbQuantizednf4b,
]:
ckpt_config_path = cls._get_checkpoint_config_path(
model_path,
model_type=fields["type"],
@@ -222,7 +226,19 @@ class ModelProbe(object):
ckpt = ckpt.get("state_dict", ckpt)
for key in [str(k) for k in ckpt.keys()]:
if key.startswith(("cond_stage_model.", "first_stage_model.", "model.diffusion_model.")):
if key.startswith(
(
"cond_stage_model.",
"first_stage_model.",
"model.diffusion_model.",
# FLUX models in the official BFL format contain keys with the "double_blocks." prefix.
"double_blocks.",
# Some FLUX checkpoint files contain transformer keys prefixed with "model.diffusion_model".
# This prefix is typically used to distinguish between multiple models bundled in a single file.
"model.diffusion_model.double_blocks.",
)
):
# Keys starting with double_blocks are associated with Flux models
return ModelType.Main
elif key.startswith(("encoder.conv_in", "decoder.conv_in")):
return ModelType.VAE
@@ -280,9 +296,16 @@ class ModelProbe(object):
if (folder_path / "image_encoder.txt").exists():
return ModelType.IPAdapter
i = folder_path / "model_index.json"
c = folder_path / "config.json"
config_path = i if i.exists() else c if c.exists() else None
config_path = None
for p in [
folder_path / "model_index.json", # pipeline
folder_path / "config.json", # most diffusers
folder_path / "text_encoder_2" / "config.json", # T5 text encoder
folder_path / "text_encoder" / "config.json", # T5 CLIP
]:
if p.exists():
config_path = p
break
if config_path:
with open(config_path, "r") as file:
@@ -321,10 +344,30 @@ class ModelProbe(object):
return possible_conf.absolute()
if model_type is ModelType.Main:
config_file = LEGACY_CONFIGS[base_type][variant_type]
if isinstance(config_file, dict): # need another tier for sd-2.x models
config_file = config_file[prediction_type]
config_file = f"stable-diffusion/{config_file}"
if base_type == BaseModelType.Flux:
# TODO: Decide between dev/schnell
checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
state_dict = checkpoint.get("state_dict") or checkpoint
if (
"guidance_in.out_layer.weight" in state_dict
or "model.diffusion_model.guidance_in.out_layer.weight" in state_dict
):
# For flux, this is a key in invokeai.backend.flux.util.params
# Due to model type and format being the descriminator for model configs this
# is used rather than attempting to support flux with separate model types and format
# If changed in the future, please fix me
config_file = "flux-dev"
else:
# For flux, this is a key in invokeai.backend.flux.util.params
# Due to model type and format being the discriminator for model configs this
# is used rather than attempting to support flux with separate model types and format
# If changed in the future, please fix me
config_file = "flux-schnell"
else:
config_file = LEGACY_CONFIGS[base_type][variant_type]
if isinstance(config_file, dict): # need another tier for sd-2.x models
config_file = config_file[prediction_type]
config_file = f"stable-diffusion/{config_file}"
elif model_type is ModelType.ControlNet:
config_file = (
"controlnet/cldm_v15.yaml"
@@ -333,7 +376,13 @@ class ModelProbe(object):
)
elif model_type is ModelType.VAE:
config_file = (
"stable-diffusion/v1-inference.yaml"
# For flux, this is a key in invokeai.backend.flux.util.ae_params
# Due to model type and format being the descriminator for model configs this
# is used rather than attempting to support flux with separate model types and format
# If changed in the future, please fix me
"flux"
if base_type is BaseModelType.Flux
else "stable-diffusion/v1-inference.yaml"
if base_type is BaseModelType.StableDiffusion1
else "stable-diffusion/sd_xl_base.yaml"
if base_type is BaseModelType.StableDiffusionXL
@@ -416,11 +465,18 @@ class CheckpointProbeBase(ProbeBase):
self.checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
def get_format(self) -> ModelFormat:
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
if (
"double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4" in state_dict
or "model.diffusion_model.double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4" in state_dict
):
return ModelFormat.BnbQuantizednf4b
return ModelFormat("checkpoint")
def get_variant_type(self) -> ModelVariantType:
model_type = ModelProbe.get_model_type_from_checkpoint(self.model_path, self.checkpoint)
if model_type != ModelType.Main:
base_type = self.get_base_type()
if model_type != ModelType.Main or base_type == BaseModelType.Flux:
return ModelVariantType.Normal
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
@@ -440,6 +496,11 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
state_dict = self.checkpoint.get("state_dict") or checkpoint
if (
"double_blocks.0.img_attn.norm.key_norm.scale" in state_dict
or "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in state_dict
):
return BaseModelType.Flux
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
return BaseModelType.StableDiffusion1
@@ -482,6 +543,7 @@ class VaeCheckpointProbe(CheckpointProbeBase):
(r"xl", BaseModelType.StableDiffusionXL),
(r"sd2", BaseModelType.StableDiffusion2),
(r"vae", BaseModelType.StableDiffusion1),
(r"FLUX.1-schnell_ae", BaseModelType.Flux),
]:
if re.search(regexp, self.model_path.name, re.IGNORECASE):
return basetype
@@ -713,6 +775,30 @@ class TextualInversionFolderProbe(FolderProbeBase):
return TextualInversionCheckpointProbe(path).get_base_type()
class T5EncoderFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
return BaseModelType.Any
def get_format(self) -> ModelFormat:
path = self.model_path / "text_encoder_2"
if (path / "model.safetensors.index.json").exists():
return ModelFormat.T5Encoder
files = list(path.glob("*.safetensors"))
if len(files) == 0:
raise InvalidModelConfigException(f"{self.model_path.as_posix()}: no .safetensors files found")
# shortcut: look for the quantization in the name
if any(x for x in files if "llm_int8" in x.as_posix()):
return ModelFormat.BnbQuantizedLlmInt8b
# more reliable path: probe contents for a 'SCB' key
ckpt = read_checkpoint_meta(files[0], scan=True)
if any("SCB" in x for x in ckpt.keys()):
return ModelFormat.BnbQuantizedLlmInt8b
raise InvalidModelConfigException(f"{self.model_path.as_posix()}: unknown model format")
class ONNXFolderProbe(PipelineFolderProbe):
def get_base_type(self) -> BaseModelType:
# Due to the way the installer is set up, the configuration file for safetensors
@@ -805,6 +891,11 @@ class CLIPVisionFolderProbe(FolderProbeBase):
return BaseModelType.Any
class CLIPEmbedFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
return BaseModelType.Any
class SpandrelImageToImageFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
@@ -835,8 +926,10 @@ ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.T5Encoder, T5EncoderFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.CLIPEmbed, CLIPEmbedFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelImageToImageFolderProbe)

View File

@@ -2,7 +2,7 @@ from typing import Optional
from pydantic import BaseModel
from invokeai.backend.model_manager.config import BaseModelType, ModelType
from invokeai.backend.model_manager.config import BaseModelType, ModelFormat, ModelType
class StarterModelWithoutDependencies(BaseModel):
@@ -11,6 +11,7 @@ class StarterModelWithoutDependencies(BaseModel):
name: str
base: BaseModelType
type: ModelType
format: Optional[ModelFormat] = None
is_installed: bool = False
@@ -51,10 +52,76 @@ cyberrealistic_negative = StarterModel(
type=ModelType.TextualInversion,
)
t5_base_encoder = StarterModel(
name="t5_base_encoder",
base=BaseModelType.Any,
source="InvokeAI/t5-v1_1-xxl::bfloat16",
description="T5-XXL text encoder (used in FLUX pipelines). ~8GB",
type=ModelType.T5Encoder,
)
t5_8b_quantized_encoder = StarterModel(
name="t5_bnb_int8_quantized_encoder",
base=BaseModelType.Any,
source="InvokeAI/t5-v1_1-xxl::bnb_llm_int8",
description="T5-XXL text encoder with bitsandbytes LLM.int8() quantization (used in FLUX pipelines). ~5GB",
type=ModelType.T5Encoder,
format=ModelFormat.BnbQuantizedLlmInt8b,
)
clip_l_encoder = StarterModel(
name="clip-vit-large-patch14",
base=BaseModelType.Any,
source="InvokeAI/clip-vit-large-patch14-text-encoder::bfloat16",
description="CLIP-L text encoder (used in FLUX pipelines). ~250MB",
type=ModelType.CLIPEmbed,
)
flux_vae = StarterModel(
name="FLUX.1-schnell_ae",
base=BaseModelType.Flux,
source="black-forest-labs/FLUX.1-schnell::ae.safetensors",
description="FLUX VAE compatible with both schnell and dev variants.",
type=ModelType.VAE,
)
# List of starter models, displayed on the frontend.
# The order/sort of this list is not changed by the frontend - set it how you want it here.
STARTER_MODELS: list[StarterModel] = [
# region: Main
StarterModel(
name="FLUX Schnell (Quantized)",
base=BaseModelType.Flux,
source="InvokeAI/flux_schnell::transformer/bnb_nf4/flux1-schnell-bnb_nf4.safetensors",
description="FLUX schnell transformer quantized to bitsandbytes NF4 format. Total size with dependencies: ~12GB",
type=ModelType.Main,
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
),
StarterModel(
name="FLUX Dev (Quantized)",
base=BaseModelType.Flux,
source="InvokeAI/flux_dev::transformer/bnb_nf4/flux1-dev-bnb_nf4.safetensors",
description="FLUX dev transformer quantized to bitsandbytes NF4 format. Total size with dependencies: ~12GB",
type=ModelType.Main,
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
),
StarterModel(
name="FLUX Schnell",
base=BaseModelType.Flux,
source="InvokeAI/flux_schnell::transformer/base/flux1-schnell.safetensors",
description="FLUX schnell transformer in bfloat16. Total size with dependencies: ~33GB",
type=ModelType.Main,
dependencies=[t5_base_encoder, flux_vae, clip_l_encoder],
),
StarterModel(
name="FLUX Dev",
base=BaseModelType.Flux,
source="InvokeAI/flux_dev::transformer/base/flux1-dev.safetensors",
description="FLUX dev transformer in bfloat16. Total size with dependencies: ~33GB",
type=ModelType.Main,
dependencies=[t5_base_encoder, flux_vae, clip_l_encoder],
),
StarterModel(
name="CyberRealistic v4.1",
base=BaseModelType.StableDiffusion1,
@@ -125,6 +192,7 @@ STARTER_MODELS: list[StarterModel] = [
# endregion
# region VAE
sdxl_fp16_vae_fix,
flux_vae,
# endregion
# region LoRA
StarterModel(
@@ -450,6 +518,11 @@ STARTER_MODELS: list[StarterModel] = [
type=ModelType.SpandrelImageToImage,
),
# endregion
# region TextEncoders
t5_base_encoder,
t5_8b_quantized_encoder,
clip_l_encoder,
# endregion
]
assert len(STARTER_MODELS) == len({m.source for m in STARTER_MODELS}), "Duplicate starter models"

View File

@@ -133,3 +133,29 @@ def lora_token_vector_length(checkpoint: Dict[str, torch.Tensor]) -> Optional[in
break
return lora_token_vector_length
def convert_bundle_to_flux_transformer_checkpoint(
transformer_state_dict: dict[str, torch.Tensor],
) -> dict[str, torch.Tensor]:
original_state_dict: dict[str, torch.Tensor] = {}
keys_to_remove: list[str] = []
for k, v in transformer_state_dict.items():
if not k.startswith("model.diffusion_model"):
keys_to_remove.append(k) # This can be removed in the future if we only want to delete transformer keys
continue
if k.endswith("scale"):
# Scale math must be done at bfloat16 due to our current flux model
# support limitations at inference time
v = v.to(dtype=torch.bfloat16)
new_key = k.replace("model.diffusion_model.", "")
original_state_dict[new_key] = v
keys_to_remove.append(k)
# Remove processed keys from the original dictionary, leaving others in case
# other model state dicts need to be pulled
for k in keys_to_remove:
del transformer_state_dict[k]
return original_state_dict

View File

@@ -54,6 +54,7 @@ def filter_files(
"lora_weights.safetensors",
"weights.pb",
"onnx_data",
"spiece.model", # Added for `black-forest-labs/FLUX.1-schnell`.
)
):
paths.append(file)
@@ -62,13 +63,13 @@ def filter_files(
# downloading random checkpoints that might also be in the repo. However there is no guarantee
# that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models
# will adhere to this naming convention, so this is an area to be careful of.
elif re.search(r"model(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name):
elif re.search(r"model.*\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name):
paths.append(file)
# limit search to subfolder if requested
if subfolder:
subfolder = root / subfolder
paths = [x for x in paths if x.parent == Path(subfolder)]
paths = [x for x in paths if Path(subfolder) in x.parents]
# _filter_by_variant uniquifies the paths and returns a set
return sorted(_filter_by_variant(paths, variant))
@@ -97,7 +98,9 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
if variant == ModelRepoVariant.Flax:
result.add(path)
elif path.suffix in [".json", ".txt"]:
# Note: '.model' was added to support:
# https://huggingface.co/black-forest-labs/FLUX.1-schnell/blob/768d12a373ed5cc9ef9a9dea7504dc09fcc14842/tokenizer_2/spiece.model
elif path.suffix in [".json", ".txt", ".model"]:
result.add(path)
elif variant in [
@@ -140,6 +143,23 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
continue
for candidate_list in subfolder_weights.values():
# Check if at least one of the files has the explicit fp16 variant.
at_least_one_fp16 = False
for candidate in candidate_list:
if len(candidate.path.suffixes) == 2 and candidate.path.suffixes[0] == ".fp16":
at_least_one_fp16 = True
break
if not at_least_one_fp16:
# If none of the candidates in this candidate_list have the explicit fp16 variant label, then this
# candidate_list probably doesn't adhere to the variant naming convention that we expected. In this case,
# we'll simply keep all the candidates. An example of a model that hits this case is
# `black-forest-labs/FLUX.1-schnell` (as of commit 012d2fd).
for candidate in candidate_list:
result.add(candidate.path)
# The candidate_list seems to have the expected variant naming convention. We'll select the highest scoring
# candidate.
highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score)
if highest_score_candidate:
result.add(highest_score_candidate.path)

View File

@@ -0,0 +1,135 @@
import bitsandbytes as bnb
import torch
# This file contains utils for working with models that use bitsandbytes LLM.int8() quantization.
# The utils in this file are partially inspired by:
# https://github.com/Lightning-AI/pytorch-lightning/blob/1551a16b94f5234a4a78801098f64d0732ef5cb5/src/lightning/fabric/plugins/precision/bitsandbytes.py
# NOTE(ryand): All of the custom state_dict manipulation logic in this file is pretty hacky. This could be made much
# cleaner by re-implementing bnb.nn.Linear8bitLt with proper use of buffers and less magic. But, for now, we try to
# stick close to the bitsandbytes classes to make interoperability easier with other models that might use bitsandbytes.
class InvokeInt8Params(bnb.nn.Int8Params):
"""We override cuda() to avoid re-quantizing the weights in the following cases:
- We loaded quantized weights from a state_dict on the cpu, and then moved the model to the gpu.
- We are moving the model back-and-forth between the cpu and gpu.
"""
def cuda(self, device):
if self.has_fp16_weights:
return super().cuda(device)
elif self.CB is not None and self.SCB is not None:
self.data = self.data.cuda()
self.CB = self.data
self.SCB = self.SCB.cuda()
else:
# we store the 8-bit rows-major weight
# we convert this weight to the turning/ampere weight during the first inference pass
B = self.data.contiguous().half().cuda(device)
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
del CBt
del SCBt
self.data = CB
self.CB = CB
self.SCB = SCB
return self
class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
def _load_from_state_dict(
self,
state_dict: dict[str, torch.Tensor],
prefix: str,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
weight = state_dict.pop(prefix + "weight")
bias = state_dict.pop(prefix + "bias", None)
# See `bnb.nn.Linear8bitLt._save_to_state_dict()` for the serialization logic of SCB and weight_format.
scb = state_dict.pop(prefix + "SCB", None)
# Currently, we only support weight_format=0.
weight_format = state_dict.pop(prefix + "weight_format", None)
assert weight_format == 0
# TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs`
# rather than raising an exception to correctly implement this API.
assert len(state_dict) == 0
if scb is not None:
# We are loading a pre-quantized state dict.
self.weight = InvokeInt8Params(
data=weight,
requires_grad=self.weight.requires_grad,
has_fp16_weights=False,
# Note: After quantization, CB is the same as weight.
CB=weight,
SCB=scb,
)
self.bias = bias if bias is None else torch.nn.Parameter(bias)
else:
# We are loading a non-quantized state dict.
# We could simply call the `super()._load_from_state_dict()` method here, but then we wouldn't be able to
# load from a state_dict into a model on the "meta" device. Attempting to load into a model on the "meta"
# device requires setting `assign=True`, doing this with the default `super()._load_from_state_dict()`
# implementation causes `Params4Bit` to be replaced by a `torch.nn.Parameter`. By initializing a new
# `Params4bit` object, we work around this issue. It's a bit hacky, but it gets the job done.
self.weight = InvokeInt8Params(
data=weight,
requires_grad=self.weight.requires_grad,
has_fp16_weights=False,
CB=None,
SCB=None,
)
self.bias = bias if bias is None else torch.nn.Parameter(bias)
# Reset the state. The persisted fields are based on the initialization behaviour in
# `bnb.nn.Linear8bitLt.__init__()`.
new_state = bnb.MatmulLtState()
new_state.threshold = self.state.threshold
new_state.has_fp16_weights = False
new_state.use_pool = self.state.use_pool
self.state = new_state
def _convert_linear_layers_to_llm_8bit(
module: torch.nn.Module, ignore_modules: set[str], outlier_threshold: float, prefix: str = ""
) -> None:
"""Convert all linear layers in the module to bnb.nn.Linear8bitLt layers."""
for name, child in module.named_children():
fullname = f"{prefix}.{name}" if prefix else name
if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules):
has_bias = child.bias is not None
replacement = InvokeLinear8bitLt(
child.in_features,
child.out_features,
bias=has_bias,
has_fp16_weights=False,
threshold=outlier_threshold,
)
replacement.weight.data = child.weight.data
if has_bias:
replacement.bias.data = child.bias.data
replacement.requires_grad_(False)
module.__setattr__(name, replacement)
else:
_convert_linear_layers_to_llm_8bit(
child, ignore_modules, outlier_threshold=outlier_threshold, prefix=fullname
)
def quantize_model_llm_int8(model: torch.nn.Module, modules_to_not_convert: set[str], outlier_threshold: float = 6.0):
"""Apply bitsandbytes LLM.8bit() quantization to the model."""
_convert_linear_layers_to_llm_8bit(
module=model, ignore_modules=modules_to_not_convert, outlier_threshold=outlier_threshold
)
return model

View File

@@ -0,0 +1,156 @@
import bitsandbytes as bnb
import torch
# This file contains utils for working with models that use bitsandbytes NF4 quantization.
# The utils in this file are partially inspired by:
# https://github.com/Lightning-AI/pytorch-lightning/blob/1551a16b94f5234a4a78801098f64d0732ef5cb5/src/lightning/fabric/plugins/precision/bitsandbytes.py
# NOTE(ryand): All of the custom state_dict manipulation logic in this file is pretty hacky. This could be made much
# cleaner by re-implementing bnb.nn.LinearNF4 with proper use of buffers and less magic. But, for now, we try to stick
# close to the bitsandbytes classes to make interoperability easier with other models that might use bitsandbytes.
class InvokeLinearNF4(bnb.nn.LinearNF4):
"""A class that extends `bnb.nn.LinearNF4` to add the following functionality:
- Ability to load Linear NF4 layers from a pre-quantized state_dict.
- Ability to load Linear NF4 layers from a state_dict when the model is on the "meta" device.
"""
def _load_from_state_dict(
self,
state_dict: dict[str, torch.Tensor],
prefix: str,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
"""This method is based on the logic in the bitsandbytes serialization unit tests for `Linear4bit`:
https://github.com/bitsandbytes-foundation/bitsandbytes/blob/6d714a5cce3db5bd7f577bc447becc7a92d5ccc7/tests/test_linear4bit.py#L52-L71
"""
weight = state_dict.pop(prefix + "weight")
bias = state_dict.pop(prefix + "bias", None)
# We expect the remaining keys to be quant_state keys.
quant_state_sd = state_dict
# During serialization, the quant_state is stored as subkeys of "weight." (See
# `bnb.nn.LinearNF4._save_to_state_dict()`). We validate that they at least have the correct prefix.
# TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs`
# rather than raising an exception to correctly implement this API.
assert all(k.startswith(prefix + "weight.") for k in quant_state_sd.keys())
if len(quant_state_sd) > 0:
# We are loading a pre-quantized state dict.
self.weight = bnb.nn.Params4bit.from_prequantized(
data=weight, quantized_stats=quant_state_sd, device=weight.device
)
self.bias = bias if bias is None else torch.nn.Parameter(bias, requires_grad=False)
else:
# We are loading a non-quantized state dict.
# We could simply call the `super()._load_from_state_dict()` method here, but then we wouldn't be able to
# load from a state_dict into a model on the "meta" device. Attempting to load into a model on the "meta"
# device requires setting `assign=True`, doing this with the default `super()._load_from_state_dict()`
# implementation causes `Params4Bit` to be replaced by a `torch.nn.Parameter`. By initializing a new
# `Params4bit` object, we work around this issue. It's a bit hacky, but it gets the job done.
self.weight = bnb.nn.Params4bit(
data=weight,
requires_grad=self.weight.requires_grad,
compress_statistics=self.weight.compress_statistics,
quant_type=self.weight.quant_type,
quant_storage=self.weight.quant_storage,
module=self,
)
self.bias = bias if bias is None else torch.nn.Parameter(bias)
def _replace_param(
param: torch.nn.Parameter | bnb.nn.Params4bit,
data: torch.Tensor,
) -> torch.nn.Parameter:
"""A helper function to replace the data of a model parameter with new data in a way that allows replacing params on
the "meta" device.
Supports both `torch.nn.Parameter` and `bnb.nn.Params4bit` parameters.
"""
if param.device.type == "meta":
# Doing `param.data = data` raises a RuntimeError if param.data was on the "meta" device, so we need to
# re-create the param instead of overwriting the data.
if isinstance(param, bnb.nn.Params4bit):
return bnb.nn.Params4bit(
data,
requires_grad=data.requires_grad,
quant_state=param.quant_state,
compress_statistics=param.compress_statistics,
quant_type=param.quant_type,
)
return torch.nn.Parameter(data, requires_grad=data.requires_grad)
param.data = data
return param
def _convert_linear_layers_to_nf4(
module: torch.nn.Module,
ignore_modules: set[str],
compute_dtype: torch.dtype,
compress_statistics: bool = False,
prefix: str = "",
) -> None:
"""Convert all linear layers in the model to NF4 quantized linear layers.
Args:
module: All linear layers in this module will be converted.
ignore_modules: A set of module prefixes to ignore when converting linear layers.
compute_dtype: The dtype to use for computation in the quantized linear layers.
compress_statistics: Whether to enable nested quantization (aka double quantization) where the quantization
constants from the first quantization are quantized again.
prefix: The prefix of the current module in the model. Used to call this function recursively.
"""
for name, child in module.named_children():
fullname = f"{prefix}.{name}" if prefix else name
if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules):
has_bias = child.bias is not None
replacement = InvokeLinearNF4(
child.in_features,
child.out_features,
bias=has_bias,
compute_dtype=compute_dtype,
compress_statistics=compress_statistics,
)
if has_bias:
replacement.bias = _replace_param(replacement.bias, child.bias.data)
replacement.weight = _replace_param(replacement.weight, child.weight.data)
replacement.requires_grad_(False)
module.__setattr__(name, replacement)
else:
_convert_linear_layers_to_nf4(child, ignore_modules, compute_dtype=compute_dtype, prefix=fullname)
def quantize_model_nf4(model: torch.nn.Module, modules_to_not_convert: set[str], compute_dtype: torch.dtype):
"""Apply bitsandbytes nf4 quantization to the model.
You likely want to call this function inside a `accelerate.init_empty_weights()` context.
Example usage:
```
# Initialize the model from a config on the meta device.
with accelerate.init_empty_weights():
model = ModelClass.from_config(...)
# Add NF4 quantization linear layers to the model - still on the meta device.
with accelerate.init_empty_weights():
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.float16)
# Load a state_dict into the model. (Could be either a prequantized or non-quantized state_dict.)
model.load_state_dict(state_dict, strict=True, assign=True)
# Move the model to the "cuda" device. If the model was non-quantized, this is where the weight quantization takes
# place.
model.to("cuda")
```
"""
_convert_linear_layers_to_nf4(module=model, ignore_modules=modules_to_not_convert, compute_dtype=compute_dtype)
return model

View File

@@ -0,0 +1,79 @@
from pathlib import Path
import accelerate
from safetensors.torch import load_file, save_file
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.util import params
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
from invokeai.backend.quantization.scripts.load_flux_model_bnb_nf4 import log_time
def main():
"""A script for quantizing a FLUX transformer model using the bitsandbytes LLM.int8() quantization method.
This script is primarily intended for reference. The script params (e.g. the model_path, modules_to_not_convert,
etc.) are hardcoded and would need to be modified for other use cases.
"""
# Load the FLUX transformer model onto the meta device.
model_path = Path(
"/data/invokeai/models/.download_cache/https__huggingface.co_black-forest-labs_flux.1-schnell_resolve_main_flux1-schnell.safetensors/flux1-schnell.safetensors"
)
with log_time("Intialize FLUX transformer on meta device"):
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
p = params["flux-schnell"]
# Initialize the model on the "meta" device.
with accelerate.init_empty_weights():
model = Flux(p)
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.
modules_to_not_convert: set[str] = set()
model_int8_path = model_path.parent / "bnb_llm_int8.safetensors"
if model_int8_path.exists():
# The quantized model already exists, load it and return it.
print(f"A pre-quantized model already exists at '{model_int8_path}'. Attempting to load it...")
# Replace the linear layers with LLM.int8() quantized linear layers (still on the meta device).
with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights():
model = quantize_model_llm_int8(model, modules_to_not_convert=modules_to_not_convert)
with log_time("Load state dict into model"):
sd = load_file(model_int8_path)
model.load_state_dict(sd, strict=True, assign=True)
with log_time("Move model to cuda"):
model = model.to("cuda")
print(f"Successfully loaded pre-quantized model from '{model_int8_path}'.")
else:
# The quantized model does not exist, quantize the model and save it.
print(f"No pre-quantized model found at '{model_int8_path}'. Quantizing the model...")
with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights():
model = quantize_model_llm_int8(model, modules_to_not_convert=modules_to_not_convert)
with log_time("Load state dict into model"):
state_dict = load_file(model_path)
# TODO(ryand): Cast the state_dict to the appropriate dtype?
model.load_state_dict(state_dict, strict=True, assign=True)
with log_time("Move model to cuda and quantize"):
model = model.to("cuda")
with log_time("Save quantized model"):
model_int8_path.parent.mkdir(parents=True, exist_ok=True)
save_file(model.state_dict(), model_int8_path)
print(f"Successfully quantized and saved model to '{model_int8_path}'.")
assert isinstance(model, Flux)
return model
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,96 @@
import time
from contextlib import contextmanager
from pathlib import Path
import accelerate
import torch
from safetensors.torch import load_file, save_file
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.util import params
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
@contextmanager
def log_time(name: str):
"""Helper context manager to log the time taken by a block of code."""
start = time.time()
try:
yield None
finally:
end = time.time()
print(f"'{name}' took {end - start:.4f} secs")
def main():
"""A script for quantizing a FLUX transformer model using the bitsandbytes NF4 quantization method.
This script is primarily intended for reference. The script params (e.g. the model_path, modules_to_not_convert,
etc.) are hardcoded and would need to be modified for other use cases.
"""
model_path = Path(
"/data/invokeai/models/.download_cache/https__huggingface.co_black-forest-labs_flux.1-schnell_resolve_main_flux1-schnell.safetensors/flux1-schnell.safetensors"
)
# inference_dtype = torch.bfloat16
with log_time("Intialize FLUX transformer on meta device"):
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
p = params["flux-schnell"]
# Initialize the model on the "meta" device.
with accelerate.init_empty_weights():
model = Flux(p)
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.
modules_to_not_convert: set[str] = set()
model_nf4_path = model_path.parent / "bnb_nf4.safetensors"
if model_nf4_path.exists():
# The quantized model already exists, load it and return it.
print(f"A pre-quantized model already exists at '{model_nf4_path}'. Attempting to load it...")
# Replace the linear layers with NF4 quantized linear layers (still on the meta device).
with log_time("Replace linear layers with NF4 layers"), accelerate.init_empty_weights():
model = quantize_model_nf4(
model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16
)
with log_time("Load state dict into model"):
state_dict = load_file(model_nf4_path)
model.load_state_dict(state_dict, strict=True, assign=True)
with log_time("Move model to cuda"):
model = model.to("cuda")
print(f"Successfully loaded pre-quantized model from '{model_nf4_path}'.")
else:
# The quantized model does not exist, quantize the model and save it.
print(f"No pre-quantized model found at '{model_nf4_path}'. Quantizing the model...")
with log_time("Replace linear layers with NF4 layers"), accelerate.init_empty_weights():
model = quantize_model_nf4(
model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16
)
with log_time("Load state dict into model"):
state_dict = load_file(model_path)
# TODO(ryand): Cast the state_dict to the appropriate dtype?
model.load_state_dict(state_dict, strict=True, assign=True)
with log_time("Move model to cuda and quantize"):
model = model.to("cuda")
with log_time("Save quantized model"):
model_nf4_path.parent.mkdir(parents=True, exist_ok=True)
save_file(model.state_dict(), model_nf4_path)
print(f"Successfully quantized and saved model to '{model_nf4_path}'.")
assert isinstance(model, Flux)
return model
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,92 @@
from pathlib import Path
import accelerate
from safetensors.torch import load_file, save_file
from transformers import AutoConfig, AutoModelForTextEncoding, T5EncoderModel
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
from invokeai.backend.quantization.scripts.load_flux_model_bnb_nf4 import log_time
def load_state_dict_into_t5(model: T5EncoderModel, state_dict: dict):
# There is a shared reference to a single weight tensor in the model.
# Both "encoder.embed_tokens.weight" and "shared.weight" refer to the same tensor, so only the latter should
# be present in the state_dict.
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False, assign=True)
assert len(unexpected_keys) == 0
assert set(missing_keys) == {"encoder.embed_tokens.weight"}
# Assert that the layers we expect to be shared are actually shared.
assert model.encoder.embed_tokens.weight is model.shared.weight
def main():
"""A script for quantizing a T5 text encoder model using the bitsandbytes LLM.int8() quantization method.
This script is primarily intended for reference. The script params (e.g. the model_path, modules_to_not_convert,
etc.) are hardcoded and would need to be modified for other use cases.
"""
model_path = Path("/data/misc/text_encoder_2")
with log_time("Intialize T5 on meta device"):
model_config = AutoConfig.from_pretrained(model_path)
with accelerate.init_empty_weights():
model = AutoModelForTextEncoding.from_config(model_config)
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.
modules_to_not_convert: set[str] = set()
model_int8_path = model_path / "bnb_llm_int8.safetensors"
if model_int8_path.exists():
# The quantized model already exists, load it and return it.
print(f"A pre-quantized model already exists at '{model_int8_path}'. Attempting to load it...")
# Replace the linear layers with LLM.int8() quantized linear layers (still on the meta device).
with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights():
model = quantize_model_llm_int8(model, modules_to_not_convert=modules_to_not_convert)
with log_time("Load state dict into model"):
sd = load_file(model_int8_path)
load_state_dict_into_t5(model, sd)
with log_time("Move model to cuda"):
model = model.to("cuda")
print(f"Successfully loaded pre-quantized model from '{model_int8_path}'.")
else:
# The quantized model does not exist, quantize the model and save it.
print(f"No pre-quantized model found at '{model_int8_path}'. Quantizing the model...")
with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights():
model = quantize_model_llm_int8(model, modules_to_not_convert=modules_to_not_convert)
with log_time("Load state dict into model"):
# Load sharded state dict.
files = list(model_path.glob("*.safetensors"))
state_dict = {}
for file in files:
sd = load_file(file)
state_dict.update(sd)
load_state_dict_into_t5(model, state_dict)
with log_time("Move model to cuda and quantize"):
model = model.to("cuda")
with log_time("Save quantized model"):
model_int8_path.parent.mkdir(parents=True, exist_ok=True)
state_dict = model.state_dict()
state_dict.pop("encoder.embed_tokens.weight")
save_file(state_dict, model_int8_path)
# This handling of shared weights could also be achieved with save_model(...), but then we'd lose control
# over which keys are kept. And, the corresponding load_model(...) function does not support assign=True.
# save_model(model, model_int8_path)
print(f"Successfully quantized and saved model to '{model_int8_path}'.")
assert isinstance(model, T5EncoderModel)
return model
if __name__ == "__main__":
main()

View File

@@ -25,11 +25,6 @@ class BasicConditioningInfo:
return self
@dataclass
class ConditioningFieldData:
conditionings: List[BasicConditioningInfo]
@dataclass
class SDXLConditioningInfo(BasicConditioningInfo):
"""SDXL text conditioning information produced by Compel."""
@@ -43,6 +38,22 @@ class SDXLConditioningInfo(BasicConditioningInfo):
return super().to(device=device, dtype=dtype)
@dataclass
class FLUXConditioningInfo:
clip_embeds: torch.Tensor
t5_embeds: torch.Tensor
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
self.clip_embeds = self.clip_embeds.to(device=device, dtype=dtype)
self.t5_embeds = self.t5_embeds.to(device=device, dtype=dtype)
return self
@dataclass
class ConditioningFieldData:
conditionings: List[BasicConditioningInfo] | List[SDXLConditioningInfo] | List[FLUXConditioningInfo]
@dataclass
class IPAdapterConditioningInfo:
cond_image_prompt_embeds: torch.Tensor

View File

@@ -3,10 +3,9 @@ Initialization file for invokeai.backend.util
"""
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.backend.util.util import GIG, Chdir, directory_size
from invokeai.backend.util.util import Chdir, directory_size
__all__ = [
"GIG",
"directory_size",
"Chdir",
"InvokeAILogger",

View File

@@ -7,9 +7,6 @@ from pathlib import Path
from PIL import Image
# actual size of a gig
GIG = 1073741824
def slugify(value: str, allow_unicode: bool = False) -> str:
"""

View File

@@ -127,7 +127,14 @@
"bulkDownloadRequestedDesc": "Dein Download wird vorbereitet. Dies kann ein paar Momente dauern.",
"bulkDownloadRequestFailed": "Problem beim Download vorbereiten",
"bulkDownloadFailed": "Download fehlgeschlagen",
"alwaysShowImageSizeBadge": "Zeige immer Bilder Größe Abzeichen"
"alwaysShowImageSizeBadge": "Zeige immer Bilder Größe Abzeichen",
"selectForCompare": "Zum Vergleichen auswählen",
"compareImage": "Bilder vergleichen",
"exitSearch": "Suche beenden",
"newestFirst": "Neueste zuerst",
"oldestFirst": "Älteste zuerst",
"openInViewer": "Im Viewer öffnen",
"swapImages": "Bilder tauschen"
},
"hotkeys": {
"keyboardShortcuts": "Tastenkürzel",
@@ -631,7 +638,8 @@
"archived": "Archiviert",
"noBoards": "Kein {boardType}} Ordner",
"hideBoards": "Ordner verstecken",
"viewBoards": "Ordner ansehen"
"viewBoards": "Ordner ansehen",
"deletedPrivateBoardsCannotbeRestored": "Gelöschte Boards können nicht wiederhergestellt werden. Wenn Sie „Nur Board löschen“ wählen, werden die Bilder in einen privaten, nicht kategorisierten Status für den Ersteller des Bildes versetzt."
},
"controlnet": {
"showAdvanced": "Zeige Erweitert",
@@ -781,7 +789,9 @@
"batchFieldValues": "Stapelverarbeitungswerte",
"batchQueued": "Stapelverarbeitung eingereiht",
"graphQueued": "Graph eingereiht",
"graphFailedToQueue": "Fehler beim Einreihen des Graphen"
"graphFailedToQueue": "Fehler beim Einreihen des Graphen",
"generations_one": "Generation",
"generations_other": "Generationen"
},
"metadata": {
"negativePrompt": "Negativ Beschreibung",
@@ -1146,5 +1156,10 @@
"noMatchingTriggers": "Keine passenden Trigger",
"addPromptTrigger": "Prompt-Trigger hinzufügen",
"compatibleEmbeddings": "Kompatible Einbettungen"
},
"ui": {
"tabs": {
"queue": "Warteschlange"
}
}
}

View File

@@ -696,6 +696,8 @@
"availableModels": "Available Models",
"baseModel": "Base Model",
"cancel": "Cancel",
"clipEmbed": "CLIP Embed",
"clipVision": "CLIP Vision",
"config": "Config",
"convert": "Convert",
"convertingModelBegin": "Converting Model. Please wait.",
@@ -783,13 +785,16 @@
"settings": "Settings",
"simpleModelPlaceholder": "URL or path to a local file or diffusers folder",
"source": "Source",
"spandrelImageToImage": "Image to Image (Spandrel)",
"starterModels": "Starter Models",
"starterModelsInModelManager": "Starter Models can be found in Model Manager",
"syncModels": "Sync Models",
"textualInversions": "Textual Inversions",
"triggerPhrases": "Trigger Phrases",
"loraTriggerPhrases": "LoRA Trigger Phrases",
"mainModelTriggerPhrases": "Main Model Trigger Phrases",
"typePhraseHere": "Type phrase here",
"t5Encoder": "T5 Encoder",
"upcastAttention": "Upcast Attention",
"uploadImage": "Upload Image",
"urlOrLocalPath": "URL or Local Path",
@@ -1675,6 +1680,7 @@
"layers_other": "Layers"
},
"upscaling": {
"upscale": "Upscale",
"creativity": "Creativity",
"exceedsMaxSize": "Upscale settings exceed max size limit",
"exceedsMaxSizeDetails": "Max upscale limit is {{maxUpscaleDimension}}x{{maxUpscaleDimension}} pixels. Please try a smaller image or decrease your scale selection.",
@@ -1723,6 +1729,7 @@
"positivePrompt": "Positive Prompt",
"preview": "Preview",
"private": "Private",
"promptTemplateCleared": "Prompt Template Cleared",
"searchByName": "Search by name",
"shared": "Shared",
"sharedTemplates": "Shared Templates",

View File

@@ -86,15 +86,15 @@
"loadMore": "Cargar más",
"noImagesInGallery": "No hay imágenes para mostrar",
"deleteImage_one": "Eliminar Imagen",
"deleteImage_many": "",
"deleteImage_other": "",
"deleteImage_many": "Eliminar {{count}} Imágenes",
"deleteImage_other": "Eliminar {{count}} Imágenes",
"deleteImagePermanent": "Las imágenes eliminadas no se pueden restaurar.",
"assets": "Activos",
"autoAssignBoardOnClick": "Asignación automática de tableros al hacer clic"
},
"hotkeys": {
"keyboardShortcuts": "Atajos de teclado",
"appHotkeys": "Atajos de applicación",
"appHotkeys": "Atajos de aplicación",
"generalHotkeys": "Atajos generales",
"galleryHotkeys": "Atajos de galería",
"unifiedCanvasHotkeys": "Atajos de lienzo unificado",
@@ -535,7 +535,7 @@
"bottomMessage": "Al eliminar este panel y las imágenes que contiene, se restablecerán las funciones que los estén utilizando actualmente.",
"deleteBoardAndImages": "Borrar el panel y las imágenes",
"loading": "Cargando...",
"deletedBoardsCannotbeRestored": "Los paneles eliminados no se pueden restaurar",
"deletedBoardsCannotbeRestored": "Los paneles eliminados no se pueden restaurar. Al Seleccionar 'Borrar Solo el Panel' transferirá las imágenes a un estado sin categorizar.",
"move": "Mover",
"menuItemAutoAdd": "Agregar automáticamente a este panel",
"searchBoard": "Buscando paneles…",
@@ -549,7 +549,13 @@
"imagesWithCount_other": "{{count}} imágenes",
"assetsWithCount_one": "{{count}} activo",
"assetsWithCount_many": "{{count}} activos",
"assetsWithCount_other": "{{count}} activos"
"assetsWithCount_other": "{{count}} activos",
"hideBoards": "Ocultar Paneles",
"addPrivateBoard": "Agregar un tablero privado",
"addSharedBoard": "Agregar Panel Compartido",
"boards": "Paneles",
"archiveBoard": "Archivar Panel",
"archived": "Archivado"
},
"accordions": {
"compositing": {

View File

@@ -496,7 +496,9 @@
"main": "Principali",
"noModelsInstalledDesc1": "Installa i modelli con",
"ipAdapters": "Adattatori IP",
"noMatchingModels": "Nessun modello corrispondente"
"noMatchingModels": "Nessun modello corrispondente",
"starterModelsInModelManager": "I modelli iniziali possono essere trovati in Gestione Modelli",
"spandrelImageToImage": "Immagine a immagine (Spandrel)"
},
"parameters": {
"images": "Immagini",
@@ -510,7 +512,7 @@
"perlinNoise": "Rumore Perlin",
"type": "Tipo",
"strength": "Forza",
"upscaling": "Ampliamento",
"upscaling": "Amplia",
"scale": "Scala",
"imageFit": "Adatta l'immagine iniziale alle dimensioni di output",
"scaleBeforeProcessing": "Scala prima dell'elaborazione",
@@ -593,7 +595,7 @@
"globalPositivePromptPlaceholder": "Prompt positivo globale",
"globalNegativePromptPlaceholder": "Prompt negativo globale",
"processImage": "Elabora Immagine",
"sendToUpscale": "Invia a Ampliare",
"sendToUpscale": "Invia a Amplia",
"postProcessing": "Post-elaborazione (Shift + U)"
},
"settings": {
@@ -929,7 +931,7 @@
"missingInvocationTemplate": "Modello di invocazione mancante",
"missingFieldTemplate": "Modello di campo mancante",
"singleFieldType": "{{name}} (Singola)",
"imageAccessError": "Impossibile trovare l'immagine {{image_name}}, ripristino delle impostazioni predefinite",
"imageAccessError": "Impossibile trovare l'immagine {{image_name}}, ripristino ai valori predefiniti",
"boardAccessError": "Impossibile trovare la bacheca {{board_id}}, ripristino ai valori predefiniti",
"modelAccessError": "Impossibile trovare il modello {{key}}, ripristino ai valori predefiniti"
},
@@ -1420,7 +1422,7 @@
"paramUpscaleMethod": {
"heading": "Metodo di ampliamento",
"paragraphs": [
"Metodo utilizzato per eseguire l'ampliamento dell'immagine per la correzione ad alta risoluzione."
"Metodo utilizzato per ampliare l'immagine per la correzione ad alta risoluzione."
]
},
"patchmatchDownScaleSize": {
@@ -1528,7 +1530,7 @@
},
"upscaleModel": {
"paragraphs": [
"Il modello di ampliamento (Upscale), scala l'immagine alle dimensioni di uscita prima di aggiungere i dettagli. È possibile utilizzare qualsiasi modello di ampliamento supportato, ma alcuni sono specializzati per diversi tipi di immagini, come foto o disegni al tratto."
"Il modello di ampliamento, scala l'immagine alle dimensioni di uscita prima di aggiungere i dettagli. È possibile utilizzare qualsiasi modello di ampliamento supportato, ma alcuni sono specializzati per diversi tipi di immagini, come foto o disegni al tratto."
],
"heading": "Modello di ampliamento"
},
@@ -1720,26 +1722,27 @@
"modelsTab": "$t(ui.tabs.models) $t(common.tab)",
"queue": "Coda",
"queueTab": "$t(ui.tabs.queue) $t(common.tab)",
"upscaling": "Ampliamento",
"upscaling": "Amplia",
"upscalingTab": "$t(ui.tabs.upscaling) $t(common.tab)"
}
},
"upscaling": {
"creativity": "Creatività",
"structure": "Struttura",
"upscaleModel": "Modello di Ampliamento",
"upscaleModel": "Modello di ampliamento",
"scale": "Scala",
"missingModelsWarning": "Visita <LinkComponent>Gestione modelli</LinkComponent> per installare i modelli richiesti:",
"mainModelDesc": "Modello principale (architettura SD1.5 o SDXL)",
"tileControlNetModelDesc": "Modello Tile ControlNet per l'architettura del modello principale scelto",
"upscaleModelDesc": "Modello per l'ampliamento (da immagine a immagine)",
"upscaleModelDesc": "Modello per l'ampliamento (immagine a immagine)",
"missingUpscaleInitialImage": "Immagine iniziale mancante per l'ampliamento",
"missingUpscaleModel": "Modello per lampliamento mancante",
"missingTileControlNetModel": "Nessun modello ControlNet Tile valido installato",
"postProcessingModel": "Modello di post-elaborazione",
"postProcessingMissingModelWarning": "Visita <LinkComponent>Gestione modelli</LinkComponent> per installare un modello di post-elaborazione (da immagine a immagine).",
"exceedsMaxSize": "Le impostazioni di ampliamento superano il limite massimo delle dimensioni",
"exceedsMaxSizeDetails": "Il limite massimo di ampliamento è {{maxUpscaleDimension}}x{{maxUpscaleDimension}} pixel. Prova un'immagine più piccola o diminuisci la scala selezionata."
"exceedsMaxSizeDetails": "Il limite massimo di ampliamento è {{maxUpscaleDimension}}x{{maxUpscaleDimension}} pixel. Prova un'immagine più piccola o diminuisci la scala selezionata.",
"upscale": "Amplia"
},
"upsell": {
"inviteTeammates": "Invita collaboratori",
@@ -1782,7 +1785,14 @@
"updatePromptTemplate": "Aggiorna il modello di prompt",
"type": "Tipo",
"promptTemplatesDesc2": "Utilizza la stringa segnaposto <Pre>{{placeholder}}</Pre> per specificare dove inserire il tuo prompt nel modello.",
"importTemplates": "Importa modelli di prompt",
"importTemplatesDesc": "Il formato deve essere un CSV con colonne 'name' e 'prompt' o 'positive_prompt' e 'negative_prompt' incluse, oppure un file JSON con chiavi 'name' e 'prompt' o 'positive_prompt' e 'negative_prompt"
"importTemplates": "Importa modelli di prompt (CSV/JSON)",
"exportDownloaded": "Esportazione completata",
"exportFailed": "Impossibile generare e scaricare il file CSV",
"exportPromptTemplates": "Esporta i miei modelli di prompt (CSV)",
"positivePromptColumn": "'prompt' o 'positive_prompt'",
"noTemplates": "Nessun modello",
"acceptedColumnsKeys": "Colonne/chiavi accettate:",
"templateActions": "Azioni modello",
"promptTemplateCleared": "Modello di prompt cancellato"
}
}

View File

@@ -91,7 +91,8 @@
"enabled": "Включено",
"disabled": "Отключено",
"comparingDesc": "Сравнение двух изображений",
"comparing": "Сравнение"
"comparing": "Сравнение",
"dontShowMeThese": "Не показывай мне это"
},
"gallery": {
"galleryImageSize": "Размер изображений",
@@ -153,7 +154,11 @@
"showArchivedBoards": "Показать архивированные доски",
"searchImages": "Поиск по метаданным",
"displayBoardSearch": "Отобразить поиск досок",
"displaySearch": "Отобразить поиск"
"displaySearch": "Отобразить поиск",
"exitBoardSearch": "Выйти из поиска досок",
"go": "Перейти",
"exitSearch": "Выйти из поиска",
"jump": "Пыгнуть"
},
"hotkeys": {
"keyboardShortcuts": "Горячие клавиши",
@@ -376,6 +381,10 @@
"toggleViewer": {
"title": "Переключить просмотр изображений",
"desc": "Переключение между средством просмотра изображений и рабочей областью для текущей вкладки."
},
"postProcess": {
"desc": "Обработайте текущее изображение с помощью выбранной модели постобработки",
"title": "Обработать изображение"
}
},
"modelManager": {
@@ -492,7 +501,8 @@
"noModelsInstalled": "Нет установленных моделей",
"noModelsInstalledDesc1": "Установите модели с помощью",
"noMatchingModels": "Нет подходящих моделей",
"ipAdapters": "IP адаптеры"
"ipAdapters": "IP адаптеры",
"starterModelsInModelManager": "Стартовые модели можно найти в Менеджере моделей"
},
"parameters": {
"images": "Изображения",
@@ -589,7 +599,10 @@
"infillColorValue": "Цвет заливки",
"globalSettings": "Глобальные настройки",
"globalNegativePromptPlaceholder": "Глобальный негативный запрос",
"globalPositivePromptPlaceholder": "Глобальный запрос"
"globalPositivePromptPlaceholder": "Глобальный запрос",
"postProcessing": "Постобработка (Shift + U)",
"processImage": "Обработка изображения",
"sendToUpscale": "Отправить на увеличение"
},
"settings": {
"models": "Модели",
@@ -623,7 +636,9 @@
"intermediatesCleared_many": "Очищено {{count}} промежуточных",
"clearIntermediatesDesc1": "Очистка промежуточных элементов приведет к сбросу состояния Canvas и ControlNet.",
"intermediatesClearedFailed": "Проблема очистки промежуточных",
"reloadingIn": "Перезагрузка через"
"reloadingIn": "Перезагрузка через",
"informationalPopoversDisabled": "Информационные всплывающие окна отключены",
"informationalPopoversDisabledDesc": "Информационные всплывающие окна были отключены. Включите их в Настройках."
},
"toast": {
"uploadFailed": "Загрузка не удалась",
@@ -694,7 +709,9 @@
"sessionRef": "Сессия: {{sessionId}}",
"outOfMemoryError": "Ошибка нехватки памяти",
"outOfMemoryErrorDesc": "Ваши текущие настройки генерации превышают возможности системы. Пожалуйста, измените настройки и повторите попытку.",
"somethingWentWrong": "Что-то пошло не так"
"somethingWentWrong": "Что-то пошло не так",
"importFailed": "Импорт неудачен",
"importSuccessful": "Импорт успешен"
},
"tooltip": {
"feature": {
@@ -1017,7 +1034,8 @@
"composition": "Только композиция",
"hed": "HED",
"beginEndStepPercentShort": "Начало/конец %",
"setControlImageDimensionsForce": "Скопируйте размер в Ш/В (игнорируйте модель)"
"setControlImageDimensionsForce": "Скопируйте размер в Ш/В (игнорируйте модель)",
"depthAnythingSmallV2": "Small V2"
},
"boards": {
"autoAddBoard": "Авто добавление Доски",
@@ -1042,7 +1060,7 @@
"downloadBoard": "Скачать доску",
"deleteBoard": "Удалить доску",
"deleteBoardAndImages": "Удалить доску и изображения",
"deletedBoardsCannotbeRestored": "Удаленные доски не подлежат восстановлению",
"deletedBoardsCannotbeRestored": "Удаленные доски не могут быть восстановлены. Выбор «Удалить только доску» переведет изображения в состояние без категории.",
"assetsWithCount_one": "{{count}} ассет",
"assetsWithCount_few": "{{count}} ассета",
"assetsWithCount_many": "{{count}} ассетов",
@@ -1057,7 +1075,11 @@
"boards": "Доски",
"addPrivateBoard": "Добавить личную доску",
"private": "Личные доски",
"shared": "Общие доски"
"shared": "Общие доски",
"hideBoards": "Скрыть доски",
"viewBoards": "Просмотреть доски",
"noBoards": "Нет досок {{boardType}}",
"deletedPrivateBoardsCannotbeRestored": "Удаленные доски не могут быть восстановлены. Выбор «Удалить только доску» переведет изображения в приватное состояние без категории для создателя изображения."
},
"dynamicPrompts": {
"seedBehaviour": {
@@ -1417,6 +1439,30 @@
"paragraphs": [
"Метод, с помощью которого применяется текущий IP-адаптер."
]
},
"structure": {
"paragraphs": [
"Структура контролирует, насколько точно выходное изображение будет соответствовать макету оригинала. Низкая структура допускает значительные изменения, в то время как высокая структура строго сохраняет исходную композицию и макет."
],
"heading": "Структура"
},
"scale": {
"paragraphs": [
"Масштаб управляет размером выходного изображения и основывается на кратном разрешении входного изображения. Например, при увеличении в 2 раза изображения 1024x1024 на выходе получится 2048 x 2048."
],
"heading": "Масштаб"
},
"creativity": {
"paragraphs": [
"Креативность контролирует степень свободы, предоставляемой модели при добавлении деталей. При низкой креативности модель остается близкой к оригинальному изображению, в то время как высокая креативность позволяет вносить больше изменений. При использовании подсказки высокая креативность увеличивает влияние подсказки."
],
"heading": "Креативность"
},
"upscaleModel": {
"heading": "Модель увеличения",
"paragraphs": [
"Модель увеличения масштаба масштабирует изображение до выходного размера перед добавлением деталей. Можно использовать любую поддерживаемую модель масштабирования, но некоторые из них специализированы для различных видов изображений, например фотографий или линейных рисунков."
]
}
},
"metadata": {
@@ -1693,7 +1739,80 @@
"canvasTab": "$t(ui.tabs.canvas) $t(common.tab)",
"queueTab": "$t(ui.tabs.queue) $t(common.tab)",
"modelsTab": "$t(ui.tabs.models) $t(common.tab)",
"queue": "Очередь"
"queue": "Очередь",
"upscaling": "Увеличение",
"upscalingTab": "$t(ui.tabs.upscaling) $t(common.tab)"
}
},
"upscaling": {
"exceedsMaxSize": "Параметры масштабирования превышают максимальный размер",
"exceedsMaxSizeDetails": "Максимальный предел масштабирования составляет {{maxUpscaleDimension}}x{{maxUpscaleDimension}} пикселей. Пожалуйста, попробуйте использовать меньшее изображение или уменьшите масштаб.",
"structure": "Структура",
"missingTileControlNetModel": "Не установлены подходящие модели ControlNet",
"missingUpscaleInitialImage": "Отсутствует увеличиваемое изображение",
"missingUpscaleModel": "Отсутствует увеличивающая модель",
"creativity": "Креативность",
"upscaleModel": "Модель увеличения",
"scale": "Масштаб",
"mainModelDesc": "Основная модель (архитектура SD1.5 или SDXL)",
"upscaleModelDesc": "Модель увеличения (img2img)",
"postProcessingModel": "Модель постобработки",
"tileControlNetModelDesc": "Модель ControlNet для выбранной архитектуры основной модели",
"missingModelsWarning": "Зайдите в <LinkComponent>Менеджер моделей</LinkComponent> чтоб установить необходимые модели:",
"postProcessingMissingModelWarning": "Посетите <LinkComponent>Менеджер моделей</LinkComponent>, чтобы установить модель постобработки (img2img).",
"upscale": "Увеличить"
},
"stylePresets": {
"noMatchingTemplates": "Нет подходящих шаблонов",
"promptTemplatesDesc1": "Шаблоны подсказок добавляют текст к подсказкам, которые вы пишете в окне подсказок.",
"sharedTemplates": "Общие шаблоны",
"templateDeleted": "Шаблон запроса удален",
"toggleViewMode": "Переключить режим просмотра",
"type": "Тип",
"unableToDeleteTemplate": "Не получилось удалить шаблон запроса",
"viewModeTooltip": "Вот как будет выглядеть ваш запрос с выбранным шаблоном. Чтобы его отредактировать, щелкните в любом месте текстового поля.",
"viewList": "Просмотреть список шаблонов",
"active": "Активно",
"choosePromptTemplate": "Выберите шаблон запроса",
"defaultTemplates": "Стандартные шаблоны",
"deleteImage": "Удалить изображение",
"deleteTemplate": "Удалить шаблон",
"deleteTemplate2": "Вы уверены, что хотите удалить этот шаблон? Это нельзя отменить.",
"editTemplate": "Редактировать шаблон",
"exportPromptTemplates": "Экспорт моих шаблонов запроса (CSV)",
"exportDownloaded": "Экспорт скачан",
"exportFailed": "Невозможно сгенерировать и загрузить CSV",
"flatten": "Объединить выбранный шаблон с текущим запросом",
"acceptedColumnsKeys": "Принимаемые столбцы/ключи:",
"positivePromptColumn": "'prompt' или 'positive_prompt'",
"insertPlaceholder": "Вставить заполнитель",
"name": "Имя",
"negativePrompt": "Негативный запрос",
"promptTemplatesDesc3": "Если вы не используете заполнитель, шаблон будет добавлен в конец запроса.",
"positivePrompt": "Позитивный запрос",
"preview": "Предпросмотр",
"private": "Приватный",
"templateActions": "Действия с шаблоном",
"updatePromptTemplate": "Обновить шаблон запроса",
"uploadImage": "Загрузить изображение",
"useForTemplate": "Использовать для шаблона запроса",
"clearTemplateSelection": "Очистить выбор шаблона",
"copyTemplate": "Копировать шаблон",
"createPromptTemplate": "Создать шаблон запроса",
"importTemplates": "Импортировать шаблоны запроса (CSV/JSON)",
"nameColumn": "'name'",
"negativePromptColumn": "'negative_prompt'",
"myTemplates": "Мои шаблоны",
"noTemplates": "Нет шаблонов",
"promptTemplatesDesc2": "Используйте строку-заполнитель <Pre>{{placeholder}}</Pre>, чтобы указать место, куда должен быть включен ваш запрос в шаблоне.",
"searchByName": "Поиск по имени",
"shared": "Общий",
"promptTemplateCleared": "Шаблон запроса создан"
},
"upsell": {
"inviteTeammates": "Пригласите членов команды",
"professional": "Профессионал",
"professionalUpsell": "Доступно в профессиональной версии Invoke. Нажмите здесь или посетите invoke.com/pricing для получения более подробной информации.",
"shareAccess": "Поделиться доступом"
}
}

View File

@@ -154,7 +154,8 @@
"displaySearch": "显示搜索",
"stretchToFit": "拉伸以适应",
"exitCompare": "退出对比",
"compareHelp1": "在点击图库中的图片或使用箭头键切换比较图片时,请按住<Kbd>Alt</Kbd> 键。"
"compareHelp1": "在点击图库中的图片或使用箭头键切换比较图片时,请按住<Kbd>Alt</Kbd> 键。",
"go": "运行"
},
"hotkeys": {
"keyboardShortcuts": "快捷键",
@@ -494,7 +495,9 @@
"huggingFacePlaceholder": "所有者或模型名称",
"huggingFaceRepoID": "HuggingFace仓库ID",
"loraTriggerPhrases": "LoRA 触发词",
"ipAdapters": "IP适配器"
"ipAdapters": "IP适配器",
"spandrelImageToImage": "图生图(Spandrel)",
"starterModelsInModelManager": "您可以在模型管理器中找到初始模型"
},
"parameters": {
"images": "图像",
@@ -695,7 +698,9 @@
"outOfMemoryErrorDesc": "您当前的生成设置已超出系统处理能力.请调整设置后再次尝试.",
"parametersSet": "参数已恢复",
"errorCopied": "错误信息已复制",
"modelImportCanceled": "模型导入已取消"
"modelImportCanceled": "模型导入已取消",
"importFailed": "导入失败",
"importSuccessful": "导入成功"
},
"unifiedCanvas": {
"layer": "图层",
@@ -1705,12 +1710,55 @@
"missingModelsWarning": "请访问<LinkComponent>模型管理器</LinkComponent> 安装所需的模型:",
"mainModelDesc": "主模型SD1.5或SDXL架构",
"exceedsMaxSize": "放大设置超出了最大尺寸限制",
"exceedsMaxSizeDetails": "最大放大限制是 {{maxUpscaleDimension}}x{{maxUpscaleDimension}} 像素.请尝试一个较小的图像或减少您的缩放选择."
"exceedsMaxSizeDetails": "最大放大限制是 {{maxUpscaleDimension}}x{{maxUpscaleDimension}} 像素.请尝试一个较小的图像或减少您的缩放选择.",
"upscale": "放大"
},
"upsell": {
"inviteTeammates": "邀请团队成员",
"professional": "专业",
"professionalUpsell": "可在 Invoke 的专业版中使用.点击此处或访问 invoke.com/pricing 了解更多详情.",
"shareAccess": "共享访问权限"
},
"stylePresets": {
"positivePrompt": "正向提示词",
"preview": "预览",
"deleteImage": "删除图像",
"deleteTemplate": "删除模版",
"deleteTemplate2": "您确定要删除这个模板吗?请注意,删除后无法恢复.",
"importTemplates": "导入提示模板支持CSV或JSON格式",
"insertPlaceholder": "插入一个占位符",
"myTemplates": "我的模版",
"name": "名称",
"type": "类型",
"unableToDeleteTemplate": "无法删除提示模板",
"updatePromptTemplate": "更新提示词模版",
"exportPromptTemplates": "导出我的提示模板为CSV格式",
"exportDownloaded": "导出已下载",
"noMatchingTemplates": "无匹配的模版",
"promptTemplatesDesc1": "提示模板可以帮助您在编写提示时添加预设的文本内容.",
"promptTemplatesDesc3": "如果您没有使用占位符,那么模板的内容将会被添加到您提示的末尾.",
"searchByName": "按名称搜索",
"shared": "已分享",
"sharedTemplates": "已分享的模版",
"templateActions": "模版操作",
"templateDeleted": "提示模版已删除",
"toggleViewMode": "切换显示模式",
"uploadImage": "上传图像",
"active": "激活",
"choosePromptTemplate": "选择提示词模板",
"clearTemplateSelection": "清除模版选择",
"copyTemplate": "拷贝模版",
"createPromptTemplate": "创建提示词模版",
"defaultTemplates": "默认模版",
"editTemplate": "编辑模版",
"exportFailed": "无法生成并下载CSV文件",
"flatten": "将选定的模板内容合并到当前提示中",
"negativePrompt": "反向提示词",
"promptTemplateCleared": "提示模板已清除",
"useForTemplate": "用于提示词模版",
"viewList": "预览模版列表",
"viewModeTooltip": "这是您的提示在当前选定的模板下的预览效果。如需编辑提示,请直接在文本框中点击进行修改.",
"noTemplates": "无模版",
"private": "私密"
}
}

View File

@@ -14,6 +14,7 @@ import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageMo
import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicPromptsPreviewModal';
import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast';
import { StylePresetModal } from 'features/stylePresets/components/StylePresetForm/StylePresetModal';
import { activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice';
import { configChanged } from 'features/system/store/configSlice';
import { languageSelector } from 'features/system/store/systemSelectors';
import InvokeTabs from 'features/ui/components/InvokeTabs';
@@ -39,10 +40,17 @@ interface Props {
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
};
selectedWorkflowId?: string;
selectedStylePresetId?: string;
destination?: InvokeTabName | undefined;
}
const App = ({ config = DEFAULT_CONFIG, selectedImage, selectedWorkflowId, destination }: Props) => {
const App = ({
config = DEFAULT_CONFIG,
selectedImage,
selectedWorkflowId,
selectedStylePresetId,
destination,
}: Props) => {
const language = useAppSelector(languageSelector);
const logger = useLogger('system');
const dispatch = useAppDispatch();
@@ -81,6 +89,12 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage, selectedWorkflowId, desti
}
}, [selectedWorkflowId, getAndLoadWorkflow]);
useEffect(() => {
if (selectedStylePresetId) {
dispatch(activeStylePresetIdChanged(selectedStylePresetId));
}
}, [dispatch, selectedStylePresetId]);
useEffect(() => {
if (destination) {
dispatch(setActiveTab(destination));

View File

@@ -45,6 +45,7 @@ interface Props extends PropsWithChildren {
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
};
selectedWorkflowId?: string;
selectedStylePresetId?: string;
destination?: InvokeTabName;
customStarUi?: CustomStarUi;
socketOptions?: Partial<ManagerOptions & SocketOptions>;
@@ -66,6 +67,7 @@ const InvokeAIUI = ({
queueId,
selectedImage,
selectedWorkflowId,
selectedStylePresetId,
destination,
customStarUi,
socketOptions,
@@ -227,6 +229,7 @@ const InvokeAIUI = ({
config={config}
selectedImage={selectedImage}
selectedWorkflowId={selectedWorkflowId}
selectedStylePresetId={selectedStylePresetId}
destination={destination}
/>
</AppDndContext>

View File

@@ -13,6 +13,7 @@ import {
import { getShouldProcessPrompt } from 'features/dynamicPrompts/util/getShouldProcessPrompt';
import { getPresetModifiedPrompts } from 'features/nodes/util/graph/graphBuilderUtils';
import { activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice';
import { stylePresetsApi } from 'services/api/endpoints/stylePresets';
import { utilitiesApi } from 'services/api/endpoints/utilities';
import { socketConnected } from 'services/events/actions';
@@ -22,7 +23,10 @@ const matcher = isAnyOf(
maxPromptsChanged,
maxPromptsReset,
socketConnected,
activeStylePresetIdChanged
activeStylePresetIdChanged,
stylePresetsApi.endpoints.deleteStylePreset.matchFulfilled,
stylePresetsApi.endpoints.updateStylePreset.matchFulfilled,
stylePresetsApi.endpoints.listStylePresets.matchFulfilled
);
export const addDynamicPromptsListener = (startAppListening: AppStartListening) => {

View File

@@ -8,7 +8,7 @@ import { $authToken } from 'app/store/nanostores/authToken';
*/
export const convertImageUrlToBlob = async (url: string) =>
new Promise<Blob | null>((resolve) => {
new Promise<Blob | null>((resolve, reject) => {
const img = new Image();
img.onload = () => {
const canvas = document.createElement('canvas');
@@ -17,17 +17,23 @@ export const convertImageUrlToBlob = async (url: string) =>
const context = canvas.getContext('2d');
if (!context) {
reject(new Error('Failed to get canvas context'));
return;
}
context.drawImage(img, 0, 0);
resolve(
new Promise<Blob | null>((resolve) => {
canvas.toBlob(function (blob) {
resolve(blob);
}, 'image/png');
})
);
canvas.toBlob((blob) => {
if (blob) {
resolve(blob);
} else {
reject(new Error('Failed to convert image to blob'));
}
}, 'image/png');
};
img.onerror = () => {
reject(new Error('Image failed to load. The URL may be invalid or the object may not exist.'));
};
img.crossOrigin = $authToken.get() ? 'use-credentials' : 'anonymous';
img.src = url;
});

View File

@@ -66,7 +66,7 @@ export const Gallery = () => {
<Flex flexDirection="column" alignItems="center" justifyContent="space-between" h="full" w="full" pt={1}>
<Tabs index={galleryView === 'images' ? 0 : 1} variant="enclosed" display="flex" flexDir="column" w="full">
<TabList gap={2} fontSize="sm" borderColor="base.800" alignItems="center" w="full">
<Text fontSize="sm" fontWeight="semibold" noOfLines={1} px="2">
<Text fontSize="sm" fontWeight="semibold" noOfLines={1} px="2" wordBreak="break-all">
{boardName}
</Text>
<Spacer />

View File

@@ -60,7 +60,6 @@ const ImageGalleryContent = () => {
<GalleryHeader />
<Flex alignItems="center" justifyContent="space-between" w="full">
<Button
w={112}
size="sm"
variant="ghost"
onClick={handleToggleBoardPanel}

View File

@@ -1,14 +1,20 @@
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { handlers, parseAndRecallAllMetadata, parseAndRecallPrompts } from 'features/metadata/util/handlers';
import { $stylePresetModalState } from 'features/stylePresets/store/stylePresetModal';
import { activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice';
import { toast } from 'features/toast/toast';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { useCallback, useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { useDebouncedMetadata } from 'services/api/hooks/useDebouncedMetadata';
export const useImageActions = (image_name?: string) => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const activeTabName = useAppSelector(activeTabNameSelector);
const activeStylePresetId = useAppSelector((s) => s.stylePreset.activeStylePresetId);
const { metadata, isLoading: isLoadingMetadata } = useDebouncedMetadata(image_name);
const [hasMetadata, setHasMetadata] = useState(false);
const [hasSeed, setHasSeed] = useState(false);
@@ -46,14 +52,26 @@ export const useImageActions = (image_name?: string) => {
parseMetadata();
}, [metadata]);
const clearStylePreset = useCallback(() => {
if (activeStylePresetId) {
dispatch(activeStylePresetIdChanged(null));
toast({
status: 'info',
title: t('stylePresets.promptTemplateCleared'),
});
}
}, [dispatch, activeStylePresetId, t]);
const recallAll = useCallback(() => {
parseAndRecallAllMetadata(metadata, activeTabName === 'generation');
}, [activeTabName, metadata]);
clearStylePreset();
}, [activeTabName, metadata, clearStylePreset]);
const remix = useCallback(() => {
// Recalls all metadata parameters except seed
parseAndRecallAllMetadata(metadata, activeTabName === 'generation', ['seed']);
}, [activeTabName, metadata]);
clearStylePreset();
}, [activeTabName, metadata, clearStylePreset]);
const recallSeed = useCallback(() => {
handlers.seed.parse(metadata).then((seed) => {
@@ -63,12 +81,24 @@ export const useImageActions = (image_name?: string) => {
const recallPrompts = useCallback(() => {
parseAndRecallPrompts(metadata);
}, [metadata]);
clearStylePreset();
}, [metadata, clearStylePreset]);
const createAsPreset = useCallback(async () => {
if (image_name && metadata && imageDTO) {
const positivePrompt = await handlers.positivePrompt.parse(metadata);
const negativePrompt = await handlers.negativePrompt.parse(metadata);
let positivePrompt;
let negativePrompt;
try {
positivePrompt = await handlers.positivePrompt.parse(metadata);
} catch (error) {
positivePrompt = '';
}
try {
negativePrompt = await handlers.negativePrompt.parse(metadata);
} catch (error) {
negativePrompt = '';
}
$stylePresetModalState.set({
prefilledFormData: {

View File

@@ -5,17 +5,33 @@ import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPlusBold } from 'react-icons/pi';
import type { GetStarterModelsResponse } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
type Props = {
result: GetStarterModelsResponse[number];
modelList: AnyModelConfig[];
};
export const StarterModelsResultItem = memo(({ result }: Props) => {
export const StarterModelsResultItem = memo(({ result, modelList }: Props) => {
const { t } = useTranslation();
const allSources = useMemo(() => {
const _allSources = [{ source: result.source, config: { name: result.name, description: result.description } }];
const _allSources = [
{
source: result.source,
config: {
name: result.name,
description: result.description,
type: result.type,
base: result.base,
format: result.format,
},
},
];
if (result.dependencies) {
for (const d of result.dependencies) {
_allSources.push({ source: d.source, config: { name: d.name, description: d.description } });
_allSources.push({
source: d.source,
config: { name: d.name, description: d.description, type: d.type, base: d.base, format: d.format },
});
}
}
return _allSources;
@@ -24,9 +40,12 @@ export const StarterModelsResultItem = memo(({ result }: Props) => {
const onClick = useCallback(() => {
for (const { config, source } of allSources) {
if (modelList.some((mc) => config.base === mc.base && config.name === mc.name && config.type === mc.type)) {
continue;
}
installModel({ config, source });
}
}, [allSources, installModel]);
}, [modelList, allSources, installModel]);
return (
<Flex alignItems="center" justifyContent="space-between" w="100%" gap={3}>

View File

@@ -1,17 +1,31 @@
import { Flex } from '@invoke-ai/ui-library';
import { EMPTY_ARRAY } from 'app/store/constants';
import { FetchingModelsLoader } from 'features/modelManagerV2/subpanels/ModelManagerPanel/FetchingModelsLoader';
import { memo } from 'react';
import { useGetStarterModelsQuery } from 'services/api/endpoints/models';
import { memo, useMemo } from 'react';
import {
modelConfigsAdapterSelectors,
useGetModelConfigsQuery,
useGetStarterModelsQuery,
} from 'services/api/endpoints/models';
import { StarterModelsResults } from './StarterModelsResults';
export const StarterModelsForm = memo(() => {
const { isLoading, data } = useGetStarterModelsQuery();
const { data: modelListRes } = useGetModelConfigsQuery();
const modelList = useMemo(() => {
if (!modelListRes) {
return EMPTY_ARRAY;
}
return modelConfigsAdapterSelectors.selectAll(modelListRes);
}, [modelListRes]);
return (
<Flex flexDir="column" height="100%" gap={3}>
{isLoading && <FetchingModelsLoader loadingMessage="Loading Embeddings..." />}
{data && <StarterModelsResults results={data} />}
{data && <StarterModelsResults results={data} modelList={modelList} />}
</Flex>
);
});

View File

@@ -5,14 +5,16 @@ import { memo, useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { PiXBold } from 'react-icons/pi';
import type { GetStarterModelsResponse } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
import { StarterModelsResultItem } from './StartModelsResultItem';
type StarterModelsResultsProps = {
results: NonNullable<GetStarterModelsResponse>;
modelList: AnyModelConfig[];
};
export const StarterModelsResults = memo(({ results }: StarterModelsResultsProps) => {
export const StarterModelsResults = memo(({ results, modelList }: StarterModelsResultsProps) => {
const { t } = useTranslation();
const [searchTerm, setSearchTerm] = useState('');
@@ -72,7 +74,7 @@ export const StarterModelsResults = memo(({ results }: StarterModelsResultsProps
<ScrollableContent>
<Flex flexDir="column" gap={3}>
{filteredResults.map((result) => (
<StarterModelsResultItem key={result.source} result={result} />
<StarterModelsResultItem key={result.source} result={result} modelList={modelList} />
))}
</Flex>
</ScrollableContent>

View File

@@ -13,6 +13,7 @@ const BASE_COLOR_MAP: Record<BaseModelType, string> = {
'sd-2': 'teal',
sdxl: 'invokeBlue',
'sdxl-refiner': 'invokeBlue',
flux: 'gold',
};
const ModelBaseBadge = ({ base }: Props) => {

View File

@@ -13,6 +13,9 @@ const FORMAT_NAME_MAP: Record<AnyModelConfig['format'], string> = {
invokeai: 'internal',
embedding_file: 'embedding',
embedding_folder: 'embedding',
t5_encoder: 't5_encoder',
bnb_quantized_int8b: 'bnb_quantized_int8b',
bnb_quantized_nf4b: 'quantized',
};
const FORMAT_COLOR_MAP: Record<AnyModelConfig['format'], string> = {
@@ -22,6 +25,9 @@ const FORMAT_COLOR_MAP: Record<AnyModelConfig['format'], string> = {
invokeai: 'base',
embedding_file: 'base',
embedding_folder: 'base',
t5_encoder: 'base',
bnb_quantized_int8b: 'base',
bnb_quantized_nf4b: 'base',
};
const ModelFormatBadge = ({ format }: Props) => {

View File

@@ -5,6 +5,7 @@ import type { FilterableModelType } from 'features/modelManagerV2/store/modelMan
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import {
useClipEmbedModels,
useControlNetModels,
useEmbeddingModels,
useIPAdapterModels,
@@ -13,6 +14,7 @@ import {
useRefinerModels,
useSpandrelImageToImageModels,
useT2IAdapterModels,
useT5EncoderModels,
useVAEModels,
} from 'services/api/hooks/modelsByType';
import type { AnyModelConfig } from 'services/api/types';
@@ -73,6 +75,18 @@ const ModelList = () => {
[vaeModels, searchTerm, filteredModelType]
);
const [t5EncoderModels, { isLoading: isLoadingT5EncoderModels }] = useT5EncoderModels();
const filteredT5EncoderModels = useMemo(
() => modelsFilter(t5EncoderModels, searchTerm, filteredModelType),
[t5EncoderModels, searchTerm, filteredModelType]
);
const [clipEmbedModels, { isLoading: isLoadingClipEmbedModels }] = useClipEmbedModels();
const filteredClipEmbedModels = useMemo(
() => modelsFilter(clipEmbedModels, searchTerm, filteredModelType),
[clipEmbedModels, searchTerm, filteredModelType]
);
const [spandrelImageToImageModels, { isLoading: isLoadingSpandrelImageToImageModels }] =
useSpandrelImageToImageModels();
const filteredSpandrelImageToImageModels = useMemo(
@@ -90,7 +104,9 @@ const ModelList = () => {
filteredT2IAdapterModels.length +
filteredIPAdapterModels.length +
filteredVAEModels.length +
filteredSpandrelImageToImageModels.length
filteredSpandrelImageToImageModels.length +
t5EncoderModels.length +
clipEmbedModels.length
);
}, [
filteredControlNetModels.length,
@@ -102,6 +118,8 @@ const ModelList = () => {
filteredT2IAdapterModels.length,
filteredVAEModels.length,
filteredSpandrelImageToImageModels.length,
t5EncoderModels.length,
clipEmbedModels.length,
]);
return (
@@ -154,13 +172,23 @@ const ModelList = () => {
{!isLoadingT2IAdapterModels && filteredT2IAdapterModels.length > 0 && (
<ModelListWrapper title={t('common.t2iAdapter')} modelList={filteredT2IAdapterModels} key="t2i-adapters" />
)}
{/* T5 Encoders List */}
{isLoadingT5EncoderModels && <FetchingModelsLoader loadingMessage="Loading T5 Encoder Models..." />}
{!isLoadingT5EncoderModels && filteredT5EncoderModels.length > 0 && (
<ModelListWrapper title={t('modelManager.t5Encoder')} modelList={filteredT5EncoderModels} key="t5-encoder" />
)}
{/* Clip Embed List */}
{isLoadingClipEmbedModels && <FetchingModelsLoader loadingMessage="Loading Clip Embed Models..." />}
{!isLoadingClipEmbedModels && filteredClipEmbedModels.length > 0 && (
<ModelListWrapper title={t('modelManager.clipEmbed')} modelList={filteredClipEmbedModels} key="clip-embed" />
)}
{/* Spandrel Image to Image List */}
{isLoadingSpandrelImageToImageModels && (
<FetchingModelsLoader loadingMessage="Loading Image-to-Image Models..." />
)}
{!isLoadingSpandrelImageToImageModels && filteredSpandrelImageToImageModels.length > 0 && (
<ModelListWrapper
title="Image-to-Image"
title={t('modelManager.spandrelImageToImage')}
modelList={filteredSpandrelImageToImageModels}
key="spandrel-image-to-image"
/>

View File

@@ -19,9 +19,10 @@ export const ModelTypeFilter = memo(() => {
controlnet: 'ControlNet',
vae: 'VAE',
t2i_adapter: t('common.t2iAdapter'),
t5_encoder: t('modelManager.t5Encoder'),
clip_embed: t('modelManager.clipEmbed'),
ip_adapter: t('common.ipAdapter'),
clip_vision: 'Clip Vision',
spandrel_image_to_image: 'Image-to-Image',
spandrel_image_to_image: t('modelManager.spandrelImageToImage'),
}),
[t]
);

View File

@@ -6,6 +6,8 @@ import {
isBoardFieldInputTemplate,
isBooleanFieldInputInstance,
isBooleanFieldInputTemplate,
isCLIPEmbedModelFieldInputInstance,
isCLIPEmbedModelFieldInputTemplate,
isColorFieldInputInstance,
isColorFieldInputTemplate,
isControlNetModelFieldInputInstance,
@@ -14,6 +16,10 @@ import {
isEnumFieldInputTemplate,
isFloatFieldInputInstance,
isFloatFieldInputTemplate,
isFluxMainModelFieldInputInstance,
isFluxMainModelFieldInputTemplate,
isFluxVAEModelFieldInputInstance,
isFluxVAEModelFieldInputTemplate,
isImageFieldInputInstance,
isImageFieldInputTemplate,
isIntegerFieldInputInstance,
@@ -38,6 +44,8 @@ import {
isStringFieldInputTemplate,
isT2IAdapterModelFieldInputInstance,
isT2IAdapterModelFieldInputTemplate,
isT5EncoderModelFieldInputInstance,
isT5EncoderModelFieldInputTemplate,
isVAEModelFieldInputInstance,
isVAEModelFieldInputTemplate,
} from 'features/nodes/types/field';
@@ -45,9 +53,12 @@ import { memo } from 'react';
import BoardFieldInputComponent from './inputs/BoardFieldInputComponent';
import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent';
import CLIPEmbedModelFieldInputComponent from './inputs/CLIPEmbedModelFieldInputComponent';
import ColorFieldInputComponent from './inputs/ColorFieldInputComponent';
import ControlNetModelFieldInputComponent from './inputs/ControlNetModelFieldInputComponent';
import EnumFieldInputComponent from './inputs/EnumFieldInputComponent';
import FluxMainModelFieldInputComponent from './inputs/FluxMainModelFieldInputComponent';
import FluxVAEModelFieldInputComponent from './inputs/FluxVAEModelFieldInputComponent';
import ImageFieldInputComponent from './inputs/ImageFieldInputComponent';
import IPAdapterModelFieldInputComponent from './inputs/IPAdapterModelFieldInputComponent';
import LoRAModelFieldInputComponent from './inputs/LoRAModelFieldInputComponent';
@@ -59,6 +70,7 @@ import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputCo
import SpandrelImageToImageModelFieldInputComponent from './inputs/SpandrelImageToImageModelFieldInputComponent';
import StringFieldInputComponent from './inputs/StringFieldInputComponent';
import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent';
import T5EncoderModelFieldInputComponent from './inputs/T5EncoderModelFieldInputComponent';
import VAEModelFieldInputComponent from './inputs/VAEModelFieldInputComponent';
type InputFieldProps = {
@@ -113,6 +125,17 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
return <VAEModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isT5EncoderModelFieldInputInstance(fieldInstance) && isT5EncoderModelFieldInputTemplate(fieldTemplate)) {
return <T5EncoderModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isCLIPEmbedModelFieldInputInstance(fieldInstance) && isCLIPEmbedModelFieldInputTemplate(fieldTemplate)) {
return <CLIPEmbedModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isFluxVAEModelFieldInputInstance(fieldInstance) && isFluxVAEModelFieldInputTemplate(fieldTemplate)) {
return <FluxVAEModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isLoRAModelFieldInputInstance(fieldInstance) && isLoRAModelFieldInputTemplate(fieldTemplate)) {
return <LoRAModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
@@ -145,6 +168,9 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
if (isColorFieldInputInstance(fieldInstance) && isColorFieldInputTemplate(fieldTemplate)) {
return <ColorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isFluxMainModelFieldInputInstance(fieldInstance) && isFluxMainModelFieldInputTemplate(fieldTemplate)) {
return <FluxMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isSDXLMainModelFieldInputInstance(fieldInstance) && isSDXLMainModelFieldInputTemplate(fieldTemplate)) {
return <SDXLMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;

View File

@@ -0,0 +1,60 @@
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldCLIPEmbedValueChanged } from 'features/nodes/store/nodesSlice';
import type { CLIPEmbedModelFieldInputInstance, CLIPEmbedModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useClipEmbedModels } from 'services/api/hooks/modelsByType';
import type { ClipEmbedModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<CLIPEmbedModelFieldInputInstance, CLIPEmbedModelFieldInputTemplate>;
const CLIPEmbedModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const { t } = useTranslation();
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useClipEmbedModels();
const _onChange = useCallback(
(value: ClipEmbedModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldCLIPEmbedValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
return (
<Flex w="full" alignItems="center" gap={2}>
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
<Combobox
value={value}
placeholder={placeholder}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Tooltip>
</Flex>
);
};
export default memo(CLIPEmbedModelFieldInputComponent);

View File

@@ -0,0 +1,55 @@
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { FluxMainModelFieldInputInstance, FluxMainModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useFluxModels } from 'services/api/hooks/modelsByType';
import type { MainModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<FluxMainModelFieldInputInstance, FluxMainModelFieldInputTemplate>;
const FluxMainModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useFluxModels();
const _onChange = useCallback(
(value: MainModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldMainModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
return (
<Flex w="full" alignItems="center" gap={2}>
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
<Combobox
value={value}
placeholder={placeholder}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Flex>
);
};
export default memo(FluxMainModelFieldInputComponent);

View File

@@ -0,0 +1,60 @@
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldFluxVAEModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { FluxVAEModelFieldInputInstance, FluxVAEModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useFluxVAEModels } from 'services/api/hooks/modelsByType';
import type { VAEModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<FluxVAEModelFieldInputInstance, FluxVAEModelFieldInputTemplate>;
const FluxVAEModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const { t } = useTranslation();
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useFluxVAEModels();
const _onChange = useCallback(
(value: VAEModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldFluxVAEModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
return (
<Flex w="full" alignItems="center" gap={2}>
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
<Combobox
value={value}
placeholder={placeholder}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Tooltip>
</Flex>
);
};
export default memo(FluxVAEModelFieldInputComponent);

View File

@@ -0,0 +1,60 @@
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldT5EncoderValueChanged } from 'features/nodes/store/nodesSlice';
import type { T5EncoderModelFieldInputInstance, T5EncoderModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useT5EncoderModels } from 'services/api/hooks/modelsByType';
import type { T5EncoderBnbQuantizedLlmInt8bModelConfig, T5EncoderModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<T5EncoderModelFieldInputInstance, T5EncoderModelFieldInputTemplate>;
const T5EncoderModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const { t } = useTranslation();
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useT5EncoderModels();
const _onChange = useCallback(
(value: T5EncoderBnbQuantizedLlmInt8bModelConfig | T5EncoderModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldT5EncoderValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
return (
<Flex w="full" alignItems="center" gap={2}>
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
<Combobox
value={value}
placeholder={placeholder}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Tooltip>
</Flex>
);
};
export default memo(T5EncoderModelFieldInputComponent);

View File

@@ -6,11 +6,13 @@ import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
import type {
BoardFieldValue,
BooleanFieldValue,
CLIPEmbedModelFieldValue,
ColorFieldValue,
ControlNetModelFieldValue,
EnumFieldValue,
FieldValue,
FloatFieldValue,
FluxVAEModelFieldValue,
ImageFieldValue,
IntegerFieldValue,
IPAdapterModelFieldValue,
@@ -23,15 +25,18 @@ import type {
StatefulFieldValue,
StringFieldValue,
T2IAdapterModelFieldValue,
T5EncoderModelFieldValue,
VAEModelFieldValue,
} from 'features/nodes/types/field';
import {
zBoardFieldValue,
zBooleanFieldValue,
zCLIPEmbedModelFieldValue,
zColorFieldValue,
zControlNetModelFieldValue,
zEnumFieldValue,
zFloatFieldValue,
zFluxVAEModelFieldValue,
zImageFieldValue,
zIntegerFieldValue,
zIPAdapterModelFieldValue,
@@ -44,6 +49,7 @@ import {
zStatefulFieldValue,
zStringFieldValue,
zT2IAdapterModelFieldValue,
zT5EncoderModelFieldValue,
zVAEModelFieldValue,
} from 'features/nodes/types/field';
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
@@ -341,6 +347,15 @@ export const nodesSlice = createSlice({
) => {
fieldValueReducer(state, action, zSpandrelImageToImageModelFieldValue);
},
fieldT5EncoderValueChanged: (state, action: FieldValueAction<T5EncoderModelFieldValue>) => {
fieldValueReducer(state, action, zT5EncoderModelFieldValue);
},
fieldCLIPEmbedValueChanged: (state, action: FieldValueAction<CLIPEmbedModelFieldValue>) => {
fieldValueReducer(state, action, zCLIPEmbedModelFieldValue);
},
fieldFluxVAEModelValueChanged: (state, action: FieldValueAction<FluxVAEModelFieldValue>) => {
fieldValueReducer(state, action, zFluxVAEModelFieldValue);
},
fieldEnumModelValueChanged: (state, action: FieldValueAction<EnumFieldValue>) => {
fieldValueReducer(state, action, zEnumFieldValue);
},
@@ -402,6 +417,9 @@ export const {
fieldSchedulerValueChanged,
fieldStringValueChanged,
fieldVaeModelValueChanged,
fieldT5EncoderValueChanged,
fieldCLIPEmbedValueChanged,
fieldFluxVAEModelValueChanged,
nodeEditorReset,
nodeIsIntermediateChanged,
nodeIsOpenChanged,
@@ -514,6 +532,9 @@ export const isAnyNodeOrEdgeMutation = isAnyOf(
fieldSchedulerValueChanged,
fieldStringValueChanged,
fieldVaeModelValueChanged,
fieldT5EncoderValueChanged,
fieldCLIPEmbedValueChanged,
fieldFluxVAEModelValueChanged,
nodesChanged,
nodeIsIntermediateChanged,
nodeIsOpenChanged,

View File

@@ -61,7 +61,7 @@ export type SchedulerField = z.infer<typeof zSchedulerField>;
// #endregion
// #region Model-related schemas
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner', 'flux']);
const zModelType = z.enum([
'main',
'vae',
@@ -73,9 +73,12 @@ const zModelType = z.enum([
'onnx',
'clip_vision',
'spandrel_image_to_image',
't5_encoder',
'clip_embed',
]);
const zSubModelType = z.enum([
'unet',
'transformer',
'text_encoder',
'text_encoder_2',
'tokenizer',

View File

@@ -31,6 +31,7 @@ export const MODEL_TYPES = [
'ControlNetModelField',
'LoRAModelField',
'MainModelField',
'FluxMainModelField',
'SDXLMainModelField',
'SDXLRefinerModelField',
'VaeModelField',
@@ -38,6 +39,7 @@ export const MODEL_TYPES = [
'VAEField',
'CLIPField',
'T2IAdapterModelField',
'T5EncoderField',
'SpandrelImageToImageModelField',
];
@@ -50,6 +52,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
CLIPField: 'green.500',
ColorField: 'pink.300',
ConditioningField: 'cyan.500',
FluxConditioningField: 'cyan.500',
ControlField: 'teal.500',
ControlNetModelField: 'teal.500',
EnumField: 'blue.500',
@@ -61,6 +64,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
LatentsField: 'pink.500',
LoRAModelField: 'teal.500',
MainModelField: 'teal.500',
FluxMainModelField: 'teal.500',
SDXLMainModelField: 'teal.500',
SDXLRefinerModelField: 'teal.500',
SpandrelImageToImageModelField: 'teal.500',
@@ -68,6 +72,8 @@ export const FIELD_COLORS: { [key: string]: string } = {
T2IAdapterField: 'teal.500',
T2IAdapterModelField: 'teal.500',
UNetField: 'red.500',
T5EncoderField: 'green.500',
TransformerField: 'red.500',
VAEField: 'blue.500',
VAEModelField: 'teal.500',
};

View File

@@ -115,6 +115,10 @@ const zSDXLMainModelFieldType = zFieldTypeBase.extend({
name: z.literal('SDXLMainModelField'),
originalType: zStatelessFieldType.optional(),
});
const zFluxMainModelFieldType = zFieldTypeBase.extend({
name: z.literal('FluxMainModelField'),
originalType: zStatelessFieldType.optional(),
});
const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({
name: z.literal('SDXLRefinerModelField'),
originalType: zStatelessFieldType.optional(),
@@ -143,6 +147,18 @@ const zSpandrelImageToImageModelFieldType = zFieldTypeBase.extend({
name: z.literal('SpandrelImageToImageModelField'),
originalType: zStatelessFieldType.optional(),
});
const zT5EncoderModelFieldType = zFieldTypeBase.extend({
name: z.literal('T5EncoderModelField'),
originalType: zStatelessFieldType.optional(),
});
const zCLIPEmbedModelFieldType = zFieldTypeBase.extend({
name: z.literal('CLIPEmbedModelField'),
originalType: zStatelessFieldType.optional(),
});
const zFluxVAEModelFieldType = zFieldTypeBase.extend({
name: z.literal('FluxVAEModelField'),
originalType: zStatelessFieldType.optional(),
});
const zSchedulerFieldType = zFieldTypeBase.extend({
name: z.literal('SchedulerField'),
originalType: zStatelessFieldType.optional(),
@@ -158,6 +174,7 @@ const zStatefulFieldType = z.union([
zModelIdentifierFieldType,
zMainModelFieldType,
zSDXLMainModelFieldType,
zFluxMainModelFieldType,
zSDXLRefinerModelFieldType,
zVAEModelFieldType,
zLoRAModelFieldType,
@@ -165,6 +182,9 @@ const zStatefulFieldType = z.union([
zIPAdapterModelFieldType,
zT2IAdapterModelFieldType,
zSpandrelImageToImageModelFieldType,
zT5EncoderModelFieldType,
zCLIPEmbedModelFieldType,
zFluxVAEModelFieldType,
zColorFieldType,
zSchedulerFieldType,
]);
@@ -447,6 +467,29 @@ export const isSDXLMainModelFieldInputTemplate = (val: unknown): val is SDXLMain
zSDXLMainModelFieldInputTemplate.safeParse(val).success;
// #endregion
// #region FluxMainModelField
const zFluxMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
const zFluxMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zFluxMainModelFieldValue,
});
const zFluxMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zFluxMainModelFieldType,
originalType: zFieldType.optional(),
default: zFluxMainModelFieldValue,
});
const zFluxMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zFluxMainModelFieldType,
});
export type FluxMainModelFieldInputInstance = z.infer<typeof zFluxMainModelFieldInputInstance>;
export type FluxMainModelFieldInputTemplate = z.infer<typeof zFluxMainModelFieldInputTemplate>;
export const isFluxMainModelFieldInputInstance = (val: unknown): val is FluxMainModelFieldInputInstance =>
zFluxMainModelFieldInputInstance.safeParse(val).success;
export const isFluxMainModelFieldInputTemplate = (val: unknown): val is FluxMainModelFieldInputTemplate =>
zFluxMainModelFieldInputTemplate.safeParse(val).success;
// #endregion
// #region SDXLRefinerModelField
/** @alias */ // tells knip to ignore this duplicate export
@@ -613,6 +656,75 @@ export const isSpandrelImageToImageModelFieldInputTemplate = (
zSpandrelImageToImageModelFieldInputTemplate.safeParse(val).success;
// #endregion
// #region T5EncoderModelField
export const zT5EncoderModelFieldValue = zModelIdentifierField.optional();
const zT5EncoderModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zT5EncoderModelFieldValue,
});
const zT5EncoderModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zT5EncoderModelFieldType,
originalType: zFieldType.optional(),
default: zT5EncoderModelFieldValue,
});
export type T5EncoderModelFieldValue = z.infer<typeof zT5EncoderModelFieldValue>;
export type T5EncoderModelFieldInputInstance = z.infer<typeof zT5EncoderModelFieldInputInstance>;
export type T5EncoderModelFieldInputTemplate = z.infer<typeof zT5EncoderModelFieldInputTemplate>;
export const isT5EncoderModelFieldInputInstance = (val: unknown): val is T5EncoderModelFieldInputInstance =>
zT5EncoderModelFieldInputInstance.safeParse(val).success;
export const isT5EncoderModelFieldInputTemplate = (val: unknown): val is T5EncoderModelFieldInputTemplate =>
zT5EncoderModelFieldInputTemplate.safeParse(val).success;
// #endregion
// #region FluxVAEModelField
export const zFluxVAEModelFieldValue = zModelIdentifierField.optional();
const zFluxVAEModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zFluxVAEModelFieldValue,
});
const zFluxVAEModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zFluxVAEModelFieldType,
originalType: zFieldType.optional(),
default: zFluxVAEModelFieldValue,
});
export type FluxVAEModelFieldValue = z.infer<typeof zFluxVAEModelFieldValue>;
export type FluxVAEModelFieldInputInstance = z.infer<typeof zFluxVAEModelFieldInputInstance>;
export type FluxVAEModelFieldInputTemplate = z.infer<typeof zFluxVAEModelFieldInputTemplate>;
export const isFluxVAEModelFieldInputInstance = (val: unknown): val is FluxVAEModelFieldInputInstance =>
zFluxVAEModelFieldInputInstance.safeParse(val).success;
export const isFluxVAEModelFieldInputTemplate = (val: unknown): val is FluxVAEModelFieldInputTemplate =>
zFluxVAEModelFieldInputTemplate.safeParse(val).success;
// #endregion
// #region CLIPEmbedModelField
export const zCLIPEmbedModelFieldValue = zModelIdentifierField.optional();
const zCLIPEmbedModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zCLIPEmbedModelFieldValue,
});
const zCLIPEmbedModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zCLIPEmbedModelFieldType,
originalType: zFieldType.optional(),
default: zCLIPEmbedModelFieldValue,
});
export type CLIPEmbedModelFieldValue = z.infer<typeof zCLIPEmbedModelFieldValue>;
export type CLIPEmbedModelFieldInputInstance = z.infer<typeof zCLIPEmbedModelFieldInputInstance>;
export type CLIPEmbedModelFieldInputTemplate = z.infer<typeof zCLIPEmbedModelFieldInputTemplate>;
export const isCLIPEmbedModelFieldInputInstance = (val: unknown): val is CLIPEmbedModelFieldInputInstance =>
zCLIPEmbedModelFieldInputInstance.safeParse(val).success;
export const isCLIPEmbedModelFieldInputTemplate = (val: unknown): val is CLIPEmbedModelFieldInputTemplate =>
zCLIPEmbedModelFieldInputTemplate.safeParse(val).success;
// #endregion
// #region SchedulerField
export const zSchedulerFieldValue = zSchedulerField.optional();
@@ -693,6 +805,7 @@ export const zStatefulFieldValue = z.union([
zModelIdentifierFieldValue,
zMainModelFieldValue,
zSDXLMainModelFieldValue,
zFluxMainModelFieldValue,
zSDXLRefinerModelFieldValue,
zVAEModelFieldValue,
zLoRAModelFieldValue,
@@ -700,6 +813,9 @@ export const zStatefulFieldValue = z.union([
zIPAdapterModelFieldValue,
zT2IAdapterModelFieldValue,
zSpandrelImageToImageModelFieldValue,
zT5EncoderModelFieldValue,
zFluxVAEModelFieldValue,
zCLIPEmbedModelFieldValue,
zColorFieldValue,
zSchedulerFieldValue,
]);
@@ -720,6 +836,7 @@ const zStatefulFieldInputInstance = z.union([
zBoardFieldInputInstance,
zModelIdentifierFieldInputInstance,
zMainModelFieldInputInstance,
zFluxMainModelFieldInputInstance,
zSDXLMainModelFieldInputInstance,
zSDXLRefinerModelFieldInputInstance,
zVAEModelFieldInputInstance,
@@ -728,6 +845,9 @@ const zStatefulFieldInputInstance = z.union([
zIPAdapterModelFieldInputInstance,
zT2IAdapterModelFieldInputInstance,
zSpandrelImageToImageModelFieldInputInstance,
zT5EncoderModelFieldInputInstance,
zFluxVAEModelFieldInputInstance,
zCLIPEmbedModelFieldInputInstance,
zColorFieldInputInstance,
zSchedulerFieldInputInstance,
]);
@@ -749,6 +869,7 @@ const zStatefulFieldInputTemplate = z.union([
zBoardFieldInputTemplate,
zModelIdentifierFieldInputTemplate,
zMainModelFieldInputTemplate,
zFluxMainModelFieldInputTemplate,
zSDXLMainModelFieldInputTemplate,
zSDXLRefinerModelFieldInputTemplate,
zVAEModelFieldInputTemplate,
@@ -757,6 +878,9 @@ const zStatefulFieldInputTemplate = z.union([
zIPAdapterModelFieldInputTemplate,
zT2IAdapterModelFieldInputTemplate,
zSpandrelImageToImageModelFieldInputTemplate,
zT5EncoderModelFieldInputTemplate,
zFluxVAEModelFieldInputTemplate,
zCLIPEmbedModelFieldInputTemplate,
zColorFieldInputTemplate,
zSchedulerFieldInputTemplate,
zStatelessFieldInputTemplate,
@@ -779,6 +903,7 @@ const zStatefulFieldOutputTemplate = z.union([
zBoardFieldOutputTemplate,
zModelIdentifierFieldOutputTemplate,
zMainModelFieldOutputTemplate,
zFluxMainModelFieldOutputTemplate,
zSDXLMainModelFieldOutputTemplate,
zSDXLRefinerModelFieldOutputTemplate,
zVAEModelFieldOutputTemplate,

View File

@@ -15,12 +15,16 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
MainModelField: undefined,
SchedulerField: 'euler',
SDXLMainModelField: undefined,
FluxMainModelField: undefined,
SDXLRefinerModelField: undefined,
StringField: '',
T2IAdapterModelField: undefined,
SpandrelImageToImageModelField: undefined,
VAEModelField: undefined,
ControlNetModelField: undefined,
T5EncoderModelField: undefined,
FluxVAEModelField: undefined,
CLIPEmbedModelField: undefined,
};
export const buildFieldInputInstance = (id: string, template: FieldInputTemplate): FieldInputInstance => {

View File

@@ -2,12 +2,15 @@ import { FieldParseError } from 'features/nodes/types/error';
import type {
BoardFieldInputTemplate,
BooleanFieldInputTemplate,
CLIPEmbedModelFieldInputTemplate,
ColorFieldInputTemplate,
ControlNetModelFieldInputTemplate,
EnumFieldInputTemplate,
FieldInputTemplate,
FieldType,
FloatFieldInputTemplate,
FluxMainModelFieldInputTemplate,
FluxVAEModelFieldInputTemplate,
ImageFieldInputTemplate,
IntegerFieldInputTemplate,
IPAdapterModelFieldInputTemplate,
@@ -22,6 +25,7 @@ import type {
StatelessFieldInputTemplate,
StringFieldInputTemplate,
T2IAdapterModelFieldInputTemplate,
T5EncoderModelFieldInputTemplate,
VAEModelFieldInputTemplate,
} from 'features/nodes/types/field';
import { isStatefulFieldType } from 'features/nodes/types/field';
@@ -180,6 +184,20 @@ const buildSDXLMainModelFieldInputTemplate: FieldInputTemplateBuilder<SDXLMainMo
return template;
};
const buildFluxMainModelFieldInputTemplate: FieldInputTemplateBuilder<FluxMainModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: FluxMainModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildRefinerModelFieldInputTemplate: FieldInputTemplateBuilder<SDXLRefinerModelFieldInputTemplate> = ({
schemaObject,
baseField,
@@ -208,6 +226,48 @@ const buildVAEModelFieldInputTemplate: FieldInputTemplateBuilder<VAEModelFieldIn
return template;
};
const buildT5EncoderModelFieldInputTemplate: FieldInputTemplateBuilder<T5EncoderModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: T5EncoderModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildCLIPEmbedModelFieldInputTemplate: FieldInputTemplateBuilder<CLIPEmbedModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: CLIPEmbedModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildFluxVAEModelFieldInputTemplate: FieldInputTemplateBuilder<FluxVAEModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: FluxVAEModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildLoRAModelFieldInputTemplate: FieldInputTemplateBuilder<LoRAModelFieldInputTemplate> = ({
schemaObject,
baseField,
@@ -386,11 +446,15 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
MainModelField: buildMainModelFieldInputTemplate,
SchedulerField: buildSchedulerFieldInputTemplate,
SDXLMainModelField: buildSDXLMainModelFieldInputTemplate,
FluxMainModelField: buildFluxMainModelFieldInputTemplate,
SDXLRefinerModelField: buildRefinerModelFieldInputTemplate,
StringField: buildStringFieldInputTemplate,
T2IAdapterModelField: buildT2IAdapterModelFieldInputTemplate,
SpandrelImageToImageModelField: buildSpandrelImageToImageModelFieldInputTemplate,
VAEModelField: buildVAEModelFieldInputTemplate,
T5EncoderModelField: buildT5EncoderModelFieldInputTemplate,
CLIPEmbedModelField: buildCLIPEmbedModelFieldInputTemplate,
FluxVAEModelField: buildFluxVAEModelFieldInputTemplate,
} as const;
export const buildFieldInputTemplate = (

View File

@@ -29,6 +29,7 @@ const MODEL_FIELD_TYPES = [
'ModelIdentifier',
'MainModelField',
'SDXLMainModelField',
'FluxMainModelField',
'SDXLRefinerModelField',
'VAEModelField',
'LoRAModelField',

View File

@@ -8,7 +8,7 @@ import { modelSelected } from 'features/parameters/store/actions';
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useMainModels } from 'services/api/hooks/modelsByType';
import { useSDMainModels } from 'services/api/hooks/modelsByType';
import type { MainModelConfig } from 'services/api/types';
const selectModel = createMemoizedSelector(selectGenerationSlice, (generation) => generation.model);
@@ -17,7 +17,7 @@ const ParamMainModelSelect = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const selectedModel = useAppSelector(selectModel);
const [modelConfigs, { isLoading }] = useMainModels();
const [modelConfigs, { isLoading }] = useSDMainModels();
const tooltipLabel = useMemo(() => {
if (!modelConfigs.length || !selectedModel) {
return;

View File

@@ -9,6 +9,7 @@ export const MODEL_TYPE_MAP = {
'sd-2': 'Stable Diffusion 2.x',
sdxl: 'Stable Diffusion XL',
'sdxl-refiner': 'Stable Diffusion XL Refiner',
flux: 'Flux',
};
/**
@@ -20,6 +21,7 @@ export const MODEL_TYPE_SHORT_MAP = {
'sd-2': 'SD2.X',
sdxl: 'SDXL',
'sdxl-refiner': 'SDXLR',
flux: 'FLUX',
};
/**
@@ -46,6 +48,10 @@ export const CLIP_SKIP_MAP = {
maxClip: 24,
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
},
flux: {
maxClip: 0,
markers: [],
},
};
/**

View File

@@ -48,7 +48,12 @@ export const UpscaleSettingsAccordion = memo(() => {
});
return (
<StandaloneAccordion label="Upscale" badges={badges} isOpen={isOpenAccordion} onToggle={onToggleAccordion}>
<StandaloneAccordion
label={t('upscaling.upscale')}
badges={badges}
isOpen={isOpenAccordion}
onToggle={onToggleAccordion}
>
<Flex pt={4} px={4} w="full" h="full" flexDir="column" data-testid="upscale-settings-accordion">
<Flex flexDir="column" gap={4}>
<Flex gap={4}>

View File

@@ -1,4 +1,4 @@
import { Badge, Flex, IconButton, Text, Tooltip } from '@invoke-ai/ui-library';
import { Badge, Flex, IconButton, Spacer, Text, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { negativePromptChanged, positivePromptChanged } from 'features/controlLayers/store/controlLayersSlice';
import { usePresetModifiedPrompts } from 'features/stylePresets/hooks/usePresetModifiedPrompts';
@@ -69,45 +69,40 @@ export const ActiveStylePreset = () => {
);
}
return (
<Flex justifyContent="space-between" w="full" alignItems="center">
<Flex gap={2} alignItems="center">
<StylePresetImage imageWidth={25} presetImageUrl={activeStylePreset.image} />
<Flex flexDir="column">
<Badge colorScheme="invokeBlue" variant="subtle">
{activeStylePreset.name}
</Badge>
</Flex>
</Flex>
<Flex gap={1}>
<Tooltip label={t('stylePresets.toggleViewMode')}>
<IconButton
onClick={handleToggleViewMode}
variant="outline"
size="sm"
aria-label={t('stylePresets.toggleViewMode')}
colorScheme={viewMode ? 'invokeBlue' : 'base'}
icon={<PiEyeBold />}
/>
</Tooltip>
<Tooltip label={t('stylePresets.flatten')}>
<IconButton
onClick={handleFlattenPrompts}
variant="outline"
size="sm"
aria-label={t('stylePresets.flatten')}
icon={<PiStackSimpleBold />}
/>
</Tooltip>
<Tooltip label={t('stylePresets.clearTemplateSelection')}>
<IconButton
onClick={handleClearActiveStylePreset}
variant="outline"
size="sm"
aria-label={t('stylePresets.clearTemplateSelection')}
icon={<PiXBold />}
/>
</Tooltip>
</Flex>
<Flex w="full" alignItems="center" gap={2} minW={0}>
<StylePresetImage imageWidth={25} presetImageUrl={activeStylePreset.image} />
<Badge colorScheme="invokeBlue" variant="subtle" justifySelf="flex-start">
{activeStylePreset.name}
</Badge>
<Spacer />
<Tooltip label={t('stylePresets.toggleViewMode')}>
<IconButton
onClick={handleToggleViewMode}
variant="outline"
size="sm"
aria-label={t('stylePresets.toggleViewMode')}
colorScheme={viewMode ? 'invokeBlue' : 'base'}
icon={<PiEyeBold />}
/>
</Tooltip>
<Tooltip label={t('stylePresets.flatten')}>
<IconButton
onClick={handleFlattenPrompts}
variant="outline"
size="sm"
aria-label={t('stylePresets.flatten')}
icon={<PiStackSimpleBold />}
/>
</Tooltip>
<Tooltip label={t('stylePresets.clearTemplateSelection')}>
<IconButton
onClick={handleClearActiveStylePreset}
variant="outline"
size="sm"
aria-label={t('stylePresets.clearTemplateSelection')}
icon={<PiXBold />}
/>
</Tooltip>
</Flex>
);
};

View File

@@ -16,9 +16,9 @@ export const StylePresetExportButton = () => {
const { t } = useTranslation();
const { presetCount } = useListStylePresetsQuery(undefined, {
selectFromResult: ({ data }) => {
const userPresets = data?.filter((preset) => preset.type === 'user') ?? EMPTY_ARRAY;
const presetsToExport = data?.filter((preset) => preset.type !== 'default') ?? EMPTY_ARRAY;
return {
presetCount: userPresets.length,
presetCount: presetsToExport.length,
};
},
});

View File

@@ -30,8 +30,8 @@ export const StylePresetForm = ({
updatingStylePresetId: string | null;
formData: StylePresetFormData | null;
}) => {
const [createStylePreset] = useCreateStylePresetMutation();
const [updateStylePreset] = useUpdateStylePresetMutation();
const [createStylePreset, { isLoading: isCreating }] = useCreateStylePresetMutation();
const [updateStylePreset, { isLoading: isUpdating }] = useUpdateStylePresetMutation();
const { t } = useTranslation();
const allowPrivateStylePresets = useAppSelector((s) => s.config.allowPrivateStylePresets);
@@ -93,8 +93,8 @@ export const StylePresetForm = ({
</FormControl>
</Flex>
<StylePresetPromptField label="Positive Prompt" control={control} name="positivePrompt" />
<StylePresetPromptField label="Negative Prompt" control={control} name="negativePrompt" />
<StylePresetPromptField label={t('stylePresets.positivePrompt')} control={control} name="positivePrompt" />
<StylePresetPromptField label={t('stylePresets.negativePrompt')} control={control} name="negativePrompt" />
<Box>
<Text variant="subtext">{t('stylePresets.promptTemplatesDesc1')}</Text>
<Text variant="subtext">
@@ -109,7 +109,11 @@ export const StylePresetForm = ({
<Flex justifyContent="space-between" alignItems="flex-end" gap={10}>
{allowPrivateStylePresets ? <StylePresetTypeField control={control} name="type" /> : <Spacer />}
<Button onClick={handleSubmit(handleClickSave)} isDisabled={!formState.isValid}>
<Button
onClick={handleSubmit(handleClickSave)}
isDisabled={!formState.isValid}
isLoading={isCreating || isUpdating}
>
{t('common.save')}
</Button>
</Flex>

View File

@@ -48,9 +48,13 @@ export const StylePresetModal = () => {
} else {
let file = null;
if (data.imageUrl) {
const blob = await convertImageUrlToBlob(data.imageUrl);
if (blob) {
file = new File([blob], 'style_preset.png', { type: 'image/png' });
try {
const blob = await convertImageUrlToBlob(data.imageUrl);
if (blob) {
file = new File([blob], 'style_preset.png', { type: 'image/png' });
}
} catch (error) {
// do nothing
}
}
setFormData({

View File

@@ -21,6 +21,7 @@ const StylePresetImage = ({ presetImageUrl, imageWidth }: { presetImageUrl: stri
/>
)
}
p={2}
>
<Image
src={presetImageUrl || ''}

View File

@@ -77,7 +77,7 @@ export const StylePresetListItem = ({ preset }: { preset: StylePresetRecordWithI
const handleDeletePreset = useCallback(async () => {
try {
await deleteStylePreset(preset.id);
await deleteStylePreset(preset.id).unwrap();
toast({
status: 'success',
title: t('stylePresets.templateDeleted'),
@@ -174,7 +174,8 @@ export const StylePresetListItem = ({ preset }: { preset: StylePresetRecordWithI
onClose={onClose}
title={t('stylePresets.deleteTemplate')}
acceptCallback={handleDeletePreset}
acceptButtonText="Delete"
acceptButtonText={t('common.delete')}
cancelButtonText={t('common.cancel')}
>
<p>{t('stylePresets.deleteTemplate2')}</p>
</ConfirmationAlertDialog>

View File

@@ -29,14 +29,14 @@ export const StylePresetMenuTrigger = () => {
py={2}
px={3}
borderRadius="base"
gap={1}
gap={2}
role="button"
_hover={_hover}
transitionProperty="background-color"
transitionDuration="normal"
w="full"
>
<ActiveStylePreset />
<IconButton aria-label={t('stylePresets.viewList')} variant="ghost" icon={<PiCaretDownBold />} size="sm" />
</Flex>
);

View File

@@ -1,6 +1,7 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig } from 'app/store/store';
import { stylePresetsApi } from 'services/api/endpoints/stylePresets';
import type { StylePresetState } from './types';
@@ -24,6 +25,26 @@ export const stylePresetSlice = createSlice({
state.viewMode = action.payload;
},
},
extraReducers(builder) {
builder.addMatcher(stylePresetsApi.endpoints.deleteStylePreset.matchFulfilled, (state, action) => {
if (state.activeStylePresetId === null) {
return;
}
const deletedId = action.meta.arg.originalArgs;
if (state.activeStylePresetId === deletedId) {
state.activeStylePresetId = null;
}
});
builder.addMatcher(stylePresetsApi.endpoints.listStylePresets.matchFulfilled, (state, action) => {
if (state.activeStylePresetId === null) {
return;
}
const ids = action.payload.map((preset) => preset.id);
if (!ids.includes(state.activeStylePresetId)) {
state.activeStylePresetId = null;
}
});
},
});
export const { activeStylePresetIdChanged, searchTermChanged, viewModeChanged } = stylePresetSlice.actions;

View File

@@ -0,0 +1,53 @@
import { getViewModeChunks } from 'features/stylePresets/util/getViewModeChunks';
import { describe, expect, it } from 'vitest';
describe('getViewModeChunks', () => {
it('should return empty strings when presetPrompt is not provided', () => {
const currentPrompt = 'current prompt';
const presetPrompt = undefined;
const result = getViewModeChunks(currentPrompt, presetPrompt);
expect(result).toEqual(['', currentPrompt, '']);
});
it('should return empty strings when presetPrompt is empty', () => {
const currentPrompt = 'current prompt';
const presetPrompt = '';
const result = getViewModeChunks(currentPrompt, presetPrompt);
expect(result).toEqual(['', currentPrompt, '']);
});
it('should append presetPrompt to currentPrompt when presetPrompt does not contain PRESET_PLACEHOLDER', () => {
const currentPrompt = 'current prompt';
const presetPrompt = 'preset prompt';
const result = getViewModeChunks(currentPrompt, presetPrompt);
expect(result).toEqual(['', `${currentPrompt} `, presetPrompt]);
});
it('should split presetPrompt into 3 parts when presetPrompt contains PRESET_PLACEHOLDER', () => {
const currentPrompt = 'current prompt';
const presetPrompt = 'before {prompt} after';
const result = getViewModeChunks(currentPrompt, presetPrompt);
expect(result).toEqual(['before ', currentPrompt, ' after']);
});
it('should split presetPrompt into 3 parts when presetPrompt contains multiple PRESET_PLACEHOLDER', () => {
const currentPrompt = 'current prompt';
const presetPrompt = 'before {prompt} middle {prompt} after';
const result = getViewModeChunks(currentPrompt, presetPrompt);
expect(result).toEqual(['before ', currentPrompt, ' middle {prompt} after']);
});
it('should handle the PRESET_PLACEHOLDER being at the start of the presetPrompt', () => {
const currentPrompt = 'current prompt';
const presetPrompt = '{prompt} after';
const result = getViewModeChunks(currentPrompt, presetPrompt);
expect(result).toEqual(['', currentPrompt, ' after']);
});
it('should handle the PRESET_PLACEHOLDER being at the end of the presetPrompt', () => {
const currentPrompt = 'current prompt';
const presetPrompt = 'before {prompt}';
const result = getViewModeChunks(currentPrompt, presetPrompt);
expect(result).toEqual(['before ', currentPrompt, '']);
});
});

View File

@@ -0,0 +1,17 @@
import { PRESET_PLACEHOLDER } from 'features/stylePresets/hooks/usePresetModifiedPrompts';
export const getViewModeChunks = (currentPrompt: string, presetPrompt?: string): [string, string, string] => {
if (!presetPrompt || !presetPrompt.length) {
return ['', currentPrompt, ''];
}
// When preset prompt does not contain the placeholder, we append the preset to the current prompt
if (!presetPrompt.includes(PRESET_PLACEHOLDER)) {
return ['', `${currentPrompt} `, presetPrompt];
}
// Otherwise, we split the preset prompt into 3 parts: before, current, and after the placeholder
const [before, ...after] = presetPrompt.split(PRESET_PLACEHOLDER);
return [before || '', currentPrompt, after.join(PRESET_PLACEHOLDER) || ''];
};

View File

@@ -1,15 +0,0 @@
import { PRESET_PLACEHOLDER } from 'features/stylePresets/hooks/usePresetModifiedPrompts';
export const getViewModeChunks = (currentPrompt: string, presetPrompt?: string): [string, string, string] => {
if (!presetPrompt || !presetPrompt.length) {
return ['', currentPrompt, ''];
}
const [before, after] = presetPrompt.split(PRESET_PLACEHOLDER, 2);
if (!before || !after) {
return ['', `${currentPrompt} `, before || after || ''];
}
return [before ?? '', currentPrompt, after ?? ''];
};

View File

@@ -94,7 +94,7 @@ export const stylePresetsApi = api.injectEndpoints({
}),
exportStylePresets: build.query<string, void>({
query: () => ({
url: buildStylePresetsUrl('/export'),
url: buildStylePresetsUrl('export'),
responseHandler: (response) => response.text(),
}),
providesTags: ['FetchOnReconnect', { type: 'StylePreset', id: LIST_TAG }],

Some files were not shown because too many files have changed in this diff Show More