Compare commits

..

368 Commits

Author SHA1 Message Date
Ryan Dick
6bcf48aa37 WIP - Started working towards MultiDiffusion batching. 2024-06-18 15:44:39 -04:00
Ryan Dick
b1bb1511fe Delete rough notes. 2024-06-18 15:36:36 -04:00
Ryan Dick
99046a8145 Fix advanced scheduler behaviour in MultiDiffusionPipeline. 2024-06-18 15:36:36 -04:00
Ryan Dick
72be7e71e3 Fix handling of stateful schedulers in MultiDiffusionPipeline. 2024-06-18 15:36:36 -04:00
Ryan Dick
35adaf1c17 Connect TiledMultiDiffusionDenoiseLatents to the MultiDiffusionPipeline backend. 2024-06-18 15:36:34 -04:00
Ryan Dick
865c2335de Remove regional conditioning logic from MultiDiffusionPipeline - it is not yet supported. 2024-06-18 15:35:52 -04:00
Ryan Dick
49ca42f84a Initial (untested) implementation of MultiDiffusionPipeline. 2024-06-18 15:35:52 -04:00
Ryan Dick
493fcd8660 Remove inpainting support from MultiDiffusionPipeline. 2024-06-18 15:35:52 -04:00
Ryan Dick
20322d781e Remove IP-Adapter and T2I-Adapter support from MultiDiffusionPipeline. 2024-06-18 15:35:52 -04:00
Ryan Dick
889d13e02a Document plan for the rest of the MultiDiffusion implementation. 2024-06-18 15:35:52 -04:00
Ryan Dick
6ccd2a867b Add detailed docstring to latents_from_embeddings(). 2024-06-18 15:35:52 -04:00
Ryan Dick
5861fa1719 Copy StableDiffusionGeneratorPipeline as a starting point for a new MultiDiffusionPipeline. 2024-06-18 15:35:52 -04:00
Ryan Dick
dfd4beb62b Simplify handling of inpainting models. Improve the in-code documentation around inpainting. 2024-06-18 15:35:52 -04:00
Ryan Dick
83df0c0df5 Minor tidying of latents_from_embeddings(...). 2024-06-18 15:35:52 -04:00
Ryan Dick
c58c4069a7 Consolidate latents_from_embeddings(...) and generate_latents_from_embeddings(...) into a single function. 2024-06-18 15:35:52 -04:00
Ryan Dick
3937fffa94 Fix invocation name of tiled_multi_diffusion_denoise_latents. 2024-06-18 15:35:52 -04:00
Ryan Dick
bbf5f67691 Improve clarity of comments regarded when 'noise' and 'latents' are expected to be set. 2024-06-18 15:35:52 -04:00
Ryan Dick
2f5c147b84 Fix static check errors on imports in diffusers_pipeline.py. 2024-06-18 15:35:52 -04:00
Ryan Dick
bd2839b748 Remove a condition for handling inpainting models that never resolves to True. The same logic is already applied earlier by AddsMaskLatents. 2024-06-18 15:35:52 -04:00
Ryan Dick
4f70dd7ce1 Add clarifying comment to explain why noise might be None in latents_from_embedding(). 2024-06-18 15:35:52 -04:00
Ryan Dick
066672fbfd Remove unused are_like_tensors() function. 2024-06-18 15:35:52 -04:00
Ryan Dick
abefaee4d1 Remove unused StableDiffusionGeneratorPipeline.use_ip_adapter member. 2024-06-18 15:35:52 -04:00
Ryan Dick
3254ba5904 Remove unused StableDiffusionGeneratorPipeline.control_model. 2024-06-18 15:35:52 -04:00
Ryan Dick
73a8c55852 Stricter typing for the is_gradient_mask: bool. 2024-06-18 15:35:52 -04:00
Ryan Dick
f82af7c22d Fix typing of control_data to reflect that it can be None. 2024-06-18 15:35:52 -04:00
Ryan Dick
3aef717ef4 Fix typing of timesteps and init_timestep. 2024-06-18 15:35:52 -04:00
Ryan Dick
c2cf1137e9 Fix typing to reflect that the callback arg to latents_from_embeddings is never None. 2024-06-18 15:35:52 -04:00
Ryan Dick
803a24bc0a Move seed above optional params. 2024-06-18 15:35:52 -04:00
Ryan Dick
7d24ad8ccd Simplify handling of AddsMaskGuidance, and fix some related type errors. 2024-06-18 15:35:52 -04:00
Ryan Dick
cb389063b2 Remove unused num_inference_steps. 2024-06-18 15:35:52 -04:00
Ryan Dick
81b8a69e1a WIP TiledMultiDiffusionDenoiseLatents. Updated parameter list and first half of the logic. 2024-06-18 15:35:50 -04:00
Ryan Dick
7ee5db87ad Tidy DenoiseLatentsInvocation.prep_control_data(...) and fix some type errors. 2024-06-18 15:34:30 -04:00
Ryan Dick
66cf2c59bd Make DenoiseLatentsInvocation.prep_control_data(...) a staticmethod so that it can be called externally. 2024-06-18 15:34:30 -04:00
Ryan Dick
3bad1367e9 Copy TiledStableDiffusionRefineInvocation as a starting point for TiledMultiDiffusionDenoiseLatents.py 2024-06-18 15:34:22 -04:00
Ryan Dick
867a7642a6 Change tiling strategy to make TiledStableDiffusionRefineInvocation work with more tile shapes and overlaps. 2024-06-18 15:31:58 -04:00
Ryan Dick
d9d1c8f9cb Expose a few more params from TiledStableDiffusionRefineInvocation. 2024-06-18 15:31:58 -04:00
Ryan Dick
e03eb7fb45 Add support for LoRA models in TiledStableDiffusionRefineInvocation. 2024-06-18 15:31:58 -04:00
Ryan Dick
85db33bc7e Add naive ControlNet support to TiledStableDiffusionRefineInvocation 2024-06-18 15:31:58 -04:00
Ryan Dick
93e3a2b504 Fix ControlNetModel type hint import source. 2024-06-18 15:31:58 -04:00
Ryan Dick
6a7a26f1bf Rough prototype of TiledStableDiffusionRefineInvocation is working. 2024-06-18 15:31:58 -04:00
Ryan Dick
08ca03ef9f WIP - TiledStableDiffusionRefine 2024-06-18 15:31:54 -04:00
Ryan Dick
ccf90b6bd6 Minor improvements to LatentsToImageInvocation type hints. 2024-06-18 15:31:21 -04:00
Ryan Dick
753239b48d Expose vae_decode(...) as a staticmethod on LatentsToImageInvocation. 2024-06-18 15:31:21 -04:00
Ryan Dick
65fa4664c9 Fix return type of prepare_noise_and_latents(...). 2024-06-18 15:31:21 -04:00
Ryan Dick
297570ded3 Make init_scheduler() a staticmethod on DenoiseLatentsInvocation so that it can be called externally. 2024-06-18 15:31:21 -04:00
Ryan Dick
680fdcf293 Only allow a single positive/negative prompt conditioning input for tiled refine. 2024-06-18 15:31:21 -04:00
Ryan Dick
5ff91f2c44 WIP on TiledStableDiffusionRefine 2024-06-18 15:31:14 -04:00
Ryan Dick
69aa7057e7 Convert several methods in DenoiseLatentsInvocation to staticmethods so that they can be called externally. 2024-06-18 15:25:08 -04:00
Ryan Dick
d3932f40de Simplify the logic in prepare_noise_and_latents(...). 2024-06-18 15:25:08 -04:00
Ryan Dick
ee74cd7fab Split out the prepare_noise_and_latents(...) logic in DenoiseLatentsInvocation so that it can be called from other invocations. 2024-06-18 15:25:08 -04:00
Ryan Dick
bda25b40c9 (minor) Add a TODO note to get_scheduler(...). 2024-06-18 15:25:08 -04:00
Ryan Dick
7e9a89f8c6 Tidy SilenceWarnings context manager (#6493)
## Summary

No functional changes, just cleaning some things up as I touch the code.
This PR cleans up the `SilenceWarnings` context manager:
- Fix type errors
- Enable SilenceWarnings to be used as both a context manager and a
decorator
- Remove duplicate implementation
- Check the initial verbosity on `__enter__()` rather than `__init__()`
- Save an indentation level in DenoiseLatents

## QA Instructions

I generated an image to confirm that warnings are still muted.

## Merge Plan

- [x] ⚠️ Merge https://github.com/invoke-ai/InvokeAI/pull/6492 first,
then change the target branch to `main`.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
2024-06-18 15:23:32 -04:00
Ryan Dick
79ceac2f82 (minor) Use SilenceWarnings as a decorator rather than a context manager to save an indentation level. 2024-06-18 15:06:22 -04:00
Ryan Dick
8e47e005a7 Tidy SilenceWarnings context manager:
- Fix type errors
- Enable SilenceWarnings to be used as both a context manager and a decorator
- Remove duplicate implementation
- Check the initial verbosity on __enter__() rather than __init__()
2024-06-18 15:06:22 -04:00
Ryan Dick
d13aafb514 Tidy denoise_latents.py imports to all use absolute import paths. 2024-06-18 15:06:22 -04:00
Brandon Rising
63a7e19dbf Run ruff 2024-06-18 10:38:29 -04:00
Brandon Rising
fbc5a8ec65 Ignore validation on improperly formatted hashes (pytest) 2024-06-18 10:38:29 -04:00
Brandon Rising
8ce6e4540e Run ruff 2024-06-18 10:38:29 -04:00
Brandon Rising
f14f377ede Update validator list 2024-06-18 10:38:29 -04:00
Brandon Rising
1925f83f5e Update validator list 2024-06-18 10:38:29 -04:00
Brandon Rising
3a5ad6d112 Update validator list 2024-06-18 10:38:29 -04:00
Brandon Rising
41a6bb45f3 Initial functionality 2024-06-18 10:38:29 -04:00
chainchompa
70e40fa6c1 added route to install huggingface models from model marketplace (#6515)
## Summary
added route to install huggingface models from model marketplace
<!--A description of the changes in this PR. Include the kind of change
(fix, feature, docs, etc), the "why" and the "how". Screenshots or
videos are useful for frontend changes.-->

## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions
test by going to
http://localhost:5173/api/v2/models/install/huggingface?source=${hfRepo}
<!--WHEN APPLICABLE: Describe how we can test the changes in this PR.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
2024-06-16 21:13:58 -04:00
psychedelicious
e26125b734 tests: fix test_model_install.py 2024-06-17 10:57:11 +10:00
psychedelicious
cd70937b7f feat(api): improved model install confirmation page styling & messaging 2024-06-17 10:51:08 +10:00
psychedelicious
f002bca2fa feat(ui): handle new model_install_download_started event
When a model install is initiated from outside the client, we now trigger the model manager tab's model install list to update.

- Handle new `model_install_download_started` event
- Handle `model_install_download_complete` event (this event is not new but was never handled)
- Update optimistic updates/cache invalidation logic to efficiently update the model install list
2024-06-17 10:07:10 +10:00
psychedelicious
56771de856 feat(ui): add redux actions for model_install_download_started event 2024-06-17 09:52:46 +10:00
psychedelicious
c11478a94a chore(ui): typegen 2024-06-17 09:51:18 +10:00
psychedelicious
fb694b3e17 feat(app): add model_install_download_started event
Previously, we used `model_install_download_progress` for both download starting and progressing. When handling this event, we don't know which actual thing it represents.

Add `model_install_download_started` event to explicitly represent a model download started event.
2024-06-17 09:50:25 +10:00
psychedelicious
1bc98abc76 docs(ui): explain model install events 2024-06-17 09:33:46 +10:00
chainchompa
7f03b04b2f Merge branch 'main' into chainchompa/model-install-deeplink 2024-06-14 17:16:25 -04:00
chainchompa
4029972530 formatting 2024-06-14 17:15:55 -04:00
chainchompa
328f160e88 refetch model installs when a new model install starts 2024-06-14 17:09:07 -04:00
chainchompa
aae318425d added route for installing huggingface model from model marketplace 2024-06-14 17:08:39 -04:00
Ryan Dick
785bb1d9e4 Fix all comparisons against the DEFAULT_PRECISION constant. DEFAULT_PRECISION is a torch.dtype. Previously, it was compared to a str in a number of places where it would always resolve to False. This is a bugfix that results in a change to the default behavior. In practice, this will not change the behavior for many users, because it only causes a change in behavior if a users has configured float32 as their default precision. 2024-06-14 11:26:10 -07:00
Lincoln Stein
a3cb5da130 Improve RAM<->VRAM memory copy performance in LoRA patching and elsewhere (#6490)
* allow model patcher to optimize away the unpatching step when feasible

* remove lazy_offloading functionality

* allow model patcher to optimize away the unpatching step when feasible

* remove lazy_offloading functionality

* do not save original weights if there is a CPU copy of state dict

* Update invokeai/backend/model_manager/load/load_base.py

Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>

* documentation fixes requested during penultimate review

* add non-blocking=True parameters to several torch.nn.Module.to() calls, for slight performance increases

* fix ruff errors

* prevent crash on non-cuda-enabled systems

---------

Co-authored-by: Lincoln Stein <lstein@gmail.com>
Co-authored-by: Kent Keirsey <31807370+hipsterusername@users.noreply.github.com>
Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>
2024-06-13 17:10:03 +00:00
blessedcoolant
568a4844f7 fix: other recursive imports 2024-06-10 04:12:20 -07:00
blessedcoolant
b1e56e2485 fix: SchedulerOutput not being imported correctly 2024-06-10 04:12:20 -07:00
Kent Keirsey
9432336e2b Add simplified model manager install API to InvocationContext (#6132)
## Summary

This three two model manager-related methods to the InvocationContext
uniform API. They are accessible via `context.models.*`:

1. **`load_local_model(model_path: Path, loader:
Optional[Callable[[Path], AnyModel]] = None) ->
LoadedModelWithoutConfig`**

*Load the model located at the indicated path.*

This will load a local model (.safetensors, .ckpt or diffusers
directory) into the model manager RAM cache and return its
`LoadedModelWithoutConfig`. If the optional loader argument is provided,
the loader will be invoked to load the model into memory. Otherwise the
method will call `safetensors.torch.load_file()` `torch.load()` (with a
pickle scan), or `from_pretrained()` as appropriate to the path type.

Be aware that the `LoadedModelWithoutConfig` object differs from
`LoadedModel` by having no `config` attribute.

Here is an example of usage:

```
def invoke(self, context: InvocatinContext) -> ImageOutput:
       model_path = Path('/opt/models/RealESRGAN_x4plus.pth')
       loadnet = context.models.load_local_model(model_path)
       with loadnet as loadnet_model:
             upscaler = RealESRGAN(loadnet=loadnet_model,...)
```

---

2. **`load_remote_model(source: str | AnyHttpUrl, loader:
Optional[Callable[[Path], AnyModel]] = None) ->
LoadedModelWithoutConfig`**

*Load the model located at the indicated URL or repo_id.*

This is similar to `load_local_model()` but it accepts either a
HugginFace repo_id (as a string), or a URL. The model's file(s) will be
downloaded to `models/.download_cache` and then loaded, returning a

```
def invoke(self, context: InvocatinContext) -> ImageOutput:
       model_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'
       loadnet = context.models.load_remote_model(model_url)
       with loadnet as loadnet_model:
             upscaler = RealESRGAN(loadnet=loadnet_model,...)
```
---

3. **`download_and_cache_model( source: str | AnyHttpUrl, access_token:
Optional[str] = None, timeout: Optional[int] = 0) -> Path`**

Download the model file located at source to the models cache and return
its Path. This will check `models/.download_cache` for the desired model
file and download it from the indicated source if not already present.
The local Path to the downloaded file is then returned.

---

## Other Changes

This PR performs a migration, in which it renames `models/.cache` to
`models/.convert_cache`, and migrates previously-downloaded ESRGAN,
openpose, DepthAnything and Lama inpaint models from the `models/core`
directory into `models/.download_cache`.

There are a number of legacy model files in `models/core`, such as
GFPGAN, which are no longer used. This PR deletes them and tidies up the
`models/core` directory.

## Related Issues / Discussions

I have systematically replaced all the calls to
`download_with_progress_bar()`. This function is no longer used
elsewhere and has been removed.

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

I have added unit tests for the three new calls. You may test that the
`load_and_cache_model()` call is working by running the upscaler within
the web app. On first try, you will see the model file being downloaded
into the models `.cache` directory. On subsequent tries, the model will
either load from RAM (if it hasn't been displaced) or will be loaded
from the filesystem.

<!--WHEN APPLICABLE: Describe how we can test the changes in this PR.-->

## Merge Plan

Squash merge when approved.

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [X] _The PR has a short but descriptive title, suitable for a
changelog_
- [X] _Tests added / updated (if applicable)_
- [X] _Documentation added / updated (if applicable)_
2024-06-08 16:24:31 -07:00
Lincoln Stein
7d19af2caa Merge branch 'main' into lstein/feat/simple-mm2-api 2024-06-08 18:55:06 -04:00
Ryan Dick
0dbec3ad8b Split up latent.py (code reorganization, no functional changes) (#6491)
## Summary

I've started working towards a better tiled upscaling implementation. It
is going to require some refactoring of `DenoiseLatentsInvocation`. As a
first step, this PR splits up all of the invocations in latent.py into
their own files. That file had become a bit of a dumping ground - it
should be a bit more manageable to work with now.

This PR just re-organizes the code. There should be no functional
changes.

## QA Instructions

I've done some light smoke testing. I'll do some more before merging.
The main risk is that I missed a broken import, or some other copy-paste
error.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_: N/A
- [x] _Documentation added / updated (if applicable)_: N/A
2024-06-07 12:01:56 -04:00
Ryan Dick
52c0c4a32f Rename latent.py -> denoise_latents.py. 2024-06-07 09:28:42 -04:00
Ryan Dick
8f1afc032a Move SchedulerInvocation to a new file. No functional changes. 2024-06-07 09:28:42 -04:00
Ryan Dick
854bca668a Move CreateDenoiseMaskInvocation to its own file. No functional changes. 2024-06-07 09:28:42 -04:00
Ryan Dick
fea9013cad Move CreateGradientMaskInvocation to its own file. No functional changes. 2024-06-07 09:28:42 -04:00
Ryan Dick
045caddee1 Move LatentsToImageInvocation to its own file. No functional changes. 2024-06-07 09:28:42 -04:00
Ryan Dick
58697141bf Move ImageToLatentsInvocation to its own file. No functional changes. 2024-06-07 09:28:42 -04:00
Ryan Dick
5e419dbb56 Move ScaleLatentsInvocation and ResizeLatentsInvocation to their own file. No functional changes. 2024-06-07 09:28:42 -04:00
Ryan Dick
595096bdcf Move BlendLatentsInvocation to its own file. No functional changes. 2024-06-07 09:28:42 -04:00
Ryan Dick
ed03d281e6 Move CropLatentsCoreInvocation to its own file. No functional changes. 2024-06-07 09:28:42 -04:00
Ryan Dick
0b37496c57 Move IdealSizeInvocation to its own file. No functional changes. 2024-06-07 09:28:42 -04:00
psychedelicious
fde58ce0a3 Merge remote-tracking branch 'origin/main' into lstein/feat/simple-mm2-api 2024-06-07 14:23:41 +10:00
Lincoln Stein
dc134935c8 replace load_and_cache_model() with load_remote_model() and load_local_odel() 2024-06-07 14:12:16 +10:00
Lincoln Stein
9f9379682e ruff fixes 2024-06-07 13:54:41 +10:00
Lincoln Stein
f81b8bc9f6 add support for generic loading of diffusers directories 2024-06-07 13:54:30 +10:00
psychedelicious
6d067e56f2 fix(ui): on page load, if CA processed image no longer exists, re-process it 2024-06-07 10:32:28 +10:00
Lincoln Stein
2871676f79 LoRA patching optimization (#6439)
* allow model patcher to optimize away the unpatching step when feasible

* remove lazy_offloading functionality

* allow model patcher to optimize away the unpatching step when feasible

* remove lazy_offloading functionality

* do not save original weights if there is a CPU copy of state dict

* Update invokeai/backend/model_manager/load/load_base.py

Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>

* documentation fixes added during penultimate review

---------

Co-authored-by: Lincoln Stein <lstein@gmail.com>
Co-authored-by: Kent Keirsey <31807370+hipsterusername@users.noreply.github.com>
Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>
2024-06-06 13:53:35 +00:00
psychedelicious
1c5c3cdbd6 tidy(ui): organize control layers konva logic
- More comments, docstrings
- Move things into saner, less-coupled locations
2024-06-06 07:45:13 +10:00
psychedelicious
3db69af220 refactor(ui): generalize stage event handlers
Create intermediary nanostores for values required by the event handlers. This allows the event handlers to be purely imperative, with no reactivity: instead of recreating/setting the handlers when a dependent piece of state changes, we use nanostores' imperative API to access dependent state.

For example, some handlers depend on brush size. If we used the standard declarative `useSelector` API, we'd need to recreate the event handler callback each time the brush size changed. This can be costly.

An intermediate `$brushSize` nanostore is set in a `useLayoutEffect()`, which responds to changes to the redux store. Then, in the event handler, we use the imperative API to access the brush size: `$brushSize.get()`.

This change allows the event handler logic to be shared with the pending canvas v2, and also more easily tested. It's a noticeable perf improvement, too, especially when changing brush size.
2024-06-06 07:45:13 +10:00
psychedelicious
1823e446ac fix(ui): conditionally render CL preview
This fixes an issue where it sometimes gets out of sync, and fixes some konva errors.
2024-06-06 07:45:13 +10:00
psychedelicious
311e44ad19 tidy(ui): clean up control layers renderers, docstrings 2024-06-06 07:45:13 +10:00
jstnlowe
848ca79da8 Changed translated labels to static suffixes, cleanup. 2024-06-05 14:45:43 +10:00
jstnlowe
9cba0dfac9 Providing fileName string directly to DataViewer as suggested 2024-06-05 14:45:43 +10:00
jstnlowe
37b1f21bcf ... and the workflow 2024-06-05 14:45:43 +10:00
jstnlowe
b2e005f6b5 Just realized we might want the same change made for the Graph JSON 2024-06-05 14:45:43 +10:00
jstnlowe
52aac954c0 Prefixed JSON filenames with the image UUID #6469 2024-06-05 14:45:43 +10:00
psychedelicious
ff01ceae99 Update invokeai_version.py 2024-06-05 05:53:19 +10:00
hugoalh
669d92d8db translationBot(ui): update translation (Chinese (Traditional))
Currently translated at 14.1% (179 of 1261 strings)

Co-authored-by: hugoalh <hugoalh@users.noreply.hosted.weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/zh_Hant/
Translation: InvokeAI/Web UI
2024-06-05 00:08:03 +10:00
Ettore Atalan
2903060154 translationBot(ui): update translation (German)
Currently translated at 67.0% (834 of 1243 strings)

Co-authored-by: Ettore Atalan <atalanttore@googlemail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/de/
Translation: InvokeAI/Web UI
2024-06-05 00:08:03 +10:00
gallegonovato
4af8699a00 translationBot(ui): update translation (Spanish)
Currently translated at 34.3% (427 of 1243 strings)

Co-authored-by: gallegonovato <fran-carro@hotmail.es>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/es/
Translation: InvokeAI/Web UI
2024-06-05 00:08:03 +10:00
Bruno Castillejo
71fedd1a07 translationBot(ui): update translation (Spanish)
Currently translated at 34.3% (427 of 1243 strings)

Co-authored-by: Bruno Castillejo <soybrunocastillejo@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/es/
Translation: InvokeAI/Web UI
2024-06-05 00:08:03 +10:00
Riccardo Giovanetti
6bb1189c88 translationBot(ui): update translation (Italian)
Currently translated at 98.5% (1243 of 1261 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.5% (1243 of 1261 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.5% (1225 of 1243 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.5% (1225 of 1243 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2024-06-05 00:08:03 +10:00
Васянатор
c7546bc82e translationBot(ui): update translation (Russian)
Currently translated at 100.0% (1261 of 1261 strings)

translationBot(ui): update translation (Russian)

Currently translated at 100.0% (1243 of 1243 strings)

Co-authored-by: Васянатор <ilabulanov339@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/ru/
Translation: InvokeAI/Web UI
2024-06-05 00:08:03 +10:00
psychedelicious
14372e3818 fix(nodes): blend latents with weight=0 with DPMSolverSDEScheduler
- Pass the seed from `latents_a` to the output latents. Fixed an issue where using `BlendLatentsInvocation` could result in different outputs during denoising even when the alpha or slerp weight was 0.

## Explanation

`LatentsField` has an optional `seed` field. During denoising, if this `seed` field is not present, we **fall back to 0 for the seed**. The seed is used during denoising in a few ways:

1. Initializing the scheduler.

The seed is used in two places in `invokeai/app/invocations/latent.py`.

The `get_scheduler()` utility function has special handling for `DPMSolverSDEScheduler`, which appears to need a seed for deterministic outputs.

`DenoiseLatentsInvocation.init_scheduler()` has special handling for schedulers that accept a generator - the generator needs to be seeded in a particular way. At the time of this commit, these are the Invoke-supported schedulers that need this seed:
  - DDIMScheduler
  - DDPMScheduler
  - DPMSolverMultistepScheduler
  - EulerAncestralDiscreteScheduler
  - EulerDiscreteScheduler
  - KDPM2AncestralDiscreteScheduler
  - LCMScheduler
  - TCDScheduler

2. Adding noise during inpainting.

If a mask is used for denoising, and we are not using an inpainting model, we add noise to the unmasked area. If, for some reason, we have a mask but no noise, the seed is used to add noise.

I wonder if we should instead assert that if a mask is provided, we also have noise.

This is done in `invokeai/backend/stable_diffusion/diffusers_pipeline.py` in `StableDiffusionGeneratorPipeline.latents_from_embeddings()`.

When we create noise to be used in denoising, we are expected to set `LatentsField.seed` to the seed used to create the noise. This introduces some awkwardness when we manipulate any "latents" that will be used for denoising. We have to pass the seed along for every operation.

If the wrong seed or no seed is passed along, we can get unexpected outputs during denoising. One notable case relates to blending latents (slerping tensors).

If we slerp two noise tensors (`LatentsField`s) _without_ passing along the seed from the source latents, when we denoise with a seed-dependent scheduler*, the schedulers use the fallback seed of 0 and we get the wrong output. This is most obvious when slerping with a weight of 0, in which case we expect the exact same output after denoising.

*It looks like only the DPMSolver* schedulers are affected, but I haven't tested all of them.

Passing the seed along in the output fixes this issue.
2024-06-05 00:02:52 +10:00
psychedelicious
64523c4b1b fix(ui): handle concat when recalling prompts
This required some minor reworking of of the logic to recall multiple items. I split this into a utility function that includes some special handling for concat.

Closes #6478
2024-06-04 06:01:01 +10:00
psychedelicious
89a764a359 fix(ui): improve model metadata parsing fallback
When the model in metadata's key no longer exists, fall back to fetching by name, base and type. This was the intention all along but the logic was never put in place.
2024-06-04 06:01:01 +10:00
Lincoln Stein
756108f6bd Update invokeai/app/invocations/latent.py
Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>
2024-06-03 11:41:47 -07:00
Lincoln Stein
68d628dc14 use zip to iterate over image prompts and adapters 2024-06-03 11:41:47 -07:00
Lincoln Stein
93c9852142 fix ruff 2024-06-03 11:41:47 -07:00
Lincoln Stein
493f81788c added a few comments to document design choices 2024-06-03 11:41:47 -07:00
Lincoln Stein
f13427e3f4 refactor redundant code and fix typechecking errors 2024-06-03 11:41:47 -07:00
Lincoln Stein
e28737fc8b add check for congruence between # of ip_adapters and image_prompts 2024-06-03 11:41:47 -07:00
Lincoln Stein
7391c126d3 handle case of no IP adapters requested 2024-06-03 11:41:47 -07:00
Lincoln Stein
1c59fce6ad reduce peak VRAM memory usage of IP adapter 2024-06-03 11:41:47 -07:00
psychedelicious
a9962fd104 chore: ruff 2024-06-03 11:53:20 +10:00
psychedelicious
e7513f6088 docs(mm): add comment in move_model_to_device 2024-06-03 10:56:04 +10:00
psychedelicious
c7f22b6a3b tidy(mm): remove extraneous docstring
It's inherited from the ABC.
2024-06-03 10:46:28 +10:00
psychedelicious
99413256ce tidy(mm): pass enum member instead of string 2024-06-03 10:43:09 +10:00
psychedelicious
aa9695e377 tidy(download): _download_job -> _multifile_job 2024-06-03 10:15:53 +10:00
psychedelicious
c58ac1e80d tidy(mm): minor formatting 2024-06-03 10:11:08 +10:00
psychedelicious
6cc6a45274 feat(download): add type for callback_name
Just a bit of typo protection in lieu of full type safety for these methods, which is difficult due to the typing of `DownloadEventHandler`.
2024-06-03 10:05:52 +10:00
psychedelicious
521f907f58 tidy(nodes): infill
- Set `self._context=context` instead of passing it as an arg
2024-06-03 09:43:25 +10:00
psychedelicious
ccdecf21a3 tidy(nodes): cnet processors
- Set `self._context=context` instead of changing the type signature of `run_processor`
- Tidy a few typing things
2024-06-03 09:41:17 +10:00
psychedelicious
b124440023 tidy(mm): move load_model_from_url from mm to invocation context 2024-06-03 08:51:21 +10:00
psychedelicious
e3a70e598e docs(app): simplify docstring in invocation_context 2024-06-03 08:40:29 +10:00
psychedelicious
132bbf330a tidy(app): remove unnecessary changes in invocation_context
- Any mypy issues are a misconfiguration of mypy
- Use simple conditionals instead of ternaries
- Consistent & standards-compliant docstring formatting
- Use `dict` instead of `typing.Dict`
2024-06-03 08:35:23 +10:00
Lincoln Stein
2276f327e5 Merge branch 'main' into lstein/feat/simple-mm2-api 2024-06-02 09:45:31 -04:00
psychedelicious
6b24424727 feat(ui): add help icon to compare toolbar 2024-06-02 15:30:00 +10:00
psychedelicious
7153d846a9 feat(ui): add hotkey to cycle compare modes 2024-06-02 15:30:00 +10:00
psychedelicious
9a0b77ad38 feat(ui): add hotkey to swap comparison images 2024-06-02 15:30:00 +10:00
psychedelicious
220d45967e fix(ui): typo 2024-06-02 15:30:00 +10:00
psychedelicious
038a482ef0 feat(ui): rework visibility conditions for image viewer 2024-06-02 15:30:00 +10:00
psychedelicious
c325ad3432 feat(ui): add hotkey hint to exit compare button 2024-06-02 15:30:00 +10:00
psychedelicious
449bc4dbe5 feat(ui): abstract out and share logic between comparisons 2024-06-02 15:30:00 +10:00
psychedelicious
34d68a3663 feat(ui): hover comparison mode 2024-06-02 15:30:00 +10:00
psychedelicious
8bb9571485 feat(ui): tweak slider divider styling 2024-06-02 15:30:00 +10:00
psychedelicious
08bcc71e99 fix(ui): workflows fit on load 2024-06-02 15:30:00 +10:00
psychedelicious
ff2b2fad83 feat(ui): revise drop zones
The main viewer area has two drop zones:
- Select for Viewer
- Select for Compare

These do what you'd imagine they would do.
2024-06-02 15:30:00 +10:00
psychedelicious
0f0a6852f1 fix(ui): make compare image scale with first image when using contain fit 2024-06-02 15:30:00 +10:00
psychedelicious
745140fa6b feat(ui): "first image"/"second image" -> "viewer image"/"compare image" 2024-06-02 15:30:00 +10:00
psychedelicious
405fc46888 feat(ui): z/esc first exit compare before closing viewer 2024-06-02 15:30:00 +10:00
psychedelicious
ca728ca29f fix(ui): ignore context menu in slider view
It doesn't make sense to allow context menu here, because the context menu will technically be on a div and not an image - there won't be any image options there.
2024-06-02 15:30:00 +10:00
psychedelicious
d0fca53e67 fix(ui): only clear comparison image on alt click of gallery image
This logic can't e in the reducer else it applies to dnd events which isn't right
2024-06-02 15:30:00 +10:00
psychedelicious
ad9740d72d feat(ui): alt-click comparison image exits compare 2024-06-02 15:30:00 +10:00
psychedelicious
1c9c982b63 feat(ui): use appropriate cursor on slider 2024-06-02 15:30:00 +10:00
psychedelicious
3cfd2755c2 fix(ui): when changing viewer state, always clear compare image 2024-06-02 15:30:00 +10:00
psychedelicious
8ea4067f83 feat(ui): rework compare toolbar 2024-06-02 15:30:00 +10:00
psychedelicious
940de6a5c5 fix(ui): allow drop of currently-selected image for compare 2024-06-02 15:30:00 +10:00
psychedelicious
dd74e89127 fix(ui): close context menu on click select for compare 2024-06-02 15:30:00 +10:00
psychedelicious
69da67e920 fix(ui): dnd on board
Copy-paste error broke this
2024-06-02 15:30:00 +10:00
psychedelicious
76b1f241d7 fix(ui): useGalleryNavigation callback typing issue 2024-06-02 15:30:00 +10:00
psychedelicious
0e5336d8fa feat(ui): rework comparison activation, add hotkeys 2024-06-02 15:30:00 +10:00
psychedelicious
3501636018 feat(ui): add fill mode for slider comparison 2024-06-02 15:30:00 +10:00
psychedelicious
e4ce188500 feat(ui): image selection gallery state & tweaks 2024-06-02 15:30:00 +10:00
psychedelicious
e976571fba build(ui): remove unused dep 2024-06-02 15:30:00 +10:00
psychedelicious
0da36c1238 feat(ui): use IAIDndImage for compare mode 2024-06-02 15:30:00 +10:00
psychedelicious
4ef8cbd9d0 fix(ui): use isValidDrop in imageDropped listener
It was possible for a drop event to be invalid but still processed. Fixed by slightly changing the signature of isValidDrop.
2024-06-02 15:30:00 +10:00
psychedelicious
8f8ddd620b feat(ui): add comparison modes, side-by-side view 2024-06-02 15:30:00 +10:00
psychedelicious
1af53aed60 feat(ui): fix image comparison slider resizing/aspect ratio jank 2024-06-02 15:30:00 +10:00
psychedelicious
7a4bbd092e feat(ui): revised image comparison slider
Should work for any components and image now.
2024-06-02 15:30:00 +10:00
psychedelicious
72bbcb2d94 feat(ui): slider working for all aspect ratios 2024-06-02 15:30:00 +10:00
psychedelicious
c2eef93476 feat(ui): wip slider implementations 2024-06-02 15:30:00 +10:00
blessedcoolant
cfb12615e1 fix: openapi stuff (#6454)
## Summary

Fix some issues with openapi schema generation. See commits for details.

## Related Issues / Discussions


https://discord.com/channels/1020123559063990373/1049495067846524939/1245141831394529352

## QA Instructions

App should work, workflows should work.

## Merge Plan

n/a

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
2024-05-30 08:22:34 +05:30
psychedelicious
a983f27aad fix(ui): update types 2024-05-30 12:03:38 +10:00
psychedelicious
7cb32d3d83 chore(ui): typegen 2024-05-30 12:03:38 +10:00
psychedelicious
ac56ab79a7 fix(app): add dynamic validator to AnyInvocation & AnyInvocationOutput
This fixes the tests and slightly changes output types.
2024-05-30 12:03:38 +10:00
psychedelicious
50d3030471 feat(app): dynamic type adapters for invocations & outputs
Keep track of whether or not the typeadapter needs to be updated. Allows for dynamic invocation and output unions.
2024-05-30 12:03:38 +10:00
psychedelicious
5beec8211a feat(api): sort openapi schemas
Reduces the constant changes to the frontend client types due to inconsistent ordering of pydantic models.
2024-05-30 12:03:38 +10:00
psychedelicious
5a4d10467b feat(ui): use updated types 2024-05-30 12:03:38 +10:00
psychedelicious
7590f3005e chore(ui): typegen 2024-05-30 12:03:03 +10:00
psychedelicious
2f9ebdec69 fix(app): openapi schema generation
Some tech debt related to dynamic pydantic schemas for invocations became problematic. Including the invocations and results in the event schemas was breaking pydantic's handling of ref schemas. I don't really understand why - I think it's a pydantic bug in a remote edge case that we are hitting.

After many failed attempts I landed on this implementation, which is actually much tidier than what was in there before.

- Create pydantic-enabled types for `AnyInvocation` and `AnyInvocationOutput` and use these in place of the janky dynamic unions. Actually, they are kinda the same, but better encapsulated. Use these in `Graph`, `GraphExecutionState`, `InvocationEventBase` and `InvocationCompleteEvent`.
- Revise the custom openapi function to work with the new models.
- Split out the custom openapi function to a separate file. Add a `post_transform` callback so consumers can customize the output schema.
- Update makefile scripts.
2024-05-30 12:03:03 +10:00
psychedelicious
e257a72f94 chore: bump pydantic, fastapi to latest 2024-05-30 12:03:03 +10:00
psychedelicious
843f82c837 fix(ui): remove overly strict constraints on control adapter weight 2024-05-29 19:01:28 -07:00
psychedelicious
66858effa2 docs: add FAQ for fixing controlnet_aux 2024-05-29 18:19:06 -07:00
Lincoln Stein
21a60af881 when unlocking models, offload_unlocked_models should prune to vram limit only (#6450)
Co-authored-by: Lincoln Stein <lstein@gmail.com>
2024-05-29 03:01:21 +00:00
Lincoln Stein
ead1748c54 issue a download progress event when install download starts 2024-05-28 19:30:42 -04:00
Ryan Dick
df91d1b849 Update TI handling for compatibility with transformers 4.40.0 (#6449)
## Summary

- Updated the documentation for `TextualInversionManager`
- Updated the `self.tokenizer.model_max_length` access to work with the
latest transformers version. Thanks to @skunkworxdark for looking into
this here:
https://github.com/invoke-ai/InvokeAI/issues/6445#issuecomment-2133098342

## Related Issues / Discussions

Closes #6445 

## QA Instructions

I tested with `transformers==4.41.1`, and compared the results against a
recent InvokeAI version before updating tranformers - no change, as
expected.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
2024-05-28 08:32:02 -04:00
Ryan Dick
829b9ad66b Add a callout about the hackiness of dropping tokens in the TextualInversionManager. 2024-05-28 05:11:54 -07:00
Ryan Dick
3aa1c8d3a8 Update TextualInversionManager for compatibility with the latest transformers release. See https://github.com/invoke-ai/InvokeAI/issues/6445. 2024-05-28 05:11:54 -07:00
Ryan Dick
994c61b67a Add docs to TextualInversionManager and improve types. No changes to functionality. 2024-05-28 05:11:54 -07:00
psychedelicious
21aa42627b feat(events): add dynamic invocation & result validators
This is required to get these event fields to deserialize correctly. If omitted, pydantic uses `BaseInvocation`/`BaseInvocationOutput`, which is not correct.

This is similar to the workaround in the `Graph` and `GraphExecutionState` classes where we need to fanagle pydantic with manual validation handling.
2024-05-28 05:11:37 -07:00
psychedelicious
a4f88ff834 feat(events): add __event_name__ as ClassVar to EventBase
This improves types for event consumers that need to access the event name.
2024-05-28 05:11:37 -07:00
Lincoln Stein
cd12ca6e85 add migration_11; fix typo 2024-05-27 22:40:01 -04:00
Lincoln Stein
34e1eb19f9 merge with main and resolve conflicts 2024-05-27 22:20:34 -04:00
psychedelicious
ddff9b4584 fix(events): typing for download event handler 2024-05-27 11:13:47 +10:00
psychedelicious
b50133d5e1 feat(events): register event schemas
This allows for events to be dispatched using dicts as payloads, and have the dicts validated as pydantic schemas.
2024-05-27 11:13:47 +10:00
psychedelicious
5388f5a817 fix(ui): edit variant for main models only
Closes #6444
2024-05-27 11:02:00 +10:00
psychedelicious
27a3eb15f8 feat(ui): update event types 2024-05-27 10:17:02 +10:00
psychedelicious
4b2d57a5e0 chore(ui): typegen
Note about the huge diff: I had a different version of pydantic installed at some point, which slightly altered a _ton_ of schema components. This typegen was done on the correct version of pydantic and un-does those alterations, in addition to the intentional changes to event models.
2024-05-27 10:17:02 +10:00
psychedelicious
bbb90ff949 feat(events): restore whole invocation to event payloads
Removing this is a breaking API change - some consumers of the events need the whole invocation. Didn't realize that until now.
2024-05-27 10:17:02 +10:00
psychedelicious
9d9801b2c2 feat(events): stronger generic typing for event registration 2024-05-27 10:17:02 +10:00
psychedelicious
8498d4344b docs: update docstrings in sockets.py 2024-05-27 09:06:02 +10:00
psychedelicious
dfad37a262 docs: update comments & docstrings 2024-05-27 09:06:02 +10:00
psychedelicious
89dede7bad feat(ui): simplify client sio redux actions
- Add a simple helper to create socket actions in a less error-prone way
- Organize and tidy sio files
2024-05-27 09:06:02 +10:00
psychedelicious
60784a4361 feat(ui): update client for removal of session events 2024-05-27 09:06:02 +10:00
psychedelicious
3d8774d295 chore(ui): typegen 2024-05-27 09:06:02 +10:00
psychedelicious
084cf26ed6 refactor: remove all session events
There's no longer any need for session-scoped events now that we have the session queue. Session started/completed/canceled map 1-to-1 to queue item status events, but queue item status events also have an event for failed state.

We can simplify queue and processor handling substantially by removing session events and instead using queue item events.

- Remove the session-scoped events entirely.
- Remove all event handling from session queue. The processor still needs to respond to some events from the queue: `QueueClearedEvent`, `BatchEnqueuedEvent` and `QueueItemStatusChangedEvent`.
- Pass an `is_canceled` callback to the invocation context instead of the cancel event
- Update processor logic to ensure the local instance of the current queue item is synced with the instance in the database. This prevents race conditions and ensures lifecycle callback do not get stale callbacks.
- Update docstrings and comments
- Add `complete_queue_item` method to session queue service as an explicit way to mark a queue item as successfully completed. Previously, the queue listened for session complete events to do this.

Closes #6442
2024-05-27 09:06:02 +10:00
psychedelicious
8592f5c6e1 feat(events): move event sets outside sio class
This lets the event sets be consumed programmatically.
2024-05-27 09:06:02 +10:00
psychedelicious
368127bd25 feat(events): register_events supports single event 2024-05-27 09:06:02 +10:00
psychedelicious
c0aabcd8ea tidy(events): use tuple index access for event payloads 2024-05-27 09:06:02 +10:00
psychedelicious
ed6c716ddc fix(mm): emit correct event when model load complete 2024-05-27 09:06:02 +10:00
psychedelicious
eaf67b2150 feat(ui): add logging for session events 2024-05-27 09:06:02 +10:00
psychedelicious
575943d0ad fix(processor): move session started event to session runner 2024-05-27 09:06:02 +10:00
psychedelicious
25d1d2b591 tidy(processor): use separate handlers for each event type
Just a bit clearer without needing `isinstance` checks.
2024-05-27 09:06:02 +10:00
psychedelicious
39415428de chore(ui): typegen 2024-05-27 09:06:02 +10:00
psychedelicious
64d553f72c feat(events): restore temp handling of user/project 2024-05-27 09:06:02 +10:00
psychedelicious
5b390bb11c tests: clean up tests after events changes 2024-05-27 09:06:02 +10:00
psychedelicious
a9f773c03c fix(mm): port changes into new model_install_common file
Some subtle changes happened between this PR's last update and now. Bring them into the file.
2024-05-27 09:06:02 +10:00
psychedelicious
585feccf82 fix(ui): update event handling to match new types 2024-05-27 09:06:02 +10:00
psychedelicious
cbd3b15cae chore(ui): typegen 2024-05-27 09:06:02 +10:00
psychedelicious
cc56918453 tidy(ui): remove old unused session subscribe actions 2024-05-27 09:06:02 +10:00
psychedelicious
f82df2661a docs: clarify comment in api_app 2024-05-27 09:06:02 +10:00
psychedelicious
a1d68eb319 fix(ui): denoise percentage 2024-05-27 09:06:02 +10:00
psychedelicious
8b5caa7e57 chore(ui): typegen 2024-05-27 09:06:02 +10:00
psychedelicious
b3a051250f feat(api): sort socket event names for openapi schema
Deterministic ordering prevents extraneous, non-functional changes to the autogenerated types
2024-05-27 09:06:02 +10:00
psychedelicious
0f733c42fc fix(events): fix denoise progress percentage
- Restore calculation of step percentage but in the backend instead of client
- Simplify signatures for denoise progress event callbacks
- Clean up `step_callback.py` (types, do not recreate constant matrix on every step, formatting)
2024-05-27 09:06:02 +10:00
psychedelicious
ec4f10aed3 chore(ui): typegen 2024-05-27 09:06:02 +10:00
psychedelicious
d97186dfc8 feat(events): remove payload registry, add method to get event classes
We don't need to use the payload schema registry. All our events are dispatched as pydantic models, which are already validated on instantiation.

We do want to add all events to the OpenAPI schema, and we referred to the payload schema registry for this. To get all events, add a simple helper to EventBase. This is functionally identical to using the schema registry.
2024-05-27 09:06:02 +10:00
psychedelicious
18b4f1b72a feat(ui): add missing socket events 2024-05-27 09:06:02 +10:00
psychedelicious
5cdf71b72f feat(events): add missing events
These events weren't being emitted via socket.io:
- DownloadCancelledEvent
- DownloadCompleteEvent
- DownloadErrorEvent
- DownloadProgressEvent
- DownloadStartedEvent
- ModelInstallDownloadsCompleteEvent
2024-05-27 09:06:02 +10:00
psychedelicious
88a2340b95 feat(events): use builder pattern for download events 2024-05-27 09:06:02 +10:00
psychedelicious
1be4cab2d9 fix(events): dump events with mode="json"
Ensures all model events are serializable.
2024-05-27 09:06:02 +10:00
psychedelicious
567b87cc50 docs(events): update event docstrings 2024-05-27 09:06:02 +10:00
psychedelicious
4756920282 tests: move fixtures import to conftest.py 2024-05-27 09:06:02 +10:00
psychedelicious
a876675448 tests: update tests to use new events 2024-05-27 09:06:02 +10:00
psychedelicious
655f62008f fix(mm): check for presence of invoker before emitting model load event
The model loader emits events. During testing, it doesn't have access to a fully-mocked events service, so the test fails when attempting to call a nonexistent method. There was a check for this previously, but I accidentally removed it. Restored.
2024-05-27 09:06:02 +10:00
psychedelicious
300725d1dd fix(ui): correct model load event format 2024-05-27 09:06:02 +10:00
psychedelicious
bf03127c69 fix(events): add missing __event_name__ to EventBase 2024-05-27 09:06:02 +10:00
psychedelicious
2dc752ea83 feat(events): simplify event classes
- Remove ABCs, they do not work well with pydantic
- Remove the event type classvar - unused
- Remove clever logic to require an event name - we already get validation for this during schema registration.
- Rename event bases to all end in "Base"
2024-05-27 09:06:02 +10:00
psychedelicious
1b9bbaa5a4 fix(events): emit bulk download events in correct room 2024-05-27 09:06:02 +10:00
psychedelicious
3abc182b44 chore(ui): tidy after rebase 2024-05-27 09:06:02 +10:00
psychedelicious
8d79ce94aa feat(ui): update UI to use new events
- Use OpenAPI schema for event payload types
- Update all event listeners
- Add missing events / remove old nonexistent events
2024-05-27 09:06:02 +10:00
psychedelicious
975dc14579 chore(ui): typegen 2024-05-27 09:06:02 +10:00
psychedelicious
9bd78823a3 refactor(events): use pydantic schemas for events
Our events handling and implementation has a couple pain points:
- Adding or removing data from event payloads requires changes wherever the events are dispatched from.
- We have no type safety for events and need to rely on string matching and dict access when interacting with events.
- Frontend types for socket events must be manually typed. This has caused several bugs.

`fastapi-events` has a neat feature where you can create a pydantic model as an event payload, give it an `__event_name__` attr, and then dispatch the model directly.

This allows us to eliminate a layer of indirection and some unpleasant complexity:
- Event handler callbacks get type hints for their event payloads, and can use `isinstance` on them if needed.
- Event payload construction is now the responsibility of the event itself (a pydantic model), not the service. Every event model has a `build` class method, encapsulating this logic. The build methods are provided as few args as possible. For example, `InvocationStartedEvent.build()` gets the invocation instance and queue item, and can choose the data it wants to include in the event payload.
- Frontend event types may be autogenerated from the OpenAPI schema. We use the payload registry feature of `fastapi-events` to collect all payload models into one place, making it trivial to keep our schema and frontend types in sync.

This commit moves the backend over to this improved event handling setup.
2024-05-27 09:06:02 +10:00
psychedelicious
461e857824 fix(ui): parameter not set translation 2024-05-26 08:21:06 -07:00
Wubbbi
48db0b90e8 Bump transformers 2024-05-26 12:51:07 +10:00
Wubbbi
c010ce49f7 Bump huggingface-hub 2024-05-26 12:51:07 +10:00
Wubbbi
6df8b23c59 Bump transformers 2024-05-26 12:51:07 +10:00
Wubbbi
dfe02b26c1 Bump accelerate 2024-05-26 12:51:07 +10:00
Wubbbi
4142dc7141 Update deps to their lastest version 2024-05-26 12:51:07 +10:00
Shukri
86bfcc53a3 docs: fix typo (#6395)
may noise steps -> many noise steps
2024-05-24 18:02:17 +00:00
Lincoln Stein
532f82cb97 Optimize RAM to VRAM transfer (#6312)
* avoid copying model back from cuda to cpu

* handle models that don't have state dicts

* add assertions that models need a `device()` method

* do not rely on torch.nn.Module having the device() method

* apply all patches after model is on the execution device

* fix model patching in latents too

* log patched tokenizer

* closes #6375

---------

Co-authored-by: Lincoln Stein <lstein@gmail.com>
2024-05-24 17:06:09 +00:00
cdpath
7437085cac fix typo (#6255) 2024-05-24 15:26:05 +00:00
psychedelicious
e9b80cf28f fix(ui): isLocal erroneously hardcoded 2024-05-25 00:05:44 +10:00
psychedelicious
f5a775ae4e feat(ui): toast on queue item errors, improved error descriptions
Show error toasts on queue item error events instead of invocation error events. This allows errors that occurred outside node execution to be surfaced to the user.

The error description component is updated to show the new error message if available. Commercial handling is retained, but local now uses the same component to display the error message itself.
2024-05-24 20:02:24 +10:00
psychedelicious
50dd569411 fix(processor): race condition that could result in node errors not getting reported
I had set the cancel event at some point during troubleshooting an unrelated issue. It seemed logical that it should be set there, and didn't seem to break anything. However, this is not correct.

The cancel event should not be set in response to a queue status change event. Doing so can cause a race condition when nodes are executed very quickly.

It's possible that a previously-executed session's queue item status change event is handled after the next session starts executing. The cancel event is set and the session runner sees it aborting the session run early.

In hindsight, it doesn't make sense to set the cancel event here either. It should be set in response to user action, e.g. the user cancelled the session or cleared the queue (which implicitly cancels the current session). These events actually trigger the queue item status changed event, so if we set the cancel event here, we'd be setting it twice per cancellation.
2024-05-24 20:02:24 +10:00
psychedelicious
125e1d7eb4 tidy: remove unnecessary whitespace changes 2024-05-24 20:02:24 +10:00
psychedelicious
2fbe5ecb00 fix(ui): correctly fallback to error message when traceback is empty string 2024-05-24 20:02:24 +10:00
psychedelicious
ba4d27860f tidy(ui): remove extraneous condition in socketInvocationError 2024-05-24 20:02:24 +10:00
psychedelicious
6fc7614b4a fix(ui): race condition with progress
There's a race condition where a canceled session may emit a progress event or two after it's been canceled, and the progress image isn't cleared out.

To resolve this, the system slice tracks canceled session ids. When a progress event comes in, we check the cancellations and skip setting the progress if canceled.
2024-05-24 20:02:24 +10:00
psychedelicious
9c926f249f feat(processor): add debug log stmts to session running callbacks 2024-05-24 20:02:24 +10:00
psychedelicious
80faeac913 fix(processor): fix race condition related to clearing the queue 2024-05-24 20:02:24 +10:00
psychedelicious
418c932595 tidy(processor): remove test callbacks 2024-05-24 20:02:24 +10:00
psychedelicious
9117db2673 tidy(queue): delete unused delete_queue_item method 2024-05-24 20:02:24 +10:00
psychedelicious
4a48aa98a4 chore: ruff 2024-05-24 20:02:24 +10:00
psychedelicious
e365d35c93 docs(processor): update docstrings, comments 2024-05-24 20:02:24 +10:00
psychedelicious
aa329ea811 feat(ui): handle enriched events 2024-05-24 20:02:24 +10:00
psychedelicious
1e622a5706 chore(ui): typegen 2024-05-24 20:02:24 +10:00
psychedelicious
ae66d32b28 feat(app): update test event callbacks 2024-05-24 20:02:24 +10:00
psychedelicious
2dd3a85ade feat(processor): update enriched errors & fail_queue_item() 2024-05-24 20:02:24 +10:00
psychedelicious
a8492bd7e4 feat(events): add enriched errors to events 2024-05-24 20:02:24 +10:00
psychedelicious
25954ea750 feat(queue): session queue error handling
- Add handling for new error columns `error_type`, `error_message`, `error_traceback`.
- Update queue item model to include the new data. The `error_traceback` field has an alias of `error` for backwards compatibility.
- Add `fail_queue_item` method. This was previously handled by `cancel_queue_item`. Splitting this functionality makes failing a queue item a bit more explicit. We also don't need to handle multiple optional error args.
-
2024-05-24 20:02:24 +10:00
psychedelicious
887b73aece feat(db): add error_type, error_message, rename error -> error_traceback to session_queue table 2024-05-24 20:02:24 +10:00
psychedelicious
3c41c67d13 fix(processor): restore missing update of session 2024-05-24 20:02:24 +10:00
psychedelicious
6c79be7dc3 chore: ruff 2024-05-24 20:02:24 +10:00
psychedelicious
097619ef51 feat(processor): get user/project from queue item w/ fallback 2024-05-24 20:02:24 +10:00
psychedelicious
a1f7a9cd6f fix(app): fix logging of error classes instead of class names 2024-05-24 20:02:24 +10:00
psychedelicious
25b9c19eed feat(app): handle preparation errors as node errors
We were not handling node preparation errors as node errors before. Here's the explanation, copied from a comment that is no longer required:

---

TODO(psyche): Sessions only support errors on nodes, not on the session itself. When an error occurs outside
node execution, it bubbles up to the processor where it is treated as a queue item error.

Nodes are pydantic models. When we prepare a node in `session.next()`, we set its inputs. This can cause a
pydantic validation error. For example, consider a resize image node which has a constraint on its `width`
input field - it must be greater than zero. During preparation, if the width is set to zero, pydantic will
raise a validation error.

When this happens, it breaks the flow before `invocation` is set. We can't set an error on the invocation
because we didn't get far enough to get it - we don't know its id. Hence, we just set it as a queue item error.

---

This change wraps the node preparation step with exception handling. A new `NodeInputError` exception is raised when there is a validation error. This error has the node (in the state it was in just prior to the error) and an identifier of the input that failed.

This allows us to mark the node that failed preparation as errored, correctly making such errors _node_ errors and not _processor_ errors. It's much easier to diagnose these situations. The error messages look like this:

> Node b5ac87c6-0678-4b8c-96b9-d215aee12175 has invalid incoming input for height

Some of the exception handling logic is cleaned up.
2024-05-24 20:02:24 +10:00
psychedelicious
cc2d877699 docs(app): explain why errors are handled poorly 2024-05-24 20:02:24 +10:00
psychedelicious
be82404759 tidy(app): "outputs" -> "output" 2024-05-24 20:02:24 +10:00
psychedelicious
33f9fe2c86 tidy(app): rearrange proccessor 2024-05-24 20:02:24 +10:00
psychedelicious
1d973f92ff feat(app): support multiple processor lifecycle callbacks 2024-05-24 20:02:24 +10:00
psychedelicious
7f70cde038 feat(app): make things in session runner private 2024-05-24 20:02:24 +10:00
psychedelicious
47722528a3 feat(app): iterate on processor split 2
- Use protocol to define callbacks, this allows them to have kwargs
- Shuffle the profiler around a bit
- Move `thread_limit` and `polling_interval` to `__init__`; `start` is called programmatically and will never get these args in practice
2024-05-24 20:02:24 +10:00
psychedelicious
be41c84305 feat(app): iterate on processor split
- Add `OnNodeError` and `OnNonFatalProcessorError` callbacks
- Move all session/node callbacks to `SessionRunner` - this ensures we dump perf stats before resetting them and generally makes sense to me
- Remove `complete` event from `SessionRunner`, it's essentially the same as `OnAfterRunSession`
- Remove extraneous `next_invocation` block, which would treat a processor error as a node error
- Simplify loops
- Add some callbacks for testing, to be removed before merge
2024-05-24 20:02:24 +10:00
brandonrising
82b4298b03 Fix next node calling logic 2024-05-24 20:02:24 +10:00
brandonrising
fa6c7badd6 Run ruff 2024-05-24 20:02:24 +10:00
brandonrising
45d2504c1e Break apart session processor and the running of each session into separate classes 2024-05-24 20:02:24 +10:00
psychedelicious
f1bb7e86c0 feat(ui): invalidate cache for queue item on status change
This query is only subscribed-to in the `QueueItemDetail` component - when is rendered only when the user clicks on a queue item in the queue. Invalidating this tag instead of optimistically updating it won't cause any meaningful change to network traffic.
2024-05-24 08:59:49 +10:00
psychedelicious
93e4c3dbc2 feat(app): update queue item's session on session completion
The session is never updated in the queue after it is first enqueued. As a result, the queue detail view in the frontend never never updates and the session itself doesn't show outputs, execution graph, etc.

We need a new method on the queue service to update a queue item's session, then call it before updating the queue item's status.

Queue item status may be updated via a session-type event _or_ queue-type event. Adding the updated session to all these events is a hairy - simpler to just update the session before we do anything that could trigger a queue item status change event:
- Before calling `emit_session_complete` in the processor (handles session error, completed and cancel events and the corresponding queue events)
- Before calling `cancel_queue_item` in the processor (handles another way queue items can be canceled, outside the session execution loop)

When serializing the session, both in the new service method and the `get_queue_item` endpoint, we need to use `exclude_none=True` to prevent unexpected validation errors.
2024-05-24 08:59:49 +10:00
gallegonovato
c3f28f7a35 translationBot(ui): update translation (Spanish)
Currently translated at 30.5% (380 of 1243 strings)

Co-authored-by: gallegonovato <fran-carro@hotmail.es>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/es/
Translation: InvokeAI/Web UI
2024-05-24 08:05:45 +10:00
Hosted Weblate
c900a63842 translationBot(ui): update translation files
Updated by "Cleanup translation files" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI
2024-05-24 08:05:45 +10:00
psychedelicious
4eb5f004e6 Update invokeai_version.py 2024-05-24 08:00:03 +10:00
blessedcoolant
bcae735d7c fix(ui): initial image layers always ignored (#6434)
## Summary

Whoops!

## Related Issues / Discussions


https://discord.com/channels/1020123559063990373/1049495067846524939/1243186572115837009

## QA Instructions

- Generate w/ initial image layer

## Merge Plan

n/a

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
2024-05-24 03:16:18 +05:30
blessedcoolant
861f06c459 Merge branch 'main' into psyche/fix/ui/initial-image-layer 2024-05-24 03:14:18 +05:30
blessedcoolant
c493628272 fix(ui): 'undefined' being used for metadata on uploaded images (#6433)
## Summary

TIL if you add `undefined` to a form data object, it gets stringified to
`'undefined'`. Whoops!

## Related Issues / Discussions

n/a

## QA Instructions

n/a

## Merge Plan

n/a

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
2024-05-24 03:14:02 +05:30
psychedelicious
46a90ca402 fix(ui): initial image layers always ignored
Whoops!
2024-05-24 06:40:48 +10:00
psychedelicious
d45c33b446 fix(ui): 'undefined' being used for metadata on uploaded images 2024-05-24 06:17:07 +10:00
psychedelicious
88025d32c2 feat(api): downgrade metadata parse warnings to debug
I set these to warn during testing and neglected to undo the change.
2024-05-23 22:48:34 +10:00
psychedelicious
af64764082 fix: remove db maintenance script from launcher
It is broken.
2024-05-23 22:39:55 +10:00
psychedelicious
70487f0c2e fix(ui): layers are "enabled", not "visible" 2024-05-23 10:14:34 +10:00
psychedelicious
55d7d9cc75 fix(ui): control layers don't disable correctly
Closes #6424
2024-05-23 10:14:34 +10:00
Mary Hipp
106674175c add logo and change text for non-local; 2024-05-23 06:51:13 +10:00
Mary Hipp
dd1d5bdb25 use support URL for non-local 2024-05-23 06:51:13 +10:00
Dennis
6259ac0bec translationBot(ui): update translation (Dutch)
Currently translated at 79.6% (973 of 1222 strings)

Co-authored-by: Dennis <dennis@vanzoerlandt.nl>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/nl/
Translation: InvokeAI/Web UI
2024-05-22 09:51:12 +10:00
Riccardo Giovanetti
ba31f8a9a9 translationBot(ui): update translation (Italian)
Currently translated at 98.5% (1210 of 1228 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.5% (1206 of 1224 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.5% (1204 of 1222 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2024-05-22 09:51:12 +10:00
psychedelicious
0ba57d6dc5 feat(ui): close starter models toast when a model is installed 2024-05-22 09:40:46 +10:00
psychedelicious
abc133e936 feat(ui): revised invocation error toast handling
Only display the session if local. Otherwise, just display the error message.
2024-05-22 09:40:46 +10:00
psychedelicious
57743239d7 feat(ui): add updateDescription flag to toast API
If false, when updating a toast, the description is left alone. The count will still tick up.
2024-05-22 09:40:46 +10:00
psychedelicious
4a394c60cf feat(ui): add isLocal flag to config 2024-05-22 09:40:46 +10:00
psychedelicious
624d28a93d feat(ui): invocation error toasts do not autoclose 2024-05-22 09:40:46 +10:00
psychedelicious
29e1ea59fc feat(ui): style copy button on ToastWithSessionRefDescription 2024-05-22 09:40:46 +10:00
psychedelicious
2e5d24f272 tidy(ui): remove old comment 2024-05-22 09:40:46 +10:00
psychedelicious
1afa340b1a fix(ui): show toast when recalling seed 2024-05-22 09:40:46 +10:00
psychedelicious
3b381b5a8c tidy(ui): remove the ToastID enum
With the model install logic cleaned up the enum is less useful
2024-05-22 09:40:46 +10:00
psychedelicious
f2b9684de8 tidy(ui): split install model into helper hook
This was duplicated like 7 times or so
2024-05-22 09:40:46 +10:00
psychedelicious
a66b3497e0 feat(ui): port all toasts to use new util 2024-05-22 09:40:46 +10:00
psychedelicious
683ec8e5f2 feat(ui): add stateful toast utility
Small wrapper around chakra's toast system simplifies creating and updating toasts. See comments in toast.ts for details.
2024-05-22 09:40:46 +10:00
psychedelicious
f31f0cf733 feat(ui): restore spellcheck on prompt boxes 2024-05-22 08:52:25 +10:00
psychedelicious
38265b3123 docs(ui): update validateWorkflow comments 2024-05-21 05:17:10 -07:00
psychedelicious
caca28286c tests(ui): add test for resource usage check 2024-05-21 05:17:10 -07:00
psychedelicious
38320a5100 feat(ui): reset missing images, boards and models when loading workflows
These fields are reset back to `undefined` if not accessible. A warning toast is showing, and in the JS console, the full warning message is logged.
2024-05-21 05:17:10 -07:00
Shukri
7badaab17d docs: fix link to invoke ai models site 2024-05-20 20:48:42 -07:00
blessedcoolant
aa0c59bb51 fix(ui): crash when using notes nodes or missing node/field templates (#6412)
## Summary

Notes nodes used some overly-strict redux selectors. The selectors are
now more chill. Also fixed an issue where you couldn't edit a notes node
title.

Found another class of error related to the overly strict reducers that
caused errors when loading a workflow that had missing templates. Fixed
this with fallback wrapper component, works like an error boundary when
a template isn't found.

## Related Issues / Discussions


https://discord.com/channels/1020123559063990373/1149506274971631688/1242256425527545949

## QA Instructions

- Add a notes node to a workflow. Edit the notes title.
- Load a workflow that has nodes that aren't installed. Should get a
fallback UI for each missing node.
- Load a workflow that references a node with different inputs than are
in the template - like an old version of a node. Should get a fallback
field warning for both missing templates, or missing inputs.

## Merge Plan

n/a

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
2024-05-21 07:59:43 +05:30
psychedelicious
e4acaa5c8f chore: v4.2.2post1 2024-05-21 11:31:06 +10:00
psychedelicious
9ba47cae20 fix(ui): unable to edit notes node title 2024-05-21 11:27:11 +10:00
psychedelicious
bf4310ca71 fix(ui): errors when node template or field template doesn't exist
Some asserts were bubbling up in places where they shouldn't have, causing errors when a node has a field without a matching template, or vice-versa.

To resolve this without sacrificing the runtime safety provided by asserts, a `InvocationFieldCheck` component now wraps all field components. This component renders a fallback when a field doesn't exist, so the inner components can safely use the asserts.
2024-05-21 11:22:08 +10:00
psychedelicious
e75f98317f fix(ui): notes node text not selectable 2024-05-21 10:06:25 +10:00
psychedelicious
1249d4a6e3 fix(ui): crash when using a notes node 2024-05-21 10:06:09 +10:00
Lincoln Stein
987ee704a1 Merge branch 'main' into lstein/feat/simple-mm2-api 2024-05-17 22:54:03 -04:00
Lincoln Stein
e77c7e40b7 fix ruff error 2024-05-17 22:53:45 -04:00
Lincoln Stein
8aebc29b91 fix test to run on 32bit cpu 2024-05-17 22:48:54 -04:00
Lincoln Stein
d968c6f379 refactor multifile download code 2024-05-17 22:29:19 -04:00
Lincoln Stein
2dae5eb7ad more refactoring; HF subfolders not working 2024-05-16 22:26:18 -04:00
Lincoln Stein
911a24479b add tests for model install file size reporting 2024-05-16 07:18:33 -04:00
Lincoln Stein
f29c406fed refactor model_install to work with refactored download queue 2024-05-13 22:49:15 -04:00
Lincoln Stein
287c679f7b clean up type checking for single file and multifile download job callbacks 2024-05-13 18:31:40 -04:00
Lincoln Stein
0bf14c2830 add multifile_download() method to download service 2024-05-12 20:14:00 -06:00
Lincoln Stein
b48d4a049d bad implementation of diffusers folder download 2024-05-08 21:21:01 -07:00
Lincoln Stein
f211c95dbc move access token regex matching into download queue 2024-05-05 21:00:31 -04:00
Lincoln Stein
8e5e9b53d6 Merge branch 'main' into lstein/feat/simple-mm2-api 2024-05-04 17:01:15 -04:00
Lincoln Stein
e9a20051bd refactor DWOpenPose and add type hints 2024-05-03 18:08:53 -04:00
Lincoln Stein
38df6f3702 fix ruff error 2024-05-02 21:22:33 -04:00
Lincoln Stein
3b64e7a1fd Merge branch 'main' into lstein/feat/simple-mm2-api 2024-05-02 21:20:35 -04:00
Lincoln Stein
49c84cd423 Merge branch 'main' into lstein/feat/simple-mm2-api 2024-04-30 18:13:42 -04:00
psychedelicious
1fe90c357c feat(backend): lift managed model loading out of depthanything class 2024-04-29 08:56:00 +10:00
psychedelicious
fcb071f30c feat(backend): lift managed model loading out of lama class 2024-04-29 08:12:51 +10:00
Lincoln Stein
57c831442e fix safe_filename() on windows 2024-04-28 14:42:40 -04:00
Lincoln Stein
f65c7e2bfd Merge branch 'main' into lstein/feat/simple-mm2-api 2024-04-28 13:42:54 -04:00
Lincoln Stein
7c39929758 support VRAM caching of dict models that lack to() 2024-04-28 13:41:06 -04:00
Lincoln Stein
a26667d3ca make download and convert cache keys safe for filename length 2024-04-28 12:24:36 -04:00
Lincoln Stein
bb04f496e0 Merge branch 'main' into lstein/feat/simple-mm2-api 2024-04-28 11:33:26 -04:00
Lincoln Stein
70903ef057 refactor load_ckpt_from_url() 2024-04-28 11:33:23 -04:00
Lincoln Stein
d72f272f16 Address change requests in first round of PR reviews.
Pending:

- Move model install calls into model manager and create passthrus in invocation_context.
- Consider splitting load_model_from_url() into a call to get the path and a call to load the path.
2024-04-24 23:53:30 -04:00
Lincoln Stein
34cdfc61ab Merge branch 'main' into lstein/feat/simple-mm2-api 2024-04-17 17:18:13 -04:00
Lincoln Stein
470a39935c fix merge conflicts with main 2024-04-15 09:24:57 -04:00
Lincoln Stein
f1e79d5a8f Merge branch 'main' into lstein/feat/simple-mm2-api 2024-04-15 09:14:55 -04:00
Lincoln Stein
f055e1edb6 Merge branch 'lstein/feat/simple-mm2-api' of github.com:invoke-ai/InvokeAI into lstein/feat/simple-mm2-api 2024-04-15 09:14:37 -04:00
Lincoln Stein
fa6efac436 change names of convert and download caches and add migration script 2024-04-14 16:10:24 -04:00
Lincoln Stein
3ead827d61 port dw_openpose, depth_anything, and lama processors to new model download scheme 2024-04-14 16:10:24 -04:00
Lincoln Stein
c140d3b1df add invocation_context.load_ckpt_from_url() method 2024-04-14 16:10:24 -04:00
Lincoln Stein
34438ce1af add simplified model manager install API to InvocationContext 2024-04-14 16:10:24 -04:00
Lincoln Stein
3ddd7ced49 change names of convert and download caches and add migration script 2024-04-14 15:57:33 -04:00
Lincoln Stein
41b909cbe3 port dw_openpose, depth_anything, and lama processors to new model download scheme 2024-04-14 15:57:03 -04:00
Lincoln Stein
3a26c7bb9e fix merge conflicts 2024-04-12 00:58:11 -04:00
Lincoln Stein
df5ebdbc4f add invocation_context.load_ckpt_from_url() method 2024-04-12 00:55:21 -04:00
Lincoln Stein
af1b57a01f add simplified model manager install API to InvocationContext 2024-04-11 21:46:00 -04:00
Lincoln Stein
9cc1f20ad5 add simplified model manager install API to InvocationContext 2024-04-03 23:26:48 -04:00
314 changed files with 13946 additions and 7872 deletions

View File

@@ -18,6 +18,7 @@ help:
@echo "frontend-typegen Generate types for the frontend from the OpenAPI schema"
@echo "installer-zip Build the installer .zip file for the current version"
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
@echo "openapi Generate the OpenAPI schema for the app, outputting to stdout"
# Runs ruff, fixing any safely-fixable errors and formatting
ruff:
@@ -70,3 +71,6 @@ installer-zip:
tag-release:
cd installer && ./tag_release.sh
# Generate the OpenAPI Schema for the app
openapi:
python scripts/generate_openapi_schema.py

View File

@@ -64,7 +64,7 @@ GPU_DRIVER=nvidia
Any environment variables supported by InvokeAI can be set here - please see the [Configuration docs](https://invoke-ai.github.io/InvokeAI/features/CONFIGURATION/) for further detail.
## Even Moar Customizing!
## Even More Customizing!
See the `docker-compose.yml` file. The `command` instruction can be uncommented and used to run arbitrary startup commands. Some examples below.

View File

@@ -128,7 +128,8 @@ The queue operates on a series of download job objects. These objects
specify the source and destination of the download, and keep track of
the progress of the download.
The only job type currently implemented is `DownloadJob`, a pydantic object with the
Two job types are defined. `DownloadJob` and
`MultiFileDownloadJob`. The former is a pydantic object with the
following fields:
| **Field** | **Type** | **Default** | **Description** |
@@ -138,7 +139,7 @@ following fields:
| `dest` | Path | | Where to download to |
| `access_token` | str | | [optional] string containing authentication token for access |
| `on_start` | Callable | | [optional] callback when the download starts |
| `on_progress` | Callable | | [optional] callback called at intervals during download progress |
| `on_progress` | Callable | | [optional] callback called at intervals during download progress |
| `on_complete` | Callable | | [optional] callback called after successful download completion |
| `on_error` | Callable | | [optional] callback called after an error occurs |
| `id` | int | auto assigned | Job ID, an integer >= 0 |
@@ -190,6 +191,33 @@ A cancelled job will have status `DownloadJobStatus.ERROR` and an
`error_type` field of "DownloadJobCancelledException". In addition,
the job's `cancelled` property will be set to True.
The `MultiFileDownloadJob` is used for diffusers model downloads,
which contain multiple files and directories under a common root:
| **Field** | **Type** | **Default** | **Description** |
|----------------|-----------------|---------------|-----------------|
| _Fields passed in at job creation time_ |
| `download_parts` | Set[DownloadJob]| | Component download jobs |
| `dest` | Path | | Where to download to |
| `on_start` | Callable | | [optional] callback when the download starts |
| `on_progress` | Callable | | [optional] callback called at intervals during download progress |
| `on_complete` | Callable | | [optional] callback called after successful download completion |
| `on_error` | Callable | | [optional] callback called after an error occurs |
| `id` | int | auto assigned | Job ID, an integer >= 0 |
| _Fields updated over the course of the download task_
| `status` | DownloadJobStatus| | Status code |
| `download_path` | Path | | Path to the root of the downloaded files |
| `bytes` | int | 0 | Bytes downloaded so far |
| `total_bytes` | int | 0 | Total size of the file at the remote site |
| `error_type` | str | | String version of the exception that caused an error during download |
| `error` | str | | String version of the traceback associated with an error |
| `cancelled` | bool | False | Set to true if the job was cancelled by the caller|
Note that the MultiFileDownloadJob does not support the `priority`,
`job_started`, `job_ended` or `content_type` attributes. You can get
these from the individual download jobs in `download_parts`.
### Callbacks
Download jobs can be associated with a series of callbacks, each with
@@ -251,11 +279,40 @@ jobs using `list_jobs()`, fetch a single job by its with
running jobs with `cancel_all_jobs()`, and wait for all jobs to finish
with `join()`.
#### job = queue.download(source, dest, priority, access_token)
#### job = queue.download(source, dest, priority, access_token, on_start, on_progress, on_complete, on_cancelled, on_error)
Create a new download job and put it on the queue, returning the
DownloadJob object.
#### multifile_job = queue.multifile_download(parts, dest, access_token, on_start, on_progress, on_complete, on_cancelled, on_error)
This is similar to download(), but instead of taking a single source,
it accepts a `parts` argument consisting of a list of
`RemoteModelFile` objects. Each part corresponds to a URL/Path pair,
where the URL is the location of the remote file, and the Path is the
destination.
`RemoteModelFile` can be imported from `invokeai.backend.model_manager.metadata`, and
consists of a url/path pair. Note that the path *must* be relative.
The method returns a `MultiFileDownloadJob`.
```
from invokeai.backend.model_manager.metadata import RemoteModelFile
remote_file_1 = RemoteModelFile(url='http://www.foo.bar/my/pytorch_model.safetensors'',
path='my_model/textencoder/pytorch_model.safetensors'
)
remote_file_2 = RemoteModelFile(url='http://www.bar.baz/vae.ckpt',
path='my_model/vae/diffusers_model.safetensors'
)
job = queue.multifile_download(parts=[remote_file_1, remote_file_2],
dest='/tmp/downloads',
on_progress=TqdmProgress().update)
queue.wait_for_job(job)
print(f"The files were downloaded to {job.download_path}")
```
#### jobs = queue.list_jobs()
Return a list of all active and inactive `DownloadJob`s.

View File

@@ -397,26 +397,25 @@ In the event you wish to create a new installer, you may use the
following initialization pattern:
```
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.config import get_config
from invokeai.app.services.model_records import ModelRecordServiceSQL
from invokeai.app.services.model_install import ModelInstallService
from invokeai.app.services.download import DownloadQueueService
from invokeai.app.services.shared.sqlite import SqliteDatabase
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.backend.util.logging import InvokeAILogger
config = InvokeAIAppConfig.get_config()
config.parse_args()
config = get_config()
logger = InvokeAILogger.get_logger(config=config)
db = SqliteDatabase(config, logger)
db = SqliteDatabase(config.db_path, logger)
record_store = ModelRecordServiceSQL(db)
queue = DownloadQueueService()
queue.start()
installer = ModelInstallService(app_config=config,
installer = ModelInstallService(app_config=config,
record_store=record_store,
download_queue=queue
)
download_queue=queue
)
installer.start()
```
@@ -1367,12 +1366,20 @@ the in-memory loaded model:
| `model` | AnyModel | The instantiated model (details below) |
| `locker` | ModelLockerBase | A context manager that mediates the movement of the model into VRAM |
Because the loader can return multiple model types, it is typed to
return `AnyModel`, a Union `ModelMixin`, `torch.nn.Module`,
`IAIOnnxRuntimeModel`, `IPAdapter`, `IPAdapterPlus`, and
`EmbeddingModelRaw`. `ModelMixin` is the base class of all diffusers
models, `EmbeddingModelRaw` is used for LoRA and TextualInversion
models. The others are obvious.
### get_model_by_key(key, [submodel]) -> LoadedModel
The `get_model_by_key()` method will retrieve the model using its
unique database key. For example:
loaded_model = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
`get_model_by_key()` may raise any of the following exceptions:
* `UnknownModelException` -- key not in database
* `ModelNotFoundException` -- key in database but model not found at path
* `NotImplementedException` -- the loader doesn't know how to load this type of model
### Using the Loaded Model in Inference
`LoadedModel` acts as a context manager. The context loads the model
into the execution device (e.g. VRAM on CUDA systems), locks the model
@@ -1380,17 +1387,33 @@ in the execution device for the duration of the context, and returns
the model. Use it like this:
```
model_info = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
with model_info as vae:
loaded_model_= loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
with loaded_model as vae:
image = vae.decode(latents)[0]
```
`get_model_by_key()` may raise any of the following exceptions:
The object returned by the LoadedModel context manager is an
`AnyModel`, which is a Union of `ModelMixin`, `torch.nn.Module`,
`IAIOnnxRuntimeModel`, `IPAdapter`, `IPAdapterPlus`, and
`EmbeddingModelRaw`. `ModelMixin` is the base class of all diffusers
models, `EmbeddingModelRaw` is used for LoRA and TextualInversion
models. The others are obvious.
In addition, you may call `LoadedModel.model_on_device()`, a context
manager that returns a tuple of the model's state dict in CPU and the
model itself in VRAM. It is used to optimize the LoRA patching and
unpatching process:
```
loaded_model_= loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
with loaded_model.model_on_device() as (state_dict, vae):
image = vae.decode(latents)[0]
```
Since not all models have state dicts, the `state_dict` return value
can be None.
* `UnknownModelException` -- key not in database
* `ModelNotFoundException` -- key in database but model not found at path
* `NotImplementedException` -- the loader doesn't know how to load this type of model
### Emitting model loading events
When the `context` argument is passed to `load_model_*()`, it will
@@ -1578,3 +1601,59 @@ This method takes a model key, looks it up using the
`ModelRecordServiceBase` object in `mm.store`, and passes the returned
model configuration to `load_model_by_config()`. It may raise a
`NotImplementedException`.
## Invocation Context Model Manager API
Within invocations, the following methods are available from the
`InvocationContext` object:
### context.download_and_cache_model(source) -> Path
This method accepts a `source` of a remote model, downloads and caches
it locally, and then returns a Path to the local model. The source can
be a direct download URL or a HuggingFace repo_id.
In the case of HuggingFace repo_id, the following variants are
recognized:
* stabilityai/stable-diffusion-v4 -- default model
* stabilityai/stable-diffusion-v4:fp16 -- fp16 variant
* stabilityai/stable-diffusion-v4:fp16:vae -- the fp16 vae subfolder
* stabilityai/stable-diffusion-v4:onnx:vae -- the onnx variant vae subfolder
You can also point at an arbitrary individual file within a repo_id
directory using this syntax:
* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
### context.load_local_model(model_path, [loader]) -> LoadedModel
This method loads a local model from the indicated path, returning a
`LoadedModel`. The optional loader is a Callable that accepts a Path
to the object, and returns a `AnyModel` object. If no loader is
provided, then the method will use `torch.load()` for a .ckpt or .bin
checkpoint file, `safetensors.torch.load_file()` for a safetensors
checkpoint file, or `cls.from_pretrained()` for a directory that looks
like a diffusers directory.
### context.load_remote_model(source, [loader]) -> LoadedModel
This method accepts a `source` of a remote model, downloads and caches
it locally, loads it, and returns a `LoadedModel`. The source can be a
direct download URL or a HuggingFace repo_id.
In the case of HuggingFace repo_id, the following variants are
recognized:
* stabilityai/stable-diffusion-v4 -- default model
* stabilityai/stable-diffusion-v4:fp16 -- fp16 variant
* stabilityai/stable-diffusion-v4:fp16:vae -- the fp16 vae subfolder
* stabilityai/stable-diffusion-v4:onnx:vae -- the onnx variant vae subfolder
You can also point at an arbitrary individual file within a repo_id
directory using this syntax:
* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors

View File

@@ -165,7 +165,7 @@ Additionally, each section can be expanded with the "Show Advanced" button in o
There are several ways to install IP-Adapter models with an existing InvokeAI installation:
1. Through the command line interface launched from the invoke.sh / invoke.bat scripts, option [4] to download models.
2. Through the Model Manager UI with models from the *Tools* section of [www.models.invoke.ai](https://www.models.invoke.ai). To do this, copy the repo ID from the desired model page, and paste it in the Add Model field of the model manager. **Note** Both the IP-Adapter and the Image Encoder must be installed for IP-Adapter to work. For example, the [SD 1.5 IP-Adapter](https://models.invoke.ai/InvokeAI/ip_adapter_plus_sd15) and [SD1.5 Image Encoder](https://models.invoke.ai/InvokeAI/ip_adapter_sd_image_encoder) must be installed to use IP-Adapter with SD1.5 based models.
2. Through the Model Manager UI with models from the *Tools* section of [models.invoke.ai](https://models.invoke.ai). To do this, copy the repo ID from the desired model page, and paste it in the Add Model field of the model manager. **Note** Both the IP-Adapter and the Image Encoder must be installed for IP-Adapter to work. For example, the [SD 1.5 IP-Adapter](https://models.invoke.ai/InvokeAI/ip_adapter_plus_sd15) and [SD1.5 Image Encoder](https://models.invoke.ai/InvokeAI/ip_adapter_sd_image_encoder) must be installed to use IP-Adapter with SD1.5 based models.
3. **Advanced -- Not recommended ** Manually downloading the IP-Adapter and Image Encoder files - Image Encoder folders shouid be placed in the `models\any\clip_vision` folders. IP Adapter Model folders should be placed in the relevant `ip-adapter` folder of relevant base model folder of Invoke root directory. For example, for the SDXL IP-Adapter, files should be added to the `model/sdxl/ip_adapter/` folder.
#### Using IP-Adapter

View File

@@ -154,6 +154,18 @@ This is caused by an invalid setting in the `invokeai.yaml` configuration file.
Check the [configuration docs] for more detail about the settings and how to specify them.
## `ModuleNotFoundError: No module named 'controlnet_aux'`
`controlnet_aux` is a dependency of Invoke and appears to have been packaged or distributed strangely. Sometimes, it doesn't install correctly. This is outside our control.
If you encounter this error, the solution is to remove the package from the `pip` cache and re-run the Invoke installer so a fresh, working version of `controlnet_aux` can be downloaded and installed:
- Run the Invoke launcher
- Choose the developer console option
- Run this command: `pip cache remove controlnet_aux`
- Close the terminal window
- Download and run the [installer](https://github.com/invoke-ai/InvokeAI/releases/latest), selecting your current install location
## Out of Memory Issues
The models are large, VRAM is expensive, and you may find yourself

View File

@@ -20,7 +20,7 @@ When you generate an image using text-to-image, multiple steps occur in latent s
4. The VAE decodes the final latent image from latent space into image space.
Image-to-image is a similar process, with only step 1 being different:
1. The input image is encoded from image space into latent space by the VAE. Noise is then added to the input latent image. Denoising Strength dictates how may noise steps are added, and the amount of noise added at each step. A Denoising Strength of 0 means there are 0 steps and no noise added, resulting in an unchanged image, while a Denoising Strength of 1 results in the image being completely replaced with noise and a full set of denoising steps are performance. The process is then the same as steps 2-4 in the text-to-image process.
1. The input image is encoded from image space into latent space by the VAE. Noise is then added to the input latent image. Denoising Strength dictates how many noise steps are added, and the amount of noise added at each step. A Denoising Strength of 0 means there are 0 steps and no noise added, resulting in an unchanged image, while a Denoising Strength of 1 results in the image being completely replaced with noise and a full set of denoising steps are performance. The process is then the same as steps 2-4 in the text-to-image process.
Furthermore, a model provides the CLIP prompt tokenizer, the VAE, and a U-Net (where noise prediction occurs given a prompt and initial noise tensor).

View File

@@ -10,8 +10,7 @@ set INVOKEAI_ROOT=.
echo Desired action:
echo 1. Generate images with the browser-based interface
echo 2. Open the developer console
echo 3. Run the InvokeAI image database maintenance script
echo 4. Command-line help
echo 3. Command-line help
echo Q - Quit
echo.
echo To update, download and run the installer from https://github.com/invoke-ai/InvokeAI/releases/latest.
@@ -34,9 +33,6 @@ IF /I "%choice%" == "1" (
echo *** Type `exit` to quit this shell and deactivate the Python virtual environment ***
call cmd /k
) ELSE IF /I "%choice%" == "3" (
echo Running the db maintenance script...
python .venv\Scripts\invokeai-db-maintenance.exe
) ELSE IF /I "%choice%" == "4" (
echo Displaying command line help...
python .venv\Scripts\invokeai-web.exe --help %*
pause

View File

@@ -47,11 +47,6 @@ do_choice() {
bash --init-file "$file_name"
;;
3)
clear
printf "Running the db maintenance script\n"
invokeai-db-maintenance --root ${INVOKEAI_ROOT}
;;
4)
clear
printf "Command-line help\n"
invokeai-web --help
@@ -71,8 +66,7 @@ do_line_input() {
printf "What would you like to do?\n"
printf "1: Generate images using the browser-based interface\n"
printf "2: Open the developer console\n"
printf "3: Run the InvokeAI image database maintenance script\n"
printf "4: Command-line help\n"
printf "3: Command-line help\n"
printf "Q: Quit\n\n"
printf "To update, download and run the installer from https://github.com/invoke-ai/InvokeAI/releases/latest.\n\n"
read -p "Please enter 1-4, Q: [1] " yn

View File

@@ -18,6 +18,7 @@ from ..services.boards.boards_default import BoardService
from ..services.bulk_download.bulk_download_default import BulkDownloadService
from ..services.config import InvokeAIAppConfig
from ..services.download import DownloadQueueService
from ..services.events.events_fastapievents import FastAPIEventService
from ..services.image_files.image_files_disk import DiskImageFileStorage
from ..services.image_records.image_records_sqlite import SqliteImageRecordStorage
from ..services.images.images_default import ImageService
@@ -29,11 +30,10 @@ from ..services.model_images.model_images_default import ModelImageFileStorageDi
from ..services.model_manager.model_manager_default import ModelManagerService
from ..services.model_records import ModelRecordServiceSQL
from ..services.names.names_default import SimpleNameService
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
from ..services.session_processor.session_processor_default import DefaultSessionProcessor, DefaultSessionRunner
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
from ..services.urls.urls_default import LocalUrlService
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
from .events import FastAPIEventService
# TODO: is there a better way to achieve this?
@@ -93,7 +93,7 @@ class ApiDependencies:
conditioning = ObjectSerializerForwardCache(
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
)
download_queue_service = DownloadQueueService(event_bus=events)
download_queue_service = DownloadQueueService(app_config=configuration, event_bus=events)
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
model_manager = ModelManagerService.build_model_manager(
app_config=configuration,
@@ -103,7 +103,7 @@ class ApiDependencies:
)
names = SimpleNameService()
performance_statistics = InvocationStatsService()
session_processor = DefaultSessionProcessor()
session_processor = DefaultSessionProcessor(session_runner=DefaultSessionRunner())
session_queue = SqliteSessionQueue(db=db)
urls = LocalUrlService()
workflow_records = SqliteWorkflowRecordsStorage(db=db)

View File

@@ -1,52 +0,0 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import asyncio
import threading
from queue import Empty, Queue
from typing import Any
from fastapi_events.dispatcher import dispatch
from ..services.events.events_base import EventServiceBase
class FastAPIEventService(EventServiceBase):
event_handler_id: int
__queue: Queue
__stop_event: threading.Event
def __init__(self, event_handler_id: int) -> None:
self.event_handler_id = event_handler_id
self.__queue = Queue()
self.__stop_event = threading.Event()
asyncio.create_task(self.__dispatch_from_queue(stop_event=self.__stop_event))
super().__init__()
def stop(self, *args, **kwargs):
self.__stop_event.set()
self.__queue.put(None)
def dispatch(self, event_name: str, payload: Any) -> None:
self.__queue.put({"event_name": event_name, "payload": payload})
async def __dispatch_from_queue(self, stop_event: threading.Event):
"""Get events on from the queue and dispatch them, from the correct thread"""
while not stop_event.is_set():
try:
event = self.__queue.get(block=False)
if not event: # Probably stopping
continue
dispatch(
event.get("event_name"),
payload=event.get("payload"),
middleware_id=self.event_handler_id,
)
except Empty:
await asyncio.sleep(0.1)
pass
except asyncio.CancelledError as e:
raise e # Raise a proper error

View File

@@ -69,7 +69,7 @@ async def upload_image(
if isinstance(metadata_raw, str):
_metadata = metadata_raw
else:
ApiDependencies.invoker.services.logger.warn("Failed to parse metadata for uploaded image")
ApiDependencies.invoker.services.logger.debug("Failed to parse metadata for uploaded image")
pass
# attempt to parse workflow from image
@@ -77,7 +77,7 @@ async def upload_image(
if isinstance(workflow_raw, str):
_workflow = workflow_raw
else:
ApiDependencies.invoker.services.logger.warn("Failed to parse workflow for uploaded image")
ApiDependencies.invoker.services.logger.debug("Failed to parse workflow for uploaded image")
pass
# attempt to extract graph from image
@@ -85,7 +85,7 @@ async def upload_image(
if isinstance(graph_raw, str):
_graph = graph_raw
else:
ApiDependencies.invoker.services.logger.warn("Failed to parse graph for uploaded image")
ApiDependencies.invoker.services.logger.debug("Failed to parse graph for uploaded image")
pass
try:

View File

@@ -9,7 +9,7 @@ from copy import deepcopy
from typing import Any, Dict, List, Optional, Type
from fastapi import Body, Path, Query, Response, UploadFile
from fastapi.responses import FileResponse
from fastapi.responses import FileResponse, HTMLResponse
from fastapi.routing import APIRouter
from PIL import Image
from pydantic import AnyHttpUrl, BaseModel, ConfigDict, Field
@@ -17,7 +17,7 @@ from starlette.exceptions import HTTPException
from typing_extensions import Annotated
from invokeai.app.services.model_images.model_images_common import ModelImageFileNotFoundException
from invokeai.app.services.model_install import ModelInstallJob
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
from invokeai.app.services.model_records import (
DuplicateModelException,
InvalidModelException,
@@ -502,6 +502,133 @@ async def install_model(
return result
@model_manager_router.get(
"/install/huggingface",
operation_id="install_hugging_face_model",
responses={
201: {"description": "The model is being installed"},
400: {"description": "Bad request"},
409: {"description": "There is already a model corresponding to this path or repo_id"},
},
status_code=201,
response_class=HTMLResponse,
)
async def install_hugging_face_model(
source: str = Query(description="HuggingFace repo_id to install"),
) -> HTMLResponse:
"""Install a Hugging Face model using a string identifier."""
def generate_html(title: str, heading: str, repo_id: str, is_error: bool, message: str | None = "") -> str:
if message:
message = f"<p>{message}</p>"
title_class = "error" if is_error else "success"
return f"""
<html>
<head>
<title>{title}</title>
<style>
body {{
text-align: center;
background-color: hsl(220 12% 10% / 1);
font-family: Helvetica, sans-serif;
color: hsl(220 12% 86% / 1);
}}
.repo-id {{
color: hsl(220 12% 68% / 1);
}}
.error {{
color: hsl(0 42% 68% / 1)
}}
.message-box {{
display: inline-block;
border-radius: 5px;
background-color: hsl(220 12% 20% / 1);
padding-inline-end: 30px;
padding: 20px;
padding-inline-start: 30px;
padding-inline-end: 30px;
}}
.container {{
display: flex;
width: 100%;
height: 100%;
align-items: center;
justify-content: center;
}}
a {{
color: inherit
}}
a:visited {{
color: inherit
}}
a:active {{
color: inherit
}}
</style>
</head>
<body style="background-color: hsl(220 12% 10% / 1);">
<div class="container">
<div class="message-box">
<h2 class="{title_class}">{heading}</h2>
{message}
<p class="repo-id">Repo ID: {repo_id}</p>
</div>
</div>
</body>
</html>
"""
try:
metadata = HuggingFaceMetadataFetch().from_id(source)
assert isinstance(metadata, ModelMetadataWithFiles)
except UnknownMetadataException:
title = "Unable to Install Model"
heading = "No HuggingFace repository found with that repo ID."
message = "Ensure the repo ID is correct and try again."
return HTMLResponse(content=generate_html(title, heading, source, True, message), status_code=400)
logger = ApiDependencies.invoker.services.logger
try:
installer = ApiDependencies.invoker.services.model_manager.install
if metadata.is_diffusers:
installer.heuristic_import(
source=source,
inplace=False,
)
elif metadata.ckpt_urls is not None and len(metadata.ckpt_urls) == 1:
installer.heuristic_import(
source=str(metadata.ckpt_urls[0]),
inplace=False,
)
else:
title = "Unable to Install Model"
heading = "This HuggingFace repo has multiple models."
message = "Please use the Model Manager to install this model."
return HTMLResponse(content=generate_html(title, heading, source, True, message), status_code=200)
title = "Model Install Started"
heading = "Your HuggingFace model is installing now."
message = "You can close this tab and check the Model Manager for installation progress."
return HTMLResponse(content=generate_html(title, heading, source, False, message), status_code=201)
except Exception as e:
logger.error(str(e))
title = "Unable to Install Model"
heading = "There was an problem installing this model."
message = 'Please use the Model Manager directly to install this model. If the issue persists, ask for help on <a href="https://discord.gg/ZmtBAhwWhy">discord</a>.'
return HTMLResponse(content=generate_html(title, heading, source, True, message), status_code=500)
@model_manager_router.get(
"/install",
operation_id="list_model_installs",

View File

@@ -203,6 +203,7 @@ async def get_batch_status(
responses={
200: {"model": SessionQueueItem},
},
response_model_exclude_none=True,
)
async def get_queue_item(
queue_id: str = Path(description="The queue id to perform this operation on"),

View File

@@ -1,66 +1,125 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Any
from fastapi import FastAPI
from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event
from pydantic import BaseModel
from socketio import ASGIApp, AsyncServer
from ..services.events.events_base import EventServiceBase
from invokeai.app.services.events.events_common import (
BatchEnqueuedEvent,
BulkDownloadCompleteEvent,
BulkDownloadErrorEvent,
BulkDownloadEventBase,
BulkDownloadStartedEvent,
DownloadCancelledEvent,
DownloadCompleteEvent,
DownloadErrorEvent,
DownloadEventBase,
DownloadProgressEvent,
DownloadStartedEvent,
FastAPIEvent,
InvocationCompleteEvent,
InvocationDenoiseProgressEvent,
InvocationErrorEvent,
InvocationStartedEvent,
ModelEventBase,
ModelInstallCancelledEvent,
ModelInstallCompleteEvent,
ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent,
ModelInstallErrorEvent,
ModelInstallStartedEvent,
ModelLoadCompleteEvent,
ModelLoadStartedEvent,
QueueClearedEvent,
QueueEventBase,
QueueItemStatusChangedEvent,
register_events,
)
class QueueSubscriptionEvent(BaseModel):
"""Event data for subscribing to the socket.io queue room.
This is a pydantic model to ensure the data is in the correct format."""
queue_id: str
class BulkDownloadSubscriptionEvent(BaseModel):
"""Event data for subscribing to the socket.io bulk downloads room.
This is a pydantic model to ensure the data is in the correct format."""
bulk_download_id: str
QUEUE_EVENTS = {
InvocationStartedEvent,
InvocationDenoiseProgressEvent,
InvocationCompleteEvent,
InvocationErrorEvent,
QueueItemStatusChangedEvent,
BatchEnqueuedEvent,
QueueClearedEvent,
}
MODEL_EVENTS = {
DownloadCancelledEvent,
DownloadCompleteEvent,
DownloadErrorEvent,
DownloadProgressEvent,
DownloadStartedEvent,
ModelLoadStartedEvent,
ModelLoadCompleteEvent,
ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent,
ModelInstallStartedEvent,
ModelInstallCompleteEvent,
ModelInstallCancelledEvent,
ModelInstallErrorEvent,
}
BULK_DOWNLOAD_EVENTS = {BulkDownloadStartedEvent, BulkDownloadCompleteEvent, BulkDownloadErrorEvent}
class SocketIO:
__sio: AsyncServer
__app: ASGIApp
_sub_queue = "subscribe_queue"
_unsub_queue = "unsubscribe_queue"
__sub_queue: str = "subscribe_queue"
__unsub_queue: str = "unsubscribe_queue"
__sub_bulk_download: str = "subscribe_bulk_download"
__unsub_bulk_download: str = "unsubscribe_bulk_download"
_sub_bulk_download = "subscribe_bulk_download"
_unsub_bulk_download = "unsubscribe_bulk_download"
def __init__(self, app: FastAPI):
self.__sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")
self.__app = ASGIApp(socketio_server=self.__sio, socketio_path="/ws/socket.io")
app.mount("/ws", self.__app)
self._sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")
self._app = ASGIApp(socketio_server=self._sio, socketio_path="/ws/socket.io")
app.mount("/ws", self._app)
self.__sio.on(self.__sub_queue, handler=self._handle_sub_queue)
self.__sio.on(self.__unsub_queue, handler=self._handle_unsub_queue)
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event)
local_handler.register(event_name=EventServiceBase.model_event, _func=self._handle_model_event)
self._sio.on(self._sub_queue, handler=self._handle_sub_queue)
self._sio.on(self._unsub_queue, handler=self._handle_unsub_queue)
self._sio.on(self._sub_bulk_download, handler=self._handle_sub_bulk_download)
self._sio.on(self._unsub_bulk_download, handler=self._handle_unsub_bulk_download)
self.__sio.on(self.__sub_bulk_download, handler=self._handle_sub_bulk_download)
self.__sio.on(self.__unsub_bulk_download, handler=self._handle_unsub_bulk_download)
local_handler.register(event_name=EventServiceBase.bulk_download_event, _func=self._handle_bulk_download_event)
register_events(QUEUE_EVENTS, self._handle_queue_event)
register_events(MODEL_EVENTS, self._handle_model_event)
register_events(BULK_DOWNLOAD_EVENTS, self._handle_bulk_image_download_event)
async def _handle_queue_event(self, event: Event):
await self.__sio.emit(
event=event[1]["event"],
data=event[1]["data"],
room=event[1]["data"]["queue_id"],
)
async def _handle_sub_queue(self, sid: str, data: Any) -> None:
await self._sio.enter_room(sid, QueueSubscriptionEvent(**data).queue_id)
async def _handle_sub_queue(self, sid, data, *args, **kwargs) -> None:
if "queue_id" in data:
await self.__sio.enter_room(sid, data["queue_id"])
async def _handle_unsub_queue(self, sid: str, data: Any) -> None:
await self._sio.leave_room(sid, QueueSubscriptionEvent(**data).queue_id)
async def _handle_unsub_queue(self, sid, data, *args, **kwargs) -> None:
if "queue_id" in data:
await self.__sio.leave_room(sid, data["queue_id"])
async def _handle_sub_bulk_download(self, sid: str, data: Any) -> None:
await self._sio.enter_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
async def _handle_model_event(self, event: Event) -> None:
await self.__sio.emit(event=event[1]["event"], data=event[1]["data"])
async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None:
await self._sio.leave_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
async def _handle_bulk_download_event(self, event: Event):
await self.__sio.emit(
event=event[1]["event"],
data=event[1]["data"],
room=event[1]["data"]["bulk_download_id"],
)
async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]):
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].queue_id)
async def _handle_sub_bulk_download(self, sid, data, *args, **kwargs):
if "bulk_download_id" in data:
await self.__sio.enter_room(sid, data["bulk_download_id"])
async def _handle_model_event(self, event: FastAPIEvent[ModelEventBase | DownloadEventBase]) -> None:
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"))
async def _handle_unsub_bulk_download(self, sid, data, *args, **kwargs):
if "bulk_download_id" in data:
await self.__sio.leave_room(sid, data["bulk_download_id"])
async def _handle_bulk_image_download_event(self, event: FastAPIEvent[BulkDownloadEventBase]) -> None:
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].bulk_download_id)

View File

@@ -3,9 +3,7 @@ import logging
import mimetypes
import socket
from contextlib import asynccontextmanager
from inspect import signature
from pathlib import Path
from typing import Any
import torch
import uvicorn
@@ -13,11 +11,9 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
from fastapi.openapi.utils import get_openapi
from fastapi.responses import HTMLResponse
from fastapi_events.handlers.local import local_handler
from fastapi_events.middleware import EventHandlerASGIMiddleware
from pydantic.json_schema import models_json_schema
from torch.backends.mps import is_available as is_mps_available
# for PyCharm:
@@ -25,9 +21,8 @@ from torch.backends.mps import is_available as is_mps_available
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
import invokeai.frontend.web as web_dir
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.util.custom_openapi import get_openapi_func
from invokeai.backend.util.devices import TorchDevice
from ..backend.util.logging import InvokeAILogger
@@ -44,11 +39,6 @@ from .api.routers import (
workflows,
)
from .api.sockets import SocketIO
from .invocations.baseinvocation import (
BaseInvocation,
UIConfigBase,
)
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
app_config = get_config()
@@ -118,93 +108,7 @@ app.include_router(app_info.app_router, prefix="/api")
app.include_router(session_queue.session_queue_router, prefix="/api")
app.include_router(workflows.workflows_router, prefix="/api")
# Build a custom OpenAPI to include all outputs
# TODO: can outputs be included on metadata of invocation schemas somehow?
def custom_openapi() -> dict[str, Any]:
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title=app.title,
description="An API for invoking AI image operations",
version="1.0.0",
routes=app.routes,
separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/
)
# Add all outputs
all_invocations = BaseInvocation.get_invocations()
output_types = set()
output_type_titles = {}
for invoker in all_invocations:
output_type = signature(invoker.invoke).return_annotation
output_types.add(output_type)
output_schemas = models_json_schema(
models=[(o, "serialization") for o in output_types], ref_template="#/components/schemas/{model}"
)
for schema_key, output_schema in output_schemas[1]["$defs"].items():
# TODO: note that we assume the schema_key here is the TYPE.__name__
# This could break in some cases, figure out a better way to do it
output_type_titles[schema_key] = output_schema["title"]
openapi_schema["components"]["schemas"][schema_key] = output_schema
openapi_schema["components"]["schemas"][schema_key]["class"] = "output"
# Some models don't end up in the schemas as standalone definitions
additional_schemas = models_json_schema(
[
(UIConfigBase, "serialization"),
(InputFieldJSONSchemaExtra, "serialization"),
(OutputFieldJSONSchemaExtra, "serialization"),
(ModelIdentifierField, "serialization"),
(ProgressImage, "serialization"),
],
ref_template="#/components/schemas/{model}",
)
for schema_key, schema_json in additional_schemas[1]["$defs"].items():
openapi_schema["components"]["schemas"][schema_key] = schema_json
openapi_schema["components"]["schemas"]["InvocationOutputMap"] = {
"type": "object",
"properties": {},
"required": [],
}
# Add a reference to the output type to additionalProperties of the invoker schema
for invoker in all_invocations:
invoker_name = invoker.__name__ # type: ignore [attr-defined] # this is a valid attribute
output_type = signature(obj=invoker.invoke).return_annotation
output_type_title = output_type_titles[output_type.__name__]
invoker_schema = openapi_schema["components"]["schemas"][f"{invoker_name}"]
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
invoker_schema["output"] = outputs_ref
openapi_schema["components"]["schemas"]["InvocationOutputMap"]["properties"][invoker.get_type()] = outputs_ref
openapi_schema["components"]["schemas"]["InvocationOutputMap"]["required"].append(invoker.get_type())
invoker_schema["class"] = "invocation"
# This code no longer seems to be necessary?
# Leave it here just in case
#
# from invokeai.backend.model_manager import get_model_config_formats
# formats = get_model_config_formats()
# for model_config_name, enum_set in formats.items():
# if model_config_name in openapi_schema["components"]["schemas"]:
# # print(f"Config with name {name} already defined")
# continue
# openapi_schema["components"]["schemas"][model_config_name] = {
# "title": model_config_name,
# "description": "An enumeration.",
# "type": "string",
# "enum": [v.value for v in enum_set],
# }
app.openapi_schema = openapi_schema
return app.openapi_schema
app.openapi = custom_openapi # type: ignore [method-assign] # this is a valid assignment
app.openapi = get_openapi_func(app)
@app.get("/docs", include_in_schema=False)

View File

@@ -98,11 +98,13 @@ class BaseInvocationOutput(BaseModel):
_output_classes: ClassVar[set[BaseInvocationOutput]] = set()
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
_typeadapter_needs_update: ClassVar[bool] = False
@classmethod
def register_output(cls, output: BaseInvocationOutput) -> None:
"""Registers an invocation output."""
cls._output_classes.add(output)
cls._typeadapter_needs_update = True
@classmethod
def get_outputs(cls) -> Iterable[BaseInvocationOutput]:
@@ -112,11 +114,12 @@ class BaseInvocationOutput(BaseModel):
@classmethod
def get_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantc TypeAdapter for the union of all invocation output types."""
if not cls._typeadapter:
InvocationOutputsUnion = TypeAliasType(
"InvocationOutputsUnion", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
if not cls._typeadapter or cls._typeadapter_needs_update:
AnyInvocationOutput = TypeAliasType(
"AnyInvocationOutput", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
)
cls._typeadapter = TypeAdapter(InvocationOutputsUnion)
cls._typeadapter = TypeAdapter(AnyInvocationOutput)
cls._typeadapter_needs_update = False
return cls._typeadapter
@classmethod
@@ -125,12 +128,13 @@ class BaseInvocationOutput(BaseModel):
return (i.get_type() for i in BaseInvocationOutput.get_outputs())
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocationOutput]) -> None:
"""Adds various UI-facing attributes to the invocation output's OpenAPI schema."""
# Because we use a pydantic Literal field with default value for the invocation type,
# it will be typed as optional in the OpenAPI schema. Make it required manually.
if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = []
schema["class"] = "output"
schema["required"].extend(["type"])
@classmethod
@@ -167,6 +171,7 @@ class BaseInvocation(ABC, BaseModel):
_invocation_classes: ClassVar[set[BaseInvocation]] = set()
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
_typeadapter_needs_update: ClassVar[bool] = False
@classmethod
def get_type(cls) -> str:
@@ -177,15 +182,17 @@ class BaseInvocation(ABC, BaseModel):
def register_invocation(cls, invocation: BaseInvocation) -> None:
"""Registers an invocation."""
cls._invocation_classes.add(invocation)
cls._typeadapter_needs_update = True
@classmethod
def get_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantc TypeAdapter for the union of all invocation types."""
if not cls._typeadapter:
InvocationsUnion = TypeAliasType(
"InvocationsUnion", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")]
if not cls._typeadapter or cls._typeadapter_needs_update:
AnyInvocation = TypeAliasType(
"AnyInvocation", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")]
)
cls._typeadapter = TypeAdapter(InvocationsUnion)
cls._typeadapter = TypeAdapter(AnyInvocation)
cls._typeadapter_needs_update = False
return cls._typeadapter
@classmethod
@@ -221,7 +228,7 @@ class BaseInvocation(ABC, BaseModel):
return signature(cls.invoke).return_annotation
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel], *args, **kwargs) -> None:
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None:
"""Adds various UI-facing attributes to the invocation's OpenAPI schema."""
uiconfig = cast(UIConfigBase | None, getattr(model_class, "UIConfig", None))
if uiconfig is not None:
@@ -237,6 +244,7 @@ class BaseInvocation(ABC, BaseModel):
schema["version"] = uiconfig.version
if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = []
schema["class"] = "invocation"
schema["required"].extend(["type", "id"])
@abstractmethod
@@ -310,7 +318,7 @@ class BaseInvocation(ABC, BaseModel):
protected_namespaces=(),
validate_assignment=True,
json_schema_extra=json_schema_extra,
json_schema_serialization_defaults_required=True,
json_schema_serialization_defaults_required=False,
coerce_numbers_to_str=True,
)

View File

@@ -0,0 +1,98 @@
from typing import Any, Union
import numpy as np
import numpy.typing as npt
import torch
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, LatentsField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.util.devices import TorchDevice
@invocation(
"lblend",
title="Blend Latents",
tags=["latents", "blend"],
category="latents",
version="1.0.3",
)
class BlendLatentsInvocation(BaseInvocation):
"""Blend two latents using a given alpha. Latents must have same size."""
latents_a: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
latents_b: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents_a = context.tensors.load(self.latents_a.latents_name)
latents_b = context.tensors.load(self.latents_b.latents_name)
if latents_a.shape != latents_b.shape:
raise Exception("Latents to blend must be the same size.")
device = TorchDevice.choose_torch_device()
def slerp(
t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here?
v0: Union[torch.Tensor, npt.NDArray[Any]],
v1: Union[torch.Tensor, npt.NDArray[Any]],
DOT_THRESHOLD: float = 0.9995,
) -> Union[torch.Tensor, npt.NDArray[Any]]:
"""
Spherical linear interpolation
Args:
t (float/np.ndarray): Float value between 0.0 and 1.0
v0 (np.ndarray): Starting vector
v1 (np.ndarray): Final vector
DOT_THRESHOLD (float): Threshold for considering the two vectors as
colineal. Not recommended to alter this.
Returns:
v2 (np.ndarray): Interpolation vector between v0 and v1
"""
inputs_are_torch = False
if not isinstance(v0, np.ndarray):
inputs_are_torch = True
v0 = v0.detach().cpu().numpy()
if not isinstance(v1, np.ndarray):
inputs_are_torch = True
v1 = v1.detach().cpu().numpy()
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
if np.abs(dot) > DOT_THRESHOLD:
v2 = (1 - t) * v0 + t * v1
else:
theta_0 = np.arccos(dot)
sin_theta_0 = np.sin(theta_0)
theta_t = theta_0 * t
sin_theta_t = np.sin(theta_t)
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0
v2 = s0 * v0 + s1 * v1
if inputs_are_torch:
v2_torch: torch.Tensor = torch.from_numpy(v2).to(device)
return v2_torch
else:
assert isinstance(v2, np.ndarray)
return v2
# blend
bl = slerp(self.alpha, latents_a, latents_b)
assert isinstance(bl, torch.Tensor)
blended_latents: torch.Tensor = bl # for type checking convenience
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
blended_latents = blended_latents.to("cpu")
TorchDevice.empty_cache()
name = context.tensors.save(tensor=blended_latents)
return LatentsOutput.build(latents_name=name, latents=blended_latents, seed=self.latents_a.seed)

View File

@@ -65,11 +65,7 @@ class CompelInvocation(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
tokenizer_info = context.models.load(self.clip.tokenizer)
tokenizer_model = tokenizer_info.model
assert isinstance(tokenizer_model, CLIPTokenizer)
text_encoder_info = context.models.load(self.clip.text_encoder)
text_encoder_model = text_encoder_info.model
assert isinstance(text_encoder_model, CLIPTextModel)
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.clip.loras:
@@ -84,19 +80,25 @@ class CompelInvocation(BaseInvocation):
ti_list = generate_ti_list(self.prompt, text_encoder_info.config.base, context)
with (
ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as (
tokenizer,
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (model_state_dict, text_encoder),
tokenizer_info as tokenizer,
ModelPatcher.apply_lora_text_encoder(
text_encoder,
loras=_lora_loader(),
model_state_dict=model_state_dict,
),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers),
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (
patched_tokenizer,
ti_manager,
),
text_encoder_info as text_encoder,
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_model, self.clip.skipped_layers),
):
assert isinstance(text_encoder, CLIPTextModel)
assert isinstance(tokenizer, CLIPTokenizer)
compel = Compel(
tokenizer=tokenizer,
tokenizer=patched_tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=ti_manager,
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
@@ -106,7 +108,7 @@ class CompelInvocation(BaseInvocation):
conjunction = Compel.parse_prompt_string(self.prompt)
if context.config.get().log_tokenization:
log_tokenization_for_conjunction(conjunction, tokenizer)
log_tokenization_for_conjunction(conjunction, patched_tokenizer)
c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction)
@@ -136,11 +138,7 @@ class SDXLPromptInvocationBase:
zero_on_empty: bool,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
tokenizer_info = context.models.load(clip_field.tokenizer)
tokenizer_model = tokenizer_info.model
assert isinstance(tokenizer_model, CLIPTokenizer)
text_encoder_info = context.models.load(clip_field.text_encoder)
text_encoder_model = text_encoder_info.model
assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection))
# return zero on empty
if prompt == "" and zero_on_empty:
@@ -177,20 +175,28 @@ class SDXLPromptInvocationBase:
ti_list = generate_ti_list(prompt, text_encoder_info.config.base, context)
with (
ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as (
tokenizer,
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (state_dict, text_encoder),
tokenizer_info as tokenizer,
ModelPatcher.apply_lora(
text_encoder,
loras=_lora_loader(),
prefix=lora_prefix,
model_state_dict=state_dict,
),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers),
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (
patched_tokenizer,
ti_manager,
),
text_encoder_info as text_encoder,
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers),
):
assert isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
assert isinstance(tokenizer, CLIPTokenizer)
text_encoder = cast(CLIPTextModel, text_encoder)
compel = Compel(
tokenizer=tokenizer,
tokenizer=patched_tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=ti_manager,
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
@@ -203,7 +209,7 @@ class SDXLPromptInvocationBase:
if context.config.get().log_tokenization:
# TODO: better logging for and syntax
log_tokenization_for_conjunction(conjunction, tokenizer)
log_tokenization_for_conjunction(conjunction, patched_tokenizer)
# TODO: ask for optimizations? to not run text_encoder twice
c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction)

View File

@@ -1,6 +1,7 @@
from typing import Literal
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
from invokeai.backend.util.devices import TorchDevice
LATENT_SCALE_FACTOR = 8
"""
@@ -15,3 +16,5 @@ SCHEDULER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())]
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
"""A literal type for PIL image modes supported by Invoke"""
DEFAULT_PRECISION = TorchDevice.choose_torch_dtype()

View File

@@ -2,6 +2,7 @@
# initial implementation by Gregg Helt, 2023
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
from builtins import bool, float
from pathlib import Path
from typing import Dict, List, Literal, Union
import cv2
@@ -36,12 +37,13 @@ from invokeai.app.invocations.util import validate_begin_end_step, validate_weig
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
from invokeai.backend.image_util.canny import get_canny_edges
from invokeai.backend.image_util.depth_anything import DepthAnythingDetector
from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector
from invokeai.backend.image_util.depth_anything import DEPTH_ANYTHING_MODELS, DepthAnythingDetector
from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector
from invokeai.backend.image_util.hed import HEDProcessor
from invokeai.backend.image_util.lineart import LineartProcessor
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
from invokeai.backend.util.devices import TorchDevice
from .baseinvocation import BaseInvocation, BaseInvocationOutput, Classification, invocation, invocation_output
@@ -139,6 +141,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
return context.images.get_pil(self.image.image_name, "RGB")
def invoke(self, context: InvocationContext) -> ImageOutput:
self._context = context
raw_image = self.load_image(context)
# image type should be PIL.PngImagePlugin.PngImageFile ?
processed_image = self.run_processor(raw_image)
@@ -284,7 +287,8 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
# depth_and_normal not supported in controlnet_aux v0.0.3
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
def run_processor(self, image):
def run_processor(self, image: Image.Image) -> Image.Image:
# TODO: replace from_pretrained() calls with context.models.download_and_cache() (or similar)
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
processed_image = midas_processor(
image,
@@ -311,7 +315,7 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image):
def run_processor(self, image: Image.Image) -> Image.Image:
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
processed_image = normalbae_processor(
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution
@@ -330,7 +334,7 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`")
thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`")
def run_processor(self, image):
def run_processor(self, image: Image.Image) -> Image.Image:
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
processed_image = mlsd_processor(
image,
@@ -353,7 +357,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
def run_processor(self, image):
def run_processor(self, image: Image.Image) -> Image.Image:
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
processed_image = pidi_processor(
image,
@@ -381,7 +385,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
w: int = InputField(default=512, ge=0, description="Content shuffle `w` parameter")
f: int = InputField(default=256, ge=0, description="Content shuffle `f` parameter")
def run_processor(self, image):
def run_processor(self, image: Image.Image) -> Image.Image:
content_shuffle_processor = ContentShuffleDetector()
processed_image = content_shuffle_processor(
image,
@@ -405,7 +409,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Zoe depth processing to image"""
def run_processor(self, image):
def run_processor(self, image: Image.Image) -> Image.Image:
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
processed_image = zoe_depth_processor(image)
return processed_image
@@ -426,7 +430,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image):
def run_processor(self, image: Image.Image) -> Image.Image:
mediapipe_face_processor = MediapipeFaceDetector()
processed_image = mediapipe_face_processor(
image,
@@ -454,7 +458,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image):
def run_processor(self, image: Image.Image) -> Image.Image:
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
processed_image = leres_processor(
image,
@@ -496,8 +500,8 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation):
np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA)
return np_img
def run_processor(self, img):
np_img = np.array(img, dtype=np.uint8)
def run_processor(self, image: Image.Image) -> Image.Image:
np_img = np.array(image, dtype=np.uint8)
processed_np_image = self.tile_resample(
np_img,
# res=self.tile_size,
@@ -520,7 +524,7 @@ class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image):
def run_processor(self, image: Image.Image) -> Image.Image:
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
"ybelkada/segment-anything", subfolder="checkpoints"
@@ -566,7 +570,7 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
color_map_tile_size: int = InputField(default=64, ge=1, description=FieldDescriptions.tile_size)
def run_processor(self, image: Image.Image):
def run_processor(self, image: Image.Image) -> Image.Image:
np_image = np.array(image, dtype=np.uint8)
height, width = np_image.shape[:2]
@@ -601,12 +605,18 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
)
resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image):
depth_anything_detector = DepthAnythingDetector()
depth_anything_detector.load_model(model_size=self.model_size)
def run_processor(self, image: Image.Image) -> Image.Image:
def loader(model_path: Path):
return DepthAnythingDetector.load_model(
model_path, model_size=self.model_size, device=TorchDevice.choose_torch_device()
)
processed_image = depth_anything_detector(image=image, resolution=self.resolution)
return processed_image
with self._context.models.load_remote_model(
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader
) as model:
depth_anything_detector = DepthAnythingDetector(model, TorchDevice.choose_torch_device())
processed_image = depth_anything_detector(image=image, resolution=self.resolution)
return processed_image
@invocation(
@@ -624,8 +634,11 @@ class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation):
draw_hands: bool = InputField(default=False)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image):
dw_openpose = DWOpenposeDetector()
def run_processor(self, image: Image.Image) -> Image.Image:
onnx_det = self._context.models.download_and_cache_model(DWPOSE_MODELS["yolox_l.onnx"])
onnx_pose = self._context.models.download_and_cache_model(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"])
dw_openpose = DWOpenposeDetector(onnx_det=onnx_det, onnx_pose=onnx_pose)
processed_image = dw_openpose(
image,
draw_face=self.draw_face,

View File

@@ -0,0 +1,80 @@
from typing import Optional
import torch
import torchvision.transforms as T
from PIL import Image
from torchvision.transforms.functional import resize as tv_resize
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import DEFAULT_PRECISION
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import DenoiseMaskOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
@invocation(
"create_denoise_mask",
title="Create Denoise Mask",
tags=["mask", "denoise"],
category="latents",
version="1.0.2",
)
class CreateDenoiseMaskInvocation(BaseInvocation):
"""Creates mask for denoising model run."""
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0)
image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1)
mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3)
fp32: bool = InputField(
default=DEFAULT_PRECISION == torch.float32,
description=FieldDescriptions.fp32,
ui_order=4,
)
def prep_mask_tensor(self, mask_image: Image.Image) -> torch.Tensor:
if mask_image.mode != "L":
mask_image = mask_image.convert("L")
mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
if mask_tensor.dim() == 3:
mask_tensor = mask_tensor.unsqueeze(0)
# if shape is not None:
# mask_tensor = tv_resize(mask_tensor, shape, T.InterpolationMode.BILINEAR)
return mask_tensor
@torch.no_grad()
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
if self.image is not None:
image = context.images.get_pil(self.image.image_name)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = image_tensor.unsqueeze(0)
else:
image_tensor = None
mask = self.prep_mask_tensor(
context.images.get_pil(self.mask.image_name),
)
if image_tensor is not None:
vae_info = context.models.load(self.vae.vae)
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
# TODO:
masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())
masked_latents_name = context.tensors.save(tensor=masked_latents)
else:
masked_latents_name = None
mask_name = context.tensors.save(tensor=mask)
return DenoiseMaskOutput.build(
mask_name=mask_name,
masked_latents_name=masked_latents_name,
gradient=False,
)

View File

@@ -0,0 +1,138 @@
from typing import Literal, Optional
import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image, ImageFilter
from torchvision.transforms.functional import resize as tv_resize
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.constants import DEFAULT_PRECISION
from invokeai.app.invocations.fields import (
DenoiseMaskField,
FieldDescriptions,
ImageField,
Input,
InputField,
OutputField,
)
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
from invokeai.app.invocations.model import UNetField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager import LoadedModel
from invokeai.backend.model_manager.config import MainConfigBase, ModelVariantType
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
@invocation_output("gradient_mask_output")
class GradientMaskOutput(BaseInvocationOutput):
"""Outputs a denoise mask and an image representing the total gradient of the mask."""
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
expanded_mask_area: ImageField = OutputField(
description="Image representing the total gradient area of the mask. For paste-back purposes."
)
@invocation(
"create_gradient_mask",
title="Create Gradient Mask",
tags=["mask", "denoise"],
category="latents",
version="1.1.0",
)
class CreateGradientMaskInvocation(BaseInvocation):
"""Creates mask for denoising model run."""
mask: ImageField = InputField(default=None, description="Image which will be masked", ui_order=1)
edge_radius: int = InputField(
default=16, ge=0, description="How far to blur/expand the edges of the mask", ui_order=2
)
coherence_mode: Literal["Gaussian Blur", "Box Blur", "Staged"] = InputField(default="Gaussian Blur", ui_order=3)
minimum_denoise: float = InputField(
default=0.0, ge=0, le=1, description="Minimum denoise level for the coherence region", ui_order=4
)
image: Optional[ImageField] = InputField(
default=None,
description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE",
title="[OPTIONAL] Image",
ui_order=6,
)
unet: Optional[UNetField] = InputField(
description="OPTIONAL: If the Unet is a specialized Inpainting model, masked_latents will be generated from the image with the VAE",
default=None,
input=Input.Connection,
title="[OPTIONAL] UNet",
ui_order=5,
)
vae: Optional[VAEField] = InputField(
default=None,
description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE",
title="[OPTIONAL] VAE",
input=Input.Connection,
ui_order=7,
)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=8)
fp32: bool = InputField(
default=DEFAULT_PRECISION == torch.float32,
description=FieldDescriptions.fp32,
ui_order=9,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> GradientMaskOutput:
mask_image = context.images.get_pil(self.mask.image_name, mode="L")
if self.edge_radius > 0:
if self.coherence_mode == "Box Blur":
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
else: # Gaussian Blur OR Staged
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
# redistribute blur so that the original edges are 0 and blur outwards to 1
blur_tensor = (blur_tensor - 0.5) * 2
threshold = 1 - self.minimum_denoise
if self.coherence_mode == "Staged":
# wherever the blur_tensor is less than fully masked, convert it to threshold
blur_tensor = torch.where((blur_tensor < 1) & (blur_tensor > 0), threshold, blur_tensor)
else:
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
else:
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1))
# compute a [0, 1] mask from the blur_tensor
expanded_mask = torch.where((blur_tensor < 1), 0, 1)
expanded_mask_image = Image.fromarray((expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L")
expanded_image_dto = context.images.save(expanded_mask_image)
masked_latents_name = None
if self.unet is not None and self.vae is not None and self.image is not None:
# all three fields must be present at the same time
main_model_config = context.models.get_config(self.unet.unet.key)
assert isinstance(main_model_config, MainConfigBase)
if main_model_config.variant is ModelVariantType.Inpaint:
mask = blur_tensor
vae_info: LoadedModel = context.models.load(self.vae.vae)
image = context.images.get_pil(self.image.image_name)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = image_tensor.unsqueeze(0)
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
masked_latents = ImageToLatentsInvocation.vae_encode(
vae_info, self.fp32, self.tiled, masked_image.clone()
)
masked_latents_name = context.tensors.save(tensor=masked_latents)
return GradientMaskOutput(
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=masked_latents_name, gradient=True),
expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),
)

View File

@@ -0,0 +1,61 @@
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, LatentsField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
# The Crop Latents node was copied from @skunkworxdark's implementation here:
# https://github.com/skunkworxdark/XYGrid_nodes/blob/74647fa9c1fa57d317a94bd43ca689af7f0aae5e/images_to_grids.py#L1117C1-L1167C80
@invocation(
"crop_latents",
title="Crop Latents",
tags=["latents", "crop"],
category="latents",
version="1.0.2",
)
# TODO(ryand): Named `CropLatentsCoreInvocation` to prevent a conflict with custom node `CropLatentsInvocation`.
# Currently, if the class names conflict then 'GET /openapi.json' fails.
class CropLatentsCoreInvocation(BaseInvocation):
"""Crops a latent-space tensor to a box specified in image-space. The box dimensions and coordinates must be
divisible by the latent scale factor of 8.
"""
latents: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
x: int = InputField(
ge=0,
multiple_of=LATENT_SCALE_FACTOR,
description="The left x coordinate (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
)
y: int = InputField(
ge=0,
multiple_of=LATENT_SCALE_FACTOR,
description="The top y coordinate (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
)
width: int = InputField(
ge=1,
multiple_of=LATENT_SCALE_FACTOR,
description="The width (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
)
height: int = InputField(
ge=1,
multiple_of=LATENT_SCALE_FACTOR,
description="The height (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.tensors.load(self.latents.latents_name)
x1 = self.x // LATENT_SCALE_FACTOR
y1 = self.y // LATENT_SCALE_FACTOR
x2 = x1 + (self.width // LATENT_SCALE_FACTOR)
y2 = y1 + (self.height // LATENT_SCALE_FACTOR)
cropped_latents = latents[..., y1:y2, x1:x2]
name = context.tensors.save(tensor=cropped_latents)
return LatentsOutput.build(latents_name=name, latents=cropped_latents)

View File

@@ -0,0 +1,848 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import inspect
from contextlib import ExitStack
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
import torch
import torchvision
import torchvision.transforms as T
from diffusers.configuration_utils import ConfigMixin
from diffusers.models.adapter import T2IAdapter
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.schedulers.scheduling_dpmsolver_sde import DPMSolverSDEScheduler
from diffusers.schedulers.scheduling_tcd import TCDScheduler
from diffusers.schedulers.scheduling_utils import SchedulerMixin as Scheduler
from pydantic import field_validator
from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPVisionModelWithProjection
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.fields import (
ConditioningField,
DenoiseMaskField,
FieldDescriptions,
Input,
InputField,
LatentsField,
UIType,
)
from invokeai.app.invocations.ip_adapter import IPAdapterField
from invokeai.app.invocations.model import ModelIdentifierField, UNetField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
ControlNetData,
StableDiffusionGeneratorPipeline,
T2IAdapterData,
)
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
IPAdapterConditioningInfo,
IPAdapterData,
Range,
SDXLConditioningInfo,
TextConditioningData,
TextConditioningRegions,
)
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.hotfixes import ControlNetModel
from invokeai.backend.util.mask import to_standard_float_mask
from invokeai.backend.util.silence_warnings import SilenceWarnings
def get_scheduler(
context: InvocationContext,
scheduler_info: ModelIdentifierField,
scheduler_name: str,
seed: int,
) -> Scheduler:
"""Load a scheduler and apply some scheduler-specific overrides."""
# TODO(ryand): Silently falling back to ddim seems like a bad idea. Look into why this was added and remove if
# possible.
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
orig_scheduler_info = context.models.load(scheduler_info)
with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config
if "_backup" in scheduler_config:
scheduler_config = scheduler_config["_backup"]
scheduler_config = {
**scheduler_config,
**scheduler_extra_config, # FIXME
"_backup": scheduler_config,
}
# make dpmpp_sde reproducable(seed can be passed only in initializer)
if scheduler_class is DPMSolverSDEScheduler:
scheduler_config["noise_sampler_seed"] = seed
scheduler = scheduler_class.from_config(scheduler_config)
# hack copied over from generate.py
if not hasattr(scheduler, "uses_inpainting_model"):
scheduler.uses_inpainting_model = lambda: False
assert isinstance(scheduler, Scheduler)
return scheduler
@invocation(
"denoise_latents",
title="Denoise Latents",
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
category="latents",
version="1.5.3",
)
class DenoiseLatentsInvocation(BaseInvocation):
"""Denoises noisy latents to decodable images"""
positive_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0
)
negative_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField(
description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=1
)
noise: Optional[LatentsField] = InputField(
default=None,
description=FieldDescriptions.noise,
input=Input.Connection,
ui_order=3,
)
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
cfg_scale: Union[float, List[float]] = InputField(
default=7.5, description=FieldDescriptions.cfg_scale, title="CFG Scale"
)
denoising_start: float = InputField(
default=0.0,
ge=0,
le=1,
description=FieldDescriptions.denoising_start,
)
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
scheduler: SCHEDULER_NAME_VALUES = InputField(
default="euler",
description=FieldDescriptions.scheduler,
ui_type=UIType.Scheduler,
)
unet: UNetField = InputField(
description=FieldDescriptions.unet,
input=Input.Connection,
title="UNet",
ui_order=2,
)
control: Optional[Union[ControlField, list[ControlField]]] = InputField(
default=None,
input=Input.Connection,
ui_order=5,
)
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]] = InputField(
description=FieldDescriptions.ip_adapter,
title="IP-Adapter",
default=None,
input=Input.Connection,
ui_order=6,
)
t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]] = InputField(
description=FieldDescriptions.t2i_adapter,
title="T2I-Adapter",
default=None,
input=Input.Connection,
ui_order=7,
)
cfg_rescale_multiplier: float = InputField(
title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
)
latents: Optional[LatentsField] = InputField(
default=None,
description=FieldDescriptions.latents,
input=Input.Connection,
ui_order=4,
)
denoise_mask: Optional[DenoiseMaskField] = InputField(
default=None,
description=FieldDescriptions.mask,
input=Input.Connection,
ui_order=8,
)
@field_validator("cfg_scale")
def ge_one(cls, v: Union[List[float], float]) -> Union[List[float], float]:
"""validate that all cfg_scale values are >= 1"""
if isinstance(v, list):
for i in v:
if i < 1:
raise ValueError("cfg_scale must be greater than 1")
else:
if v < 1:
raise ValueError("cfg_scale must be greater than 1")
return v
@staticmethod
def _get_text_embeddings_and_masks(
cond_list: list[ConditioningField],
context: InvocationContext,
device: torch.device,
dtype: torch.dtype,
) -> tuple[Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], list[Optional[torch.Tensor]]]:
"""Get the text embeddings and masks from the input conditioning fields."""
text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = []
text_embeddings_masks: list[Optional[torch.Tensor]] = []
for cond in cond_list:
cond_data = context.conditioning.load(cond.conditioning_name)
text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype))
mask = cond.mask
if mask is not None:
mask = context.tensors.load(mask.tensor_name)
text_embeddings_masks.append(mask)
return text_embeddings, text_embeddings_masks
@staticmethod
def _preprocess_regional_prompt_mask(
mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype
) -> torch.Tensor:
"""Preprocess a regional prompt mask to match the target height and width.
If mask is None, returns a mask of all ones with the target height and width.
If mask is not None, resizes the mask to the target height and width using 'nearest' interpolation.
Returns:
torch.Tensor: The processed mask. shape: (1, 1, target_height, target_width).
"""
if mask is None:
return torch.ones((1, 1, target_height, target_width), dtype=dtype)
mask = to_standard_float_mask(mask, out_dtype=dtype)
tf = torchvision.transforms.Resize(
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
)
# Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w).
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
resized_mask = tf(mask)
return resized_mask
@staticmethod
def _concat_regional_text_embeddings(
text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]],
masks: Optional[list[Optional[torch.Tensor]]],
latent_height: int,
latent_width: int,
dtype: torch.dtype,
) -> tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[TextConditioningRegions]]:
"""Concatenate regional text embeddings into a single embedding and track the region masks accordingly."""
if masks is None:
masks = [None] * len(text_conditionings)
assert len(text_conditionings) == len(masks)
is_sdxl = type(text_conditionings[0]) is SDXLConditioningInfo
all_masks_are_none = all(mask is None for mask in masks)
text_embedding = []
pooled_embedding = None
add_time_ids = None
cur_text_embedding_len = 0
processed_masks = []
embedding_ranges = []
for prompt_idx, text_embedding_info in enumerate(text_conditionings):
mask = masks[prompt_idx]
if is_sdxl:
# We choose a random SDXLConditioningInfo's pooled_embeds and add_time_ids here, with a preference for
# prompts without a mask. We prefer prompts without a mask, because they are more likely to contain
# global prompt information. In an ideal case, there should be exactly one global prompt without a
# mask, but we don't enforce this.
# HACK(ryand): The fact that we have to choose a single pooled_embedding and add_time_ids here is a
# fundamental interface issue. The SDXL Compel nodes are not designed to be used in the way that we use
# them for regional prompting. Ideally, the DenoiseLatents invocation should accept a single
# pooled_embeds tensor and a list of standard text embeds with region masks. This change would be a
# pretty major breaking change to a popular node, so for now we use this hack.
if pooled_embedding is None or mask is None:
pooled_embedding = text_embedding_info.pooled_embeds
if add_time_ids is None or mask is None:
add_time_ids = text_embedding_info.add_time_ids
text_embedding.append(text_embedding_info.embeds)
if not all_masks_are_none:
embedding_ranges.append(
Range(
start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1]
)
)
processed_masks.append(
DenoiseLatentsInvocation._preprocess_regional_prompt_mask(
mask, latent_height, latent_width, dtype=dtype
)
)
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
text_embedding = torch.cat(text_embedding, dim=1)
assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len
regions = None
if not all_masks_are_none:
regions = TextConditioningRegions(
masks=torch.cat(processed_masks, dim=1),
ranges=embedding_ranges,
)
if is_sdxl:
return (
SDXLConditioningInfo(embeds=text_embedding, pooled_embeds=pooled_embedding, add_time_ids=add_time_ids),
regions,
)
return BasicConditioningInfo(embeds=text_embedding), regions
@staticmethod
def get_conditioning_data(
context: InvocationContext,
positive_conditioning_field: Union[ConditioningField, list[ConditioningField]],
negative_conditioning_field: Union[ConditioningField, list[ConditioningField]],
unet: UNet2DConditionModel,
latent_height: int,
latent_width: int,
cfg_scale: float | list[float],
steps: int,
cfg_rescale_multiplier: float,
) -> TextConditioningData:
# Normalize positive_conditioning_field and negative_conditioning_field to lists.
cond_list = positive_conditioning_field
if not isinstance(cond_list, list):
cond_list = [cond_list]
uncond_list = negative_conditioning_field
if not isinstance(uncond_list, list):
uncond_list = [uncond_list]
cond_text_embeddings, cond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
cond_list, context, unet.device, unet.dtype
)
uncond_text_embeddings, uncond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
uncond_list, context, unet.device, unet.dtype
)
cond_text_embedding, cond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
text_conditionings=cond_text_embeddings,
masks=cond_text_embedding_masks,
latent_height=latent_height,
latent_width=latent_width,
dtype=unet.dtype,
)
uncond_text_embedding, uncond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
text_conditionings=uncond_text_embeddings,
masks=uncond_text_embedding_masks,
latent_height=latent_height,
latent_width=latent_width,
dtype=unet.dtype,
)
if isinstance(cfg_scale, list):
assert len(cfg_scale) == steps, "cfg_scale (list) must have the same length as the number of steps"
conditioning_data = TextConditioningData(
uncond_text=uncond_text_embedding,
cond_text=cond_text_embedding,
uncond_regions=uncond_regions,
cond_regions=cond_regions,
guidance_scale=cfg_scale,
guidance_rescale_multiplier=cfg_rescale_multiplier,
)
return conditioning_data
@staticmethod
def create_pipeline(
unet: UNet2DConditionModel,
scheduler: Scheduler,
) -> StableDiffusionGeneratorPipeline:
class FakeVae:
class FakeVaeConfig:
def __init__(self) -> None:
self.block_out_channels = [0]
def __init__(self) -> None:
self.config = FakeVae.FakeVaeConfig()
return StableDiffusionGeneratorPipeline(
vae=FakeVae(), # TODO: oh...
text_encoder=None,
tokenizer=None,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
@staticmethod
def prep_control_data(
context: InvocationContext,
control_input: ControlField | list[ControlField] | None,
latents_shape: List[int],
exit_stack: ExitStack,
do_classifier_free_guidance: bool = True,
) -> list[ControlNetData] | None:
# Normalize control_input to a list.
control_list: list[ControlField]
if isinstance(control_input, ControlField):
control_list = [control_input]
elif isinstance(control_input, list):
control_list = control_input
elif control_input is None:
control_list = []
else:
raise ValueError(f"Unexpected control_input type: {type(control_input)}")
if len(control_list) == 0:
return None
# Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR.
_, _, latent_height, latent_width = latents_shape
control_height_resize = latent_height * LATENT_SCALE_FACTOR
control_width_resize = latent_width * LATENT_SCALE_FACTOR
controlnet_data: list[ControlNetData] = []
for control_info in control_list:
control_model = exit_stack.enter_context(context.models.load(control_info.control_model))
assert isinstance(control_model, ControlNetModel)
control_image_field = control_info.image
input_image = context.images.get_pil(control_image_field.image_name)
# self.image.image_type, self.image.image_name
# FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt?
# and do real check for classifier_free_guidance?
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
control_image = prepare_control_image(
image=input_image,
do_classifier_free_guidance=do_classifier_free_guidance,
width=control_width_resize,
height=control_height_resize,
# batch_size=batch_size * num_images_per_prompt,
# num_images_per_prompt=num_images_per_prompt,
device=control_model.device,
dtype=control_model.dtype,
control_mode=control_info.control_mode,
resize_mode=control_info.resize_mode,
)
control_item = ControlNetData(
model=control_model,
image_tensor=control_image,
weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent,
control_mode=control_info.control_mode,
# any resizing needed should currently be happening in prepare_control_image(),
# but adding resize_mode to ControlNetData in case needed in the future
resize_mode=control_info.resize_mode,
)
controlnet_data.append(control_item)
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
return controlnet_data
def prep_ip_adapter_image_prompts(
self,
context: InvocationContext,
ip_adapters: List[IPAdapterField],
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
"""Run the IPAdapter CLIPVisionModel, returning image prompt embeddings."""
image_prompts = []
for single_ip_adapter in ip_adapters:
with context.models.load(single_ip_adapter.ip_adapter_model) as ip_adapter_model:
assert isinstance(ip_adapter_model, IPAdapter)
image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model)
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
single_ipa_image_fields = single_ip_adapter.image
if not isinstance(single_ipa_image_fields, list):
single_ipa_image_fields = [single_ipa_image_fields]
single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_image_fields]
with image_encoder_model_info as image_encoder_model:
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
# Get image embeddings from CLIP and ImageProjModel.
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
single_ipa_images, image_encoder_model
)
image_prompts.append((image_prompt_embeds, uncond_image_prompt_embeds))
return image_prompts
def prep_ip_adapter_data(
self,
context: InvocationContext,
ip_adapters: List[IPAdapterField],
image_prompts: List[Tuple[torch.Tensor, torch.Tensor]],
exit_stack: ExitStack,
latent_height: int,
latent_width: int,
dtype: torch.dtype,
) -> Optional[List[IPAdapterData]]:
"""If IP-Adapter is enabled, then this function loads the requisite models and adds the image prompt conditioning data."""
ip_adapter_data_list = []
for single_ip_adapter, (image_prompt_embeds, uncond_image_prompt_embeds) in zip(
ip_adapters, image_prompts, strict=True
):
ip_adapter_model = exit_stack.enter_context(context.models.load(single_ip_adapter.ip_adapter_model))
mask_field = single_ip_adapter.mask
mask = context.tensors.load(mask_field.tensor_name) if mask_field is not None else None
mask = self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype)
ip_adapter_data_list.append(
IPAdapterData(
ip_adapter_model=ip_adapter_model,
weight=single_ip_adapter.weight,
target_blocks=single_ip_adapter.target_blocks,
begin_step_percent=single_ip_adapter.begin_step_percent,
end_step_percent=single_ip_adapter.end_step_percent,
ip_adapter_conditioning=IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds),
mask=mask,
)
)
return ip_adapter_data_list if len(ip_adapter_data_list) > 0 else None
def run_t2i_adapters(
self,
context: InvocationContext,
t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]],
latents_shape: list[int],
do_classifier_free_guidance: bool,
) -> Optional[list[T2IAdapterData]]:
if t2i_adapter is None:
return None
# Handle the possibility that t2i_adapter could be a list or a single T2IAdapterField.
if isinstance(t2i_adapter, T2IAdapterField):
t2i_adapter = [t2i_adapter]
if len(t2i_adapter) == 0:
return None
t2i_adapter_data = []
for t2i_adapter_field in t2i_adapter:
t2i_adapter_model_config = context.models.get_config(t2i_adapter_field.t2i_adapter_model.key)
t2i_adapter_loaded_model = context.models.load(t2i_adapter_field.t2i_adapter_model)
image = context.images.get_pil(t2i_adapter_field.image.image_name)
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
if t2i_adapter_model_config.base == BaseModelType.StableDiffusion1:
max_unet_downscale = 8
elif t2i_adapter_model_config.base == BaseModelType.StableDiffusionXL:
max_unet_downscale = 4
else:
raise ValueError(f"Unexpected T2I-Adapter base model type: '{t2i_adapter_model_config.base}'.")
t2i_adapter_model: T2IAdapter
with t2i_adapter_loaded_model as t2i_adapter_model:
total_downscale_factor = t2i_adapter_model.total_downscale_factor
# Resize the T2I-Adapter input image.
# We select the resize dimensions so that after the T2I-Adapter's total_downscale_factor is applied, the
# result will match the latent image's dimensions after max_unet_downscale is applied.
t2i_input_height = latents_shape[2] // max_unet_downscale * total_downscale_factor
t2i_input_width = latents_shape[3] // max_unet_downscale * total_downscale_factor
# Note: We have hard-coded `do_classifier_free_guidance=False`. This is because we only want to prepare
# a single image. If CFG is enabled, we will duplicate the resultant tensor after applying the
# T2I-Adapter model.
#
# Note: We re-use the `prepare_control_image(...)` from ControlNet for T2I-Adapter, because it has many
# of the same requirements (e.g. preserving binary masks during resize).
t2i_image = prepare_control_image(
image=image,
do_classifier_free_guidance=False,
width=t2i_input_width,
height=t2i_input_height,
num_channels=t2i_adapter_model.config["in_channels"], # mypy treats this as a FrozenDict
device=t2i_adapter_model.device,
dtype=t2i_adapter_model.dtype,
resize_mode=t2i_adapter_field.resize_mode,
)
adapter_state = t2i_adapter_model(t2i_image)
if do_classifier_free_guidance:
for idx, value in enumerate(adapter_state):
adapter_state[idx] = torch.cat([value] * 2, dim=0)
t2i_adapter_data.append(
T2IAdapterData(
adapter_state=adapter_state,
weight=t2i_adapter_field.weight,
begin_step_percent=t2i_adapter_field.begin_step_percent,
end_step_percent=t2i_adapter_field.end_step_percent,
)
)
return t2i_adapter_data
# original idea by https://github.com/AmericanPresidentJimmyCarter
# TODO: research more for second order schedulers timesteps
@staticmethod
def init_scheduler(
scheduler: Union[Scheduler, ConfigMixin],
device: torch.device,
steps: int,
denoising_start: float,
denoising_end: float,
seed: int,
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
assert isinstance(scheduler, ConfigMixin)
if scheduler.config.get("cpu_only", False):
scheduler.set_timesteps(steps, device="cpu")
timesteps = scheduler.timesteps.to(device=device)
else:
scheduler.set_timesteps(steps, device=device)
timesteps = scheduler.timesteps
# skip greater order timesteps
_timesteps = timesteps[:: scheduler.order]
# get start timestep index
t_start_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_start)))
t_start_idx = len(list(filter(lambda ts: ts >= t_start_val, _timesteps)))
# get end timestep index
t_end_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_end)))
t_end_idx = len(list(filter(lambda ts: ts >= t_end_val, _timesteps[t_start_idx:])))
# apply order to indexes
t_start_idx *= scheduler.order
t_end_idx *= scheduler.order
init_timestep = timesteps[t_start_idx : t_start_idx + 1]
timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx]
scheduler_step_kwargs: Dict[str, Any] = {}
scheduler_step_signature = inspect.signature(scheduler.step)
if "generator" in scheduler_step_signature.parameters:
# At some point, someone decided that schedulers that accept a generator should use the original seed with
# all bits flipped. I don't know the original rationale for this, but now we must keep it like this for
# reproducibility.
#
# These Invoke-supported schedulers accept a generator as of 2024-06-04:
# - DDIMScheduler
# - DDPMScheduler
# - DPMSolverMultistepScheduler
# - EulerAncestralDiscreteScheduler
# - EulerDiscreteScheduler
# - KDPM2AncestralDiscreteScheduler
# - LCMScheduler
# - TCDScheduler
scheduler_step_kwargs.update({"generator": torch.Generator(device=device).manual_seed(seed ^ 0xFFFFFFFF)})
if isinstance(scheduler, TCDScheduler):
scheduler_step_kwargs.update({"eta": 1.0})
return timesteps, init_timestep, scheduler_step_kwargs
def prep_inpaint_mask(
self, context: InvocationContext, latents: torch.Tensor
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], bool]:
if self.denoise_mask is None:
return None, None, False
mask = context.tensors.load(self.denoise_mask.mask_name)
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
if self.denoise_mask.masked_latents_name is not None:
masked_latents = context.tensors.load(self.denoise_mask.masked_latents_name)
else:
masked_latents = torch.where(mask < 0.5, 0.0, latents)
return 1 - mask, masked_latents, self.denoise_mask.gradient
@staticmethod
def prepare_noise_and_latents(
context: InvocationContext, noise_field: LatentsField | None, latents_field: LatentsField | None
) -> Tuple[int, torch.Tensor | None, torch.Tensor]:
"""Depending on the workflow, we expect different combinations of noise and latents to be provided. This
function handles preparing these values accordingly.
Expected workflows:
- Text-to-Image Denoising: `noise` is provided, `latents` is not. `latents` is initialized to zeros.
- Image-to-Image Denoising: `noise` and `latents` are both provided.
- Text-to-Image SDXL Refiner Denoising: `latents` is provided, `noise` is not.
- Image-to-Image SDXL Refiner Denoising: `latents` is provided, `noise` is not.
NOTE(ryand): I wrote this docstring, but I am not the original author of this code. There may be other workflows
I haven't considered.
"""
noise = None
if noise_field is not None:
noise = context.tensors.load(noise_field.latents_name)
if latents_field is not None:
latents = context.tensors.load(latents_field.latents_name)
elif noise is not None:
latents = torch.zeros_like(noise)
else:
raise ValueError("'latents' or 'noise' must be provided!")
if noise is not None and noise.shape[1:] != latents.shape[1:]:
raise ValueError(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")
# The seed comes from (in order of priority): the noise field, the latents field, or 0.
seed = 0
if noise_field is not None and noise_field.seed is not None:
seed = noise_field.seed
elif latents_field is not None and latents_field.seed is not None:
seed = latents_field.seed
else:
seed = 0
return seed, noise, latents
@torch.no_grad()
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
def invoke(self, context: InvocationContext) -> LatentsOutput:
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
# below. Investigate whether this is appropriate.
t2i_adapter_data = self.run_t2i_adapters(
context,
self.t2i_adapter,
latents.shape,
do_classifier_free_guidance=True,
)
ip_adapters: List[IPAdapterField] = []
if self.ip_adapter is not None:
# ip_adapter could be a list or a single IPAdapterField. Normalize to a list here.
if isinstance(self.ip_adapter, list):
ip_adapters = self.ip_adapter
else:
ip_adapters = [self.ip_adapter]
# If there are IP adapters, the following line runs the adapters' CLIPVision image encoders to return
# a series of image conditioning embeddings. This is being done here rather than in the
# big model context below in order to use less VRAM on low-VRAM systems.
# The image prompts are then passed to prep_ip_adapter_data().
image_prompts = self.prep_ip_adapter_image_prompts(context=context, ip_adapters=ip_adapters)
# get the unet's config so that we can pass the base to dispatch_progress()
unet_config = context.models.get_config(self.unet.unet.key)
def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, unet_config.base)
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.unet.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
return
unet_info = context.models.load(self.unet.unet)
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
ExitStack() as exit_stack,
unet_info.model_on_device() as (model_state_dict, unet),
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
set_seamless(unet, self.unet.seamless_axes), # FIXME
# Apply the LoRA after unet has been moved to its target device for faster patching.
ModelPatcher.apply_lora_unet(
unet,
loras=_lora_loader(),
model_state_dict=model_state_dict,
),
):
assert isinstance(unet, UNet2DConditionModel)
latents = latents.to(device=unet.device, dtype=unet.dtype)
if noise is not None:
noise = noise.to(device=unet.device, dtype=unet.dtype)
if mask is not None:
mask = mask.to(device=unet.device, dtype=unet.dtype)
if masked_latents is not None:
masked_latents = masked_latents.to(device=unet.device, dtype=unet.dtype)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
seed=seed,
)
pipeline = self.create_pipeline(unet, scheduler)
_, _, latent_height, latent_width = latents.shape
conditioning_data = self.get_conditioning_data(
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
unet=unet,
latent_height=latent_height,
latent_width=latent_width,
cfg_scale=self.cfg_scale,
steps=self.steps,
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
)
controlnet_data = self.prep_control_data(
context=context,
control_input=self.control,
latents_shape=latents.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
)
ip_adapter_data = self.prep_ip_adapter_data(
context=context,
ip_adapters=ip_adapters,
image_prompts=image_prompts,
exit_stack=exit_stack,
latent_height=latent_height,
latent_width=latent_width,
dtype=unet.dtype,
)
timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
scheduler,
device=unet.device,
steps=self.steps,
denoising_start=self.denoising_start,
denoising_end=self.denoising_end,
seed=seed,
)
result_latents = pipeline.latents_from_embeddings(
latents=latents,
timesteps=timesteps,
init_timestep=init_timestep,
noise=noise,
seed=seed,
mask=mask,
masked_latents=masked_latents,
is_gradient_mask=gradient_mask,
scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data,
control_data=controlnet_data,
ip_adapter_data=ip_adapter_data,
t2i_adapter_data=t2i_adapter_data,
callback=step_callback,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.to("cpu")
TorchDevice.empty_cache()
name = context.tensors.save(tensor=result_latents)
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)

View File

@@ -0,0 +1,65 @@
import math
from typing import Tuple
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField
from invokeai.app.invocations.model import UNetField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import BaseModelType
@invocation_output("ideal_size_output")
class IdealSizeOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
width: int = OutputField(description="The ideal width of the image (in pixels)")
height: int = OutputField(description="The ideal height of the image (in pixels)")
@invocation(
"ideal_size",
title="Ideal Size",
tags=["latents", "math", "ideal_size"],
version="1.0.3",
)
class IdealSizeInvocation(BaseInvocation):
"""Calculates the ideal size for generation to avoid duplication"""
width: int = InputField(default=1024, description="Final image width")
height: int = InputField(default=576, description="Final image height")
unet: UNetField = InputField(default=None, description=FieldDescriptions.unet)
multiplier: float = InputField(
default=1.0,
description="Amount to multiply the model's dimensions by when calculating the ideal size (may result in "
"initial generation artifacts if too large)",
)
def trim_to_multiple_of(self, *args: int, multiple_of: int = LATENT_SCALE_FACTOR) -> Tuple[int, ...]:
return tuple((x - x % multiple_of) for x in args)
def invoke(self, context: InvocationContext) -> IdealSizeOutput:
unet_config = context.models.get_config(self.unet.unet.key)
aspect = self.width / self.height
dimension: float = 512
if unet_config.base == BaseModelType.StableDiffusion2:
dimension = 768
elif unet_config.base == BaseModelType.StableDiffusionXL:
dimension = 1024
dimension = dimension * self.multiplier
min_dimension = math.floor(dimension * 0.5)
model_area = dimension * dimension # hardcoded for now since all models are trained on square images
if aspect > 1.0:
init_height = max(min_dimension, math.sqrt(model_area / aspect))
init_width = init_height * aspect
else:
init_width = max(min_dimension, math.sqrt(model_area * aspect))
init_height = init_width / aspect
scaled_width, scaled_height = self.trim_to_multiple_of(
math.floor(init_width),
math.floor(init_height),
)
return IdealSizeOutput(width=scaled_width, height=scaled_height)

View File

@@ -0,0 +1,125 @@
from functools import singledispatchmethod
import einops
import torch
from diffusers.models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import DEFAULT_PRECISION
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
Input,
InputField,
)
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
@invocation(
"i2l",
title="Image to Latents",
tags=["latents", "image", "vae", "i2l"],
category="latents",
version="1.0.2",
)
class ImageToLatentsInvocation(BaseInvocation):
"""Encodes an image into latents."""
image: ImageField = InputField(
description="The image to encode",
)
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
@staticmethod
def vae_encode(vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor) -> torch.Tensor:
with vae_info as vae:
assert isinstance(vae, torch.nn.Module)
orig_dtype = vae.dtype
if upcast:
vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
vae.decoder.mid_block.attentions[0].processor,
(
AttnProcessor2_0,
XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
),
)
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if use_torch_2_0_or_xformers:
vae.post_quant_conv.to(orig_dtype)
vae.decoder.conv_in.to(orig_dtype)
vae.decoder.mid_block.to(orig_dtype)
# else:
# latents = latents.float()
else:
vae.to(dtype=torch.float16)
# latents = latents.half()
if tiled:
vae.enable_tiling()
else:
vae.disable_tiling()
# non_noised_latents_from_image
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
with torch.inference_mode():
latents = ImageToLatentsInvocation._encode_to_tensor(vae, image_tensor)
latents = vae.config.scaling_factor * latents
latents = latents.to(dtype=orig_dtype)
return latents
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.images.get_pil(self.image.image_name)
vae_info = context.models.load(self.vae.vae)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor)
latents = latents.to("cpu")
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
@singledispatchmethod
@staticmethod
def _encode_to_tensor(vae: AutoencoderKL, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
assert isinstance(vae, torch.nn.Module)
image_tensor_dist = vae.encode(image_tensor).latent_dist
latents: torch.Tensor = image_tensor_dist.sample().to(
dtype=vae.dtype
) # FIXME: uses torch.randn. make reproducible!
return latents
@_encode_to_tensor.register
@staticmethod
def _(vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
assert isinstance(vae, torch.nn.Module)
latents: torch.FloatTensor = vae.encode(image_tensor).latents
return latents

View File

@@ -42,15 +42,16 @@ class InfillImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Infill the image with the specified method"""
pass
def load_image(self, context: InvocationContext) -> tuple[Image.Image, bool]:
def load_image(self) -> tuple[Image.Image, bool]:
"""Process the image to have an alpha channel before being infilled"""
image = context.images.get_pil(self.image.image_name)
image = self._context.images.get_pil(self.image.image_name)
has_alpha = True if image.mode == "RGBA" else False
return image, has_alpha
def invoke(self, context: InvocationContext) -> ImageOutput:
self._context = context
# Retrieve and process image to be infilled
input_image, has_alpha = self.load_image(context)
input_image, has_alpha = self.load_image()
# If the input image has no alpha channel, return it
if has_alpha is False:
@@ -133,8 +134,12 @@ class LaMaInfillInvocation(InfillImageProcessorInvocation):
"""Infills transparent areas of an image using the LaMa model"""
def infill(self, image: Image.Image):
lama = LaMA()
return lama(image)
with self._context.models.load_remote_model(
source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
loader=LaMA.load_jit_model,
) as model:
lama = LaMA(model)
return lama(image)
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,127 @@
import torch
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import DEFAULT_PRECISION
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
LatentsField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.stable_diffusion import set_seamless
from invokeai.backend.util.devices import TorchDevice
@invocation(
"l2i",
title="Latents to Image",
tags=["latents", "image", "vae", "l2i"],
category="latents",
version="1.2.2",
)
class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates an image from latents."""
latents: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
@staticmethod
def vae_decode(
context: InvocationContext,
vae_info: LoadedModel,
seamless_axes: list[str],
latents: torch.Tensor,
use_fp32: bool,
use_tiling: bool,
) -> Image.Image:
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
with set_seamless(vae_info.model, seamless_axes), vae_info as vae:
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
latents = latents.to(vae.device)
if use_fp32:
vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
vae.decoder.mid_block.attentions[0].processor,
(
AttnProcessor2_0,
XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
),
)
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if use_torch_2_0_or_xformers:
vae.post_quant_conv.to(latents.dtype)
vae.decoder.conv_in.to(latents.dtype)
vae.decoder.mid_block.to(latents.dtype)
else:
latents = latents.float()
else:
vae.to(dtype=torch.float16)
latents = latents.half()
if use_tiling or context.config.get().force_tiled_decode:
vae.enable_tiling()
else:
vae.disable_tiling()
# clear memory as vae decode can request a lot
TorchDevice.empty_cache()
with torch.inference_mode():
# copied from diffusers pipeline
latents = latents / vae.config.scaling_factor
image = vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) # denormalize
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
np_image = image.cpu().permute(0, 2, 3, 1).float().numpy()
image = VaeImageProcessor.numpy_to_pil(np_image)[0]
TorchDevice.empty_cache()
return image
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae)
image = self.vae_decode(
context=context,
vae_info=vae_info,
seamless_axes=self.vae.seamless_axes,
latents=latents,
use_fp32=self.fp32,
use_tiling=self.tiled,
)
image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto)

View File

@@ -0,0 +1,103 @@
from typing import Literal
import torch
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
LatentsField,
)
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.util.devices import TorchDevice
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
@invocation(
"lresize",
title="Resize Latents",
tags=["latents", "resize"],
category="latents",
version="1.0.2",
)
class ResizeLatentsInvocation(BaseInvocation):
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
latents: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
width: int = InputField(
ge=64,
multiple_of=LATENT_SCALE_FACTOR,
description=FieldDescriptions.width,
)
height: int = InputField(
ge=64,
multiple_of=LATENT_SCALE_FACTOR,
description=FieldDescriptions.width,
)
mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.tensors.load(self.latents.latents_name)
device = TorchDevice.choose_torch_device()
resized_latents = torch.nn.functional.interpolate(
latents.to(device),
size=(self.height // LATENT_SCALE_FACTOR, self.width // LATENT_SCALE_FACTOR),
mode=self.mode,
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
resized_latents = resized_latents.to("cpu")
TorchDevice.empty_cache()
name = context.tensors.save(tensor=resized_latents)
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
@invocation(
"lscale",
title="Scale Latents",
tags=["latents", "resize"],
category="latents",
version="1.0.2",
)
class ScaleLatentsInvocation(BaseInvocation):
"""Scales latents by a given factor."""
latents: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
scale_factor: float = InputField(gt=0, description=FieldDescriptions.scale_factor)
mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.tensors.load(self.latents.latents_name)
device = TorchDevice.choose_torch_device()
# resizing
resized_latents = torch.nn.functional.interpolate(
latents.to(device),
scale_factor=self.scale_factor,
mode=self.mode,
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
resized_latents = resized_latents.to("cpu")
TorchDevice.empty_cache()
name = context.tensors.save(tensor=resized_latents)
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)

View File

@@ -0,0 +1,34 @@
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
from invokeai.app.invocations.fields import (
FieldDescriptions,
InputField,
OutputField,
UIType,
)
from invokeai.app.services.shared.invocation_context import InvocationContext
@invocation_output("scheduler_output")
class SchedulerOutput(BaseInvocationOutput):
scheduler: SCHEDULER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler)
@invocation(
"scheduler",
title="Scheduler",
tags=["scheduler"],
category="latents",
version="1.0.0",
)
class SchedulerInvocation(BaseInvocation):
"""Selects a scheduler."""
scheduler: SCHEDULER_NAME_VALUES = InputField(
default="euler",
description=FieldDescriptions.scheduler,
ui_type=UIType.Scheduler,
)
def invoke(self, context: InvocationContext) -> SchedulerOutput:
return SchedulerOutput(scheduler=self.scheduler)

View File

@@ -0,0 +1,268 @@
import copy
from contextlib import ExitStack
from typing import Iterator, Tuple
import torch
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from pydantic import field_validator
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation, get_scheduler
from invokeai.app.invocations.fields import (
ConditioningField,
FieldDescriptions,
Input,
InputField,
LatentsField,
UIType,
)
from invokeai.app.invocations.model import UNetField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
MultiDiffusionPipeline,
MultiDiffusionRegionConditioning,
)
from invokeai.backend.tiles.tiles import (
calc_tiles_min_overlap,
)
from invokeai.backend.tiles.utils import TBLR
from invokeai.backend.util.devices import TorchDevice
def crop_controlnet_data(control_data: ControlNetData, latent_region: TBLR) -> ControlNetData:
"""Crop a ControlNetData object to a region."""
# Create a shallow copy of the control_data object.
control_data_copy = copy.copy(control_data)
# The ControlNet reference image is the only attribute that needs to be cropped.
control_data_copy.image_tensor = control_data.image_tensor[
:,
:,
latent_region.top * LATENT_SCALE_FACTOR : latent_region.bottom * LATENT_SCALE_FACTOR,
latent_region.left * LATENT_SCALE_FACTOR : latent_region.right * LATENT_SCALE_FACTOR,
]
return control_data_copy
@invocation(
"tiled_multi_diffusion_denoise_latents",
title="Tiled Multi-Diffusion Denoise Latents",
tags=["upscale", "denoise"],
category="latents",
# TODO(ryand): Reset to 1.0.0 right before release.
version="1.0.0",
)
class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
"""Tiled Multi-Diffusion denoising.
This node handles automatically tiling the input image. Future iterations of
this node should allow the user to specify custom regions with different parameters for each region to harness the
full power of Multi-Diffusion.
This node has a similar interface to the `DenoiseLatents` node, but it has a reduced feature set (no IP-Adapter,
T2I-Adapter, masking, etc.).
"""
positive_conditioning: ConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
negative_conditioning: ConditioningField = InputField(
description=FieldDescriptions.negative_cond, input=Input.Connection
)
noise: LatentsField | None = InputField(
default=None,
description=FieldDescriptions.noise,
input=Input.Connection,
)
latents: LatentsField | None = InputField(
default=None,
description=FieldDescriptions.latents,
input=Input.Connection,
)
# TODO(ryand): Add multiple-of validation.
# TODO(ryand): Smaller defaults might make more sense.
tile_height: int = InputField(default=112, gt=0, description="Height of the tiles in latent space.")
tile_width: int = InputField(default=112, gt=0, description="Width of the tiles in latent space.")
tile_min_overlap: int = InputField(
default=16,
gt=0,
description="The minimum overlap between adjacent tiles in latent space. The actual overlap may be larger than "
"this to evenly cover the entire image.",
)
steps: int = InputField(default=18, gt=0, description=FieldDescriptions.steps)
cfg_scale: float | list[float] = InputField(default=6.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
# TODO(ryand): The default here should probably be 0.0.
denoising_start: float = InputField(
default=0.65,
ge=0,
le=1,
description=FieldDescriptions.denoising_start,
)
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
scheduler: SCHEDULER_NAME_VALUES = InputField(
default="euler",
description=FieldDescriptions.scheduler,
ui_type=UIType.Scheduler,
)
unet: UNetField = InputField(
description=FieldDescriptions.unet,
input=Input.Connection,
title="UNet",
)
cfg_rescale_multiplier: float = InputField(
title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
)
control: ControlField | list[ControlField] | None = InputField(
default=None,
input=Input.Connection,
)
@field_validator("cfg_scale")
def ge_one(cls, v: list[float] | float) -> list[float] | float:
"""Validate that all cfg_scale values are >= 1"""
if isinstance(v, list):
for i in v:
if i < 1:
raise ValueError("cfg_scale must be greater than 1")
else:
if v < 1:
raise ValueError("cfg_scale must be greater than 1")
return v
@staticmethod
def create_pipeline(
unet: UNet2DConditionModel,
scheduler: SchedulerMixin,
) -> MultiDiffusionPipeline:
# TODO(ryand): Get rid of this FakeVae hack.
class FakeVae:
class FakeVaeConfig:
def __init__(self) -> None:
self.block_out_channels = [0]
def __init__(self) -> None:
self.config = FakeVae.FakeVaeConfig()
return MultiDiffusionPipeline(
vae=FakeVae(), # TODO: oh...
text_encoder=None,
tokenizer=None,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
seed, noise, latents = DenoiseLatentsInvocation.prepare_noise_and_latents(context, self.noise, self.latents)
_, _, latent_height, latent_width = latents.shape
# Calculate the tile locations to cover the latent-space image.
# TODO(ryand): Add constraints on the tile params. Is there a multiple-of constraint?
tiles = calc_tiles_min_overlap(
image_height=latent_height,
image_width=latent_width,
tile_height=self.tile_height,
tile_width=self.tile_width,
min_overlap=self.tile_min_overlap,
)
# Prepare an iterator that yields the UNet's LoRA models and their weights.
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.unet.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
# Load the UNet model.
unet_info = context.models.load(self.unet.unet)
with ExitStack() as exit_stack, unet_info as unet, ModelPatcher.apply_lora_unet(unet, _lora_loader()):
assert isinstance(unet, UNet2DConditionModel)
latents = latents.to(device=unet.device, dtype=unet.dtype)
if noise is not None:
noise = noise.to(device=unet.device, dtype=unet.dtype)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
seed=seed,
)
pipeline = self.create_pipeline(unet=unet, scheduler=scheduler)
# Prepare the prompt conditioning data. The same prompt conditioning is applied to all tiles.
conditioning_data = DenoiseLatentsInvocation.get_conditioning_data(
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
unet=unet,
latent_height=self.tile_height,
latent_width=self.tile_width,
cfg_scale=self.cfg_scale,
steps=self.steps,
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
)
controlnet_data = DenoiseLatentsInvocation.prep_control_data(
context=context,
control_input=self.control,
latents_shape=list(latents.shape),
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
)
# Split the controlnet_data into tiles.
# controlnet_data_tiles[t][c] is the c'th control data for the t'th tile.
controlnet_data_tiles: list[list[ControlNetData]] = []
for tile in tiles:
tile_controlnet_data = [crop_controlnet_data(cn, tile.coords) for cn in controlnet_data or []]
controlnet_data_tiles.append(tile_controlnet_data)
# Prepare the MultiDiffusionRegionConditioning list.
multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning] = []
for tile, tile_controlnet_data in zip(tiles, controlnet_data_tiles, strict=True):
multi_diffusion_conditioning.append(
MultiDiffusionRegionConditioning(
region=tile.coords,
text_conditioning_data=conditioning_data,
control_data=tile_controlnet_data,
)
)
timesteps, init_timestep, scheduler_step_kwargs = DenoiseLatentsInvocation.init_scheduler(
scheduler,
device=unet.device,
steps=self.steps,
denoising_start=self.denoising_start,
denoising_end=self.denoising_end,
seed=seed,
)
# Run Multi-Diffusion denoising.
result_latents = pipeline.multi_diffusion_denoise(
multi_diffusion_conditioning=multi_diffusion_conditioning,
latents=latents,
scheduler_step_kwargs=scheduler_step_kwargs,
noise=noise,
timesteps=timesteps,
init_timestep=init_timestep,
# TODO(ryand): Add proper callback.
callback=lambda x: None,
)
# TODO(ryand): I copied this from DenoiseLatentsInvocation. I'm not sure if it's actually important.
result_latents = result_latents.to("cpu")
TorchDevice.empty_cache()
name = context.tensors.save(tensor=result_latents)
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)

View File

@@ -0,0 +1,380 @@
from contextlib import ExitStack
from typing import Iterator, Tuple
import numpy as np
import numpy.typing as npt
import torch
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from PIL import Image
from pydantic import field_validator
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation, get_scheduler
from invokeai.app.invocations.fields import (
ConditioningField,
FieldDescriptions,
ImageField,
Input,
InputField,
UIType,
)
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
from invokeai.app.invocations.latents_to_image import LatentsToImageInvocation
from invokeai.app.invocations.model import ModelIdentifierField, UNetField, VAEField
from invokeai.app.invocations.noise import get_noise
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, prepare_control_image
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, image_resized_to_grid_as_tensor
from invokeai.backend.tiles.tiles import calc_tiles_with_overlap, merge_tiles_with_linear_blending
from invokeai.backend.tiles.utils import Tile
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.hotfixes import ControlNetModel
@invocation(
"tiled_stable_diffusion_refine",
title="Tiled Stable Diffusion Refine",
tags=["upscale", "denoise"],
category="latents",
version="1.0.0",
)
class TiledStableDiffusionRefineInvocation(BaseInvocation):
"""A tiled Stable Diffusion pipeline for refining high resolution images. This invocation is intended to be used to
refine an image after upscaling i.e. it is the second step in a typical "tiled upscaling" workflow.
"""
image: ImageField = InputField(description="Image to be refined.")
positive_conditioning: ConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
negative_conditioning: ConditioningField = InputField(
description=FieldDescriptions.negative_cond, input=Input.Connection
)
# TODO(ryand): Add multiple-of validation.
tile_height: int = InputField(default=512, gt=0, description="Height of the tiles.")
tile_width: int = InputField(default=512, gt=0, description="Width of the tiles.")
tile_overlap: int = InputField(
default=16,
gt=0,
description="Target overlap between adjacent tiles (the last row/column may overlap more than this).",
)
steps: int = InputField(default=18, gt=0, description=FieldDescriptions.steps)
cfg_scale: float | list[float] = InputField(default=6.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
denoising_start: float = InputField(
default=0.65,
ge=0,
le=1,
description=FieldDescriptions.denoising_start,
)
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
scheduler: SCHEDULER_NAME_VALUES = InputField(
default="euler",
description=FieldDescriptions.scheduler,
ui_type=UIType.Scheduler,
)
unet: UNetField = InputField(
description=FieldDescriptions.unet,
input=Input.Connection,
title="UNet",
)
cfg_rescale_multiplier: float = InputField(
title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
)
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
vae_fp32: bool = InputField(
default=DEFAULT_PRECISION == torch.float32, description="Whether to use float32 precision when running the VAE."
)
# HACK(ryand): We probably want to allow the user to control all of the parameters in ControlField. But, we akwardly
# don't want to use the image field. Figure out how best to handle this.
# TODO(ryand): Currently, there is no ControlNet preprocessor applied to the tile images. In other words, we pretty
# much assume that it is a tile ControlNet. We need to decide how we want to handle this. E.g. find a way to support
# CN preprocessors, raise a clear warning when a non-tile CN model is selected, hardcode the supported CN models,
# etc.
control_model: ModelIdentifierField = InputField(
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
)
control_weight: float = InputField(default=0.6)
@field_validator("cfg_scale")
def ge_one(cls, v: list[float] | float) -> list[float] | float:
"""Validate that all cfg_scale values are >= 1"""
if isinstance(v, list):
for i in v:
if i < 1:
raise ValueError("cfg_scale must be greater than 1")
else:
if v < 1:
raise ValueError("cfg_scale must be greater than 1")
return v
@staticmethod
def crop_latents_to_tile(latents: torch.Tensor, image_tile: Tile) -> torch.Tensor:
"""Crop the latent-space tensor to the area corresponding to the image-space tile.
The tile coordinates must be divisible by the LATENT_SCALE_FACTOR.
"""
for coord in [image_tile.coords.top, image_tile.coords.left, image_tile.coords.right, image_tile.coords.bottom]:
if coord % LATENT_SCALE_FACTOR != 0:
raise ValueError(
f"The tile coordinates must all be divisible by the latent scale factor"
f" ({LATENT_SCALE_FACTOR}). {image_tile.coords=}."
)
assert latents.dim() == 4 # We expect: (batch_size, channels, height, width).
top = image_tile.coords.top // LATENT_SCALE_FACTOR
left = image_tile.coords.left // LATENT_SCALE_FACTOR
bottom = image_tile.coords.bottom // LATENT_SCALE_FACTOR
right = image_tile.coords.right // LATENT_SCALE_FACTOR
return latents[..., top:bottom, left:right]
def run_controlnet(
self,
image: Image.Image,
controlnet_model: ControlNetModel,
weight: float,
do_classifier_free_guidance: bool,
width: int,
height: int,
device: torch.device,
dtype: torch.dtype,
control_mode: CONTROLNET_MODE_VALUES = "balanced",
resize_mode: CONTROLNET_RESIZE_VALUES = "just_resize_simple",
) -> ControlNetData:
control_image = prepare_control_image(
image=image,
do_classifier_free_guidance=do_classifier_free_guidance,
width=width,
height=height,
device=device,
dtype=dtype,
control_mode=control_mode,
resize_mode=resize_mode,
)
return ControlNetData(
model=controlnet_model,
image_tensor=control_image,
weight=weight,
begin_step_percent=0.0,
end_step_percent=1.0,
control_mode=control_mode,
# Any resizing needed should currently be happening in prepare_control_image(), but adding resize_mode to
# ControlNetData in case needed in the future.
resize_mode=resize_mode,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
# TODO(ryand): Expose the seed parameter.
seed = 0
# Load the input image.
input_image = context.images.get_pil(self.image.image_name)
# Calculate the tile locations to cover the image.
# We have selected this tiling strategy to make it easy to achieve tile coords that are multiples of 8. This
# facilitates conversions between image space and latent space.
# TODO(ryand): Expose these tiling parameters. (Keep in mind the multiple-of constraints on these params.)
tiles = calc_tiles_with_overlap(
image_height=input_image.height,
image_width=input_image.width,
tile_height=self.tile_height,
tile_width=self.tile_width,
overlap=self.tile_overlap,
)
# Convert the input image to a torch.Tensor.
input_image_torch = image_resized_to_grid_as_tensor(input_image.convert("RGB"), multiple_of=LATENT_SCALE_FACTOR)
input_image_torch = input_image_torch.unsqueeze(0) # Add a batch dimension.
# Validate our assumptions about the shape of input_image_torch.
assert input_image_torch.dim() == 4 # We expect: (batch_size, channels, height, width).
assert input_image_torch.shape[:2] == (1, 3)
# Split the input image into tiles in torch.Tensor format.
image_tiles_torch: list[torch.Tensor] = []
for tile in tiles:
image_tile = input_image_torch[
:,
:,
tile.coords.top : tile.coords.bottom,
tile.coords.left : tile.coords.right,
]
image_tiles_torch.append(image_tile)
# Split the input image into tiles in numpy format.
# TODO(ryand): We currently maintain both np.ndarray and torch.Tensor tiles. Ideally, all operations should work
# with torch.Tensor tiles.
input_image_np = np.array(input_image)
image_tiles_np: list[npt.NDArray[np.uint8]] = []
for tile in tiles:
image_tile_np = input_image_np[
tile.coords.top : tile.coords.bottom,
tile.coords.left : tile.coords.right,
:,
]
image_tiles_np.append(image_tile_np)
# VAE-encode each image tile independently.
# TODO(ryand): Is there any advantage to VAE-encoding the entire image before splitting it into tiles? What
# about for decoding?
vae_info = context.models.load(self.vae.vae)
latent_tiles: list[torch.Tensor] = []
for image_tile_torch in image_tiles_torch:
latent_tiles.append(
ImageToLatentsInvocation.vae_encode(
vae_info=vae_info, upcast=self.vae_fp32, tiled=False, image_tensor=image_tile_torch
)
)
# Generate noise with dimensions corresponding to the full image in latent space.
# It is important that the noise tensor is generated at the full image dimension and then tiled, rather than
# generating for each tile independently. This ensures that overlapping regions between tiles use the same
# noise.
assert input_image_torch.shape[2] % LATENT_SCALE_FACTOR == 0
assert input_image_torch.shape[3] % LATENT_SCALE_FACTOR == 0
global_noise = get_noise(
width=input_image_torch.shape[3],
height=input_image_torch.shape[2],
device=TorchDevice.choose_torch_device(),
seed=seed,
downsampling_factor=LATENT_SCALE_FACTOR,
use_cpu=True,
)
# Crop the global noise into tiles.
noise_tiles = [self.crop_latents_to_tile(latents=global_noise, image_tile=t) for t in tiles]
# Prepare an iterator that yields the UNet's LoRA models and their weights.
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.unet.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
# Load the UNet model.
unet_info = context.models.load(self.unet.unet)
refined_latent_tiles: list[torch.Tensor] = []
with ExitStack() as exit_stack, unet_info as unet, ModelPatcher.apply_lora_unet(unet, _lora_loader()):
assert isinstance(unet, UNet2DConditionModel)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
seed=seed,
)
pipeline = DenoiseLatentsInvocation.create_pipeline(unet=unet, scheduler=scheduler)
# Prepare the prompt conditioning data. The same prompt conditioning is applied to all tiles.
# Assume that all tiles have the same shape.
_, _, latent_height, latent_width = latent_tiles[0].shape
conditioning_data = DenoiseLatentsInvocation.get_conditioning_data(
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
unet=unet,
latent_height=latent_height,
latent_width=latent_width,
cfg_scale=self.cfg_scale,
steps=self.steps,
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
)
# Load the ControlNet model.
# TODO(ryand): Support multiple ControlNet models.
controlnet_model = exit_stack.enter_context(context.models.load(self.control_model))
assert isinstance(controlnet_model, ControlNetModel)
# Denoise (i.e. "refine") each tile independently.
for image_tile_np, latent_tile, noise_tile in zip(image_tiles_np, latent_tiles, noise_tiles, strict=True):
assert latent_tile.shape == noise_tile.shape
# Prepare a PIL Image for ControlNet processing.
# TODO(ryand): This is a bit awkward that we have to prepare both torch.Tensor and PIL.Image versions of
# the tiles. Ideally, the ControlNet code should be able to work with Tensors.
image_tile_pil = Image.fromarray(image_tile_np)
# Run the ControlNet on the image tile.
height, width, _ = image_tile_np.shape
# The height and width must be evenly divisible by LATENT_SCALE_FACTOR. This is enforced earlier, but we
# validate this assumption here.
assert height % LATENT_SCALE_FACTOR == 0
assert width % LATENT_SCALE_FACTOR == 0
controlnet_data = self.run_controlnet(
image=image_tile_pil,
controlnet_model=controlnet_model,
weight=self.control_weight,
do_classifier_free_guidance=True,
width=width,
height=height,
device=controlnet_model.device,
dtype=controlnet_model.dtype,
control_mode="balanced",
resize_mode="just_resize_simple",
)
timesteps, init_timestep, scheduler_step_kwargs = DenoiseLatentsInvocation.init_scheduler(
scheduler,
device=unet.device,
steps=self.steps,
denoising_start=self.denoising_start,
denoising_end=self.denoising_end,
seed=seed,
)
# TODO(ryand): Think about when/if latents/noise should be moved off of the device to save VRAM.
latent_tile = latent_tile.to(device=unet.device, dtype=unet.dtype)
noise_tile = noise_tile.to(device=unet.device, dtype=unet.dtype)
refined_latent_tile = pipeline.latents_from_embeddings(
latents=latent_tile,
timesteps=timesteps,
init_timestep=init_timestep,
noise=noise_tile,
seed=seed,
mask=None,
masked_latents=None,
scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data,
control_data=[controlnet_data],
ip_adapter_data=None,
t2i_adapter_data=None,
callback=lambda x: None,
)
refined_latent_tiles.append(refined_latent_tile)
# VAE-decode each refined latent tile independently.
refined_image_tiles: list[Image.Image] = []
for refined_latent_tile in refined_latent_tiles:
refined_image_tile = LatentsToImageInvocation.vae_decode(
context=context,
vae_info=vae_info,
seamless_axes=self.vae.seamless_axes,
latents=refined_latent_tile,
use_fp32=self.vae_fp32,
use_tiling=False,
)
refined_image_tiles.append(refined_image_tile)
# TODO(ryand): I copied this from DenoiseLatentsInvocation. I'm not sure if it's actually important.
TorchDevice.empty_cache()
# Merge the refined image tiles back into a single image.
refined_image_tiles_np = [np.array(t) for t in refined_image_tiles]
merged_image_np = np.zeros(shape=(input_image.height, input_image.width, 3), dtype=np.uint8)
# TODO(ryand): Tune the blend_amount. Should this be exposed as a parameter?
merge_tiles_with_linear_blending(
dst_image=merged_image_np, tiles=tiles, tile_images=refined_image_tiles_np, blend_amount=self.tile_overlap
)
# Save the refined image and return its reference.
merged_image_pil = Image.fromarray(merged_image_np)
image_dto = context.images.save(image=merged_image_pil)
return ImageOutput.build(image_dto)

View File

@@ -1,5 +1,4 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
from pathlib import Path
from typing import Literal
import cv2
@@ -10,10 +9,8 @@ from pydantic import ConfigDict
from invokeai.app.invocations.fields import ImageField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
from invokeai.backend.util.devices import TorchDevice
from .baseinvocation import BaseInvocation, invocation
from .fields import InputField, WithBoard, WithMetadata
@@ -52,7 +49,6 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
rrdbnet_model = None
netscale = None
esrgan_model_path = None
if self.model_name in [
"RealESRGAN_x4plus.pth",
@@ -95,28 +91,25 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
context.logger.error(msg)
raise ValueError(msg)
esrgan_model_path = Path(context.config.get().models_path, f"core/upscaling/realesrgan/{self.model_name}")
# Downloads the ESRGAN model if it doesn't already exist
download_with_progress_bar(
name=self.model_name, url=ESRGAN_MODEL_URLS[self.model_name], dest_path=esrgan_model_path
loadnet = context.models.load_remote_model(
source=ESRGAN_MODEL_URLS[self.model_name],
)
upscaler = RealESRGAN(
scale=netscale,
model_path=esrgan_model_path,
model=rrdbnet_model,
half=False,
tile=self.tile_size,
)
with loadnet as loadnet_model:
upscaler = RealESRGAN(
scale=netscale,
loadnet=loadnet_model,
model=rrdbnet_model,
half=False,
tile=self.tile_size,
)
# prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL
# TODO: This strips the alpha... is that okay?
cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
upscaled_image = upscaler.upscale(cv2_image)
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
# prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL
# TODO: This strips the alpha... is that okay?
cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
upscaled_image = upscaler.upscale(cv2_image)
TorchDevice.empty_cache()
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
image_dto = context.images.save(image=pil_image)

View File

@@ -106,9 +106,7 @@ class BulkDownloadService(BulkDownloadBase):
if self._invoker:
assert bulk_download_id is not None
self._invoker.services.events.emit_bulk_download_started(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
bulk_download_id, bulk_download_item_id, bulk_download_item_name
)
def _signal_job_completed(
@@ -118,10 +116,8 @@ class BulkDownloadService(BulkDownloadBase):
if self._invoker:
assert bulk_download_id is not None
assert bulk_download_item_name is not None
self._invoker.services.events.emit_bulk_download_completed(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
self._invoker.services.events.emit_bulk_download_complete(
bulk_download_id, bulk_download_item_id, bulk_download_item_name
)
def _signal_job_failed(
@@ -131,11 +127,8 @@ class BulkDownloadService(BulkDownloadBase):
if self._invoker:
assert bulk_download_id is not None
assert exception is not None
self._invoker.services.events.emit_bulk_download_failed(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
error=str(exception),
self._invoker.services.events.emit_bulk_download_error(
bulk_download_id, bulk_download_item_id, bulk_download_item_name, str(exception)
)
def stop(self, *args, **kwargs):

View File

@@ -86,6 +86,7 @@ class InvokeAIAppConfig(BaseSettings):
patchmatch: Enable patchmatch inpaint code.
models_dir: Path to the models directory.
convert_cache_dir: Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.
download_cache_dir: Path to the directory that contains dynamically downloaded models.
legacy_conf_dir: Path to directory of legacy checkpoint config files.
db_dir: Path to InvokeAI databases directory.
outputs_dir: Path to directory for outputs.
@@ -146,7 +147,8 @@ class InvokeAIAppConfig(BaseSettings):
# PATHS
models_dir: Path = Field(default=Path("models"), description="Path to the models directory.")
convert_cache_dir: Path = Field(default=Path("models/.cache"), description="Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.")
convert_cache_dir: Path = Field(default=Path("models/.convert_cache"), description="Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.")
download_cache_dir: Path = Field(default=Path("models/.download_cache"), description="Path to the directory that contains dynamically downloaded models.")
legacy_conf_dir: Path = Field(default=Path("configs"), description="Path to directory of legacy checkpoint config files.")
db_dir: Path = Field(default=Path("databases"), description="Path to InvokeAI databases directory.")
outputs_dir: Path = Field(default=Path("outputs"), description="Path to directory for outputs.")
@@ -303,6 +305,11 @@ class InvokeAIAppConfig(BaseSettings):
"""Path to the converted cache models directory, resolved to an absolute path.."""
return self._resolve(self.convert_cache_dir)
@property
def download_cache_path(self) -> Path:
"""Path to the downloaded models directory, resolved to an absolute path.."""
return self._resolve(self.download_cache_dir)
@property
def custom_nodes_path(self) -> Path:
"""Path to the custom nodes directory, resolved to an absolute path.."""

View File

@@ -1,10 +1,17 @@
"""Init file for download queue."""
from .download_base import DownloadJob, DownloadJobStatus, DownloadQueueServiceBase, UnknownJobIDException
from .download_base import (
DownloadJob,
DownloadJobStatus,
DownloadQueueServiceBase,
MultiFileDownloadJob,
UnknownJobIDException,
)
from .download_default import DownloadQueueService, TqdmProgress
__all__ = [
"DownloadJob",
"MultiFileDownloadJob",
"DownloadQueueServiceBase",
"DownloadQueueService",
"TqdmProgress",

View File

@@ -5,11 +5,13 @@ from abc import ABC, abstractmethod
from enum import Enum
from functools import total_ordering
from pathlib import Path
from typing import Any, Callable, List, Optional
from typing import Any, Callable, List, Optional, Set, Union
from pydantic import BaseModel, Field, PrivateAttr
from pydantic.networks import AnyHttpUrl
from invokeai.backend.model_manager.metadata import RemoteModelFile
class DownloadJobStatus(str, Enum):
"""State of a download job."""
@@ -33,30 +35,23 @@ class ServiceInactiveException(Exception):
"""This exception is raised when user attempts to initiate a download before the service is started."""
DownloadEventHandler = Callable[["DownloadJob"], None]
DownloadExceptionHandler = Callable[["DownloadJob", Optional[Exception]], None]
SingleFileDownloadEventHandler = Callable[["DownloadJob"], None]
SingleFileDownloadExceptionHandler = Callable[["DownloadJob", Optional[Exception]], None]
MultiFileDownloadEventHandler = Callable[["MultiFileDownloadJob"], None]
MultiFileDownloadExceptionHandler = Callable[["MultiFileDownloadJob", Optional[Exception]], None]
DownloadEventHandler = Union[SingleFileDownloadEventHandler, MultiFileDownloadEventHandler]
DownloadExceptionHandler = Union[SingleFileDownloadExceptionHandler, MultiFileDownloadExceptionHandler]
@total_ordering
class DownloadJob(BaseModel):
"""Class to monitor and control a model download request."""
class DownloadJobBase(BaseModel):
"""Base of classes to monitor and control downloads."""
# required variables to be passed in on creation
source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.")
dest: Path = Field(description="Destination of downloaded model on local disk; a directory or file path")
access_token: Optional[str] = Field(default=None, description="authorization token for protected resources")
# automatically assigned on creation
id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
# set internally during download process
dest: Path = Field(description="Initial destination of downloaded model on local disk; a directory or file path")
download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file or directory")
status: DownloadJobStatus = Field(default=DownloadJobStatus.WAITING, description="Status of the download")
download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file")
job_started: Optional[str] = Field(default=None, description="Timestamp for when the download job started")
job_ended: Optional[str] = Field(
default=None, description="Timestamp for when the download job ende1d (completed or errored)"
)
content_type: Optional[str] = Field(default=None, description="Content type of downloaded file")
bytes: int = Field(default=0, description="Bytes downloaded so far")
total_bytes: int = Field(default=0, description="Total file size (bytes)")
@@ -74,14 +69,6 @@ class DownloadJob(BaseModel):
_on_cancelled: Optional[DownloadEventHandler] = PrivateAttr(default=None)
_on_error: Optional[DownloadExceptionHandler] = PrivateAttr(default=None)
def __hash__(self) -> int:
"""Return hash of the string representation of this object, for indexing."""
return hash(str(self))
def __le__(self, other: "DownloadJob") -> bool:
"""Return True if this job's priority is less than another's."""
return self.priority <= other.priority
def cancel(self) -> None:
"""Call to cancel the job."""
self._cancelled = True
@@ -98,6 +85,11 @@ class DownloadJob(BaseModel):
"""Return true if job completed without errors."""
return self.status == DownloadJobStatus.COMPLETED
@property
def waiting(self) -> bool:
"""Return true if the job is waiting to run."""
return self.status == DownloadJobStatus.WAITING
@property
def running(self) -> bool:
"""Return true if the job is running."""
@@ -154,6 +146,37 @@ class DownloadJob(BaseModel):
self._on_cancelled = on_cancelled
@total_ordering
class DownloadJob(DownloadJobBase):
"""Class to monitor and control a model download request."""
# required variables to be passed in on creation
source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.")
access_token: Optional[str] = Field(default=None, description="authorization token for protected resources")
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
# set internally during download process
job_started: Optional[str] = Field(default=None, description="Timestamp for when the download job started")
job_ended: Optional[str] = Field(
default=None, description="Timestamp for when the download job ende1d (completed or errored)"
)
content_type: Optional[str] = Field(default=None, description="Content type of downloaded file")
def __hash__(self) -> int:
"""Return hash of the string representation of this object, for indexing."""
return hash(str(self))
def __le__(self, other: "DownloadJob") -> bool:
"""Return True if this job's priority is less than another's."""
return self.priority <= other.priority
class MultiFileDownloadJob(DownloadJobBase):
"""Class to monitor and control multifile downloads."""
download_parts: Set[DownloadJob] = Field(default_factory=set, description="List of download parts.")
class DownloadQueueServiceBase(ABC):
"""Multithreaded queue for downloading models via URL."""
@@ -201,6 +224,48 @@ class DownloadQueueServiceBase(ABC):
"""
pass
@abstractmethod
def multifile_download(
self,
parts: List[RemoteModelFile],
dest: Path,
access_token: Optional[str] = None,
submit_job: bool = True,
on_start: Optional[DownloadEventHandler] = None,
on_progress: Optional[DownloadEventHandler] = None,
on_complete: Optional[DownloadEventHandler] = None,
on_cancelled: Optional[DownloadEventHandler] = None,
on_error: Optional[DownloadExceptionHandler] = None,
) -> MultiFileDownloadJob:
"""
Create and enqueue a multifile download job.
:param parts: Set of URL / filename pairs
:param dest: Path to download to. See below.
:param access_token: Access token to download the indicated files. If not provided,
each file's URL may be matched to an access token using the config file matching
system.
:param submit_job: If true [default] then submit the job for execution. Otherwise,
you will need to pass the job to submit_multifile_download().
:param on_start, on_progress, on_complete, on_error: Callbacks for the indicated
events.
:returns: A MultiFileDownloadJob object for monitoring the state of the download.
The `dest` argument is a Path object pointing to a directory. All downloads
with be placed inside this directory. The callbacks will receive the
MultiFileDownloadJob.
"""
pass
@abstractmethod
def submit_multifile_download(self, job: MultiFileDownloadJob) -> None:
"""
Enqueue a previously-created multi-file download job.
:param job: A MultiFileDownloadJob created with multifile_download()
"""
pass
@abstractmethod
def submit_download_job(
self,
@@ -252,7 +317,7 @@ class DownloadQueueServiceBase(ABC):
pass
@abstractmethod
def cancel_job(self, job: DownloadJob) -> None:
def cancel_job(self, job: DownloadJobBase) -> None:
"""Cancel the job, clearing partial downloads and putting it into ERROR state."""
pass
@@ -262,7 +327,7 @@ class DownloadQueueServiceBase(ABC):
pass
@abstractmethod
def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob:
def wait_for_job(self, job: DownloadJobBase, timeout: int = 0) -> DownloadJobBase:
"""Wait until the indicated download job has reached a terminal state.
This will block until the indicated install job has completed,

View File

@@ -8,24 +8,28 @@ import time
import traceback
from pathlib import Path
from queue import Empty, PriorityQueue
from typing import Any, Dict, List, Optional, Set
from typing import Any, Dict, List, Literal, Optional, Set
import requests
from pydantic.networks import AnyHttpUrl
from requests import HTTPError
from tqdm import tqdm
from invokeai.app.services.config import InvokeAIAppConfig, get_config
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.util.misc import get_iso_timestamp
from invokeai.backend.model_manager.metadata import RemoteModelFile
from invokeai.backend.util.logging import InvokeAILogger
from .download_base import (
DownloadEventHandler,
DownloadExceptionHandler,
DownloadJob,
DownloadJobBase,
DownloadJobCancelledException,
DownloadJobStatus,
DownloadQueueServiceBase,
MultiFileDownloadJob,
ServiceInactiveException,
UnknownJobIDException,
)
@@ -40,20 +44,24 @@ class DownloadQueueService(DownloadQueueServiceBase):
def __init__(
self,
max_parallel_dl: int = 5,
event_bus: Optional[EventServiceBase] = None,
app_config: Optional[InvokeAIAppConfig] = None,
event_bus: Optional["EventServiceBase"] = None,
requests_session: Optional[requests.sessions.Session] = None,
):
"""
Initialize DownloadQueue.
:param app_config: InvokeAIAppConfig object
:param max_parallel_dl: Number of simultaneous downloads allowed [5].
:param requests_session: Optional requests.sessions.Session object, for unit tests.
"""
self._app_config = app_config or get_config()
self._jobs: Dict[int, DownloadJob] = {}
self._download_part2parent: Dict[AnyHttpUrl, MultiFileDownloadJob] = {}
self._next_job_id = 0
self._queue: PriorityQueue[DownloadJob] = PriorityQueue()
self._stop_event = threading.Event()
self._job_completed_event = threading.Event()
self._job_terminated_event = threading.Event()
self._worker_pool: Set[threading.Thread] = set()
self._lock = threading.Lock()
self._logger = InvokeAILogger.get_logger("DownloadQueueService")
@@ -105,18 +113,16 @@ class DownloadQueueService(DownloadQueueServiceBase):
raise ServiceInactiveException(
"The download service is not currently accepting requests. Please call start() to initialize the service."
)
with self._lock:
job.id = self._next_job_id
self._next_job_id += 1
job.set_callbacks(
on_start=on_start,
on_progress=on_progress,
on_complete=on_complete,
on_cancelled=on_cancelled,
on_error=on_error,
)
self._jobs[job.id] = job
self._queue.put(job)
job.id = self._next_id()
job.set_callbacks(
on_start=on_start,
on_progress=on_progress,
on_complete=on_complete,
on_cancelled=on_cancelled,
on_error=on_error,
)
self._jobs[job.id] = job
self._queue.put(job)
def download(
self,
@@ -139,7 +145,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
source=source,
dest=dest,
priority=priority,
access_token=access_token,
access_token=access_token or self._lookup_access_token(source),
)
self.submit_download_job(
job,
@@ -151,10 +157,63 @@ class DownloadQueueService(DownloadQueueServiceBase):
)
return job
def multifile_download(
self,
parts: List[RemoteModelFile],
dest: Path,
access_token: Optional[str] = None,
submit_job: bool = True,
on_start: Optional[DownloadEventHandler] = None,
on_progress: Optional[DownloadEventHandler] = None,
on_complete: Optional[DownloadEventHandler] = None,
on_cancelled: Optional[DownloadEventHandler] = None,
on_error: Optional[DownloadExceptionHandler] = None,
) -> MultiFileDownloadJob:
mfdj = MultiFileDownloadJob(dest=dest, id=self._next_id())
mfdj.set_callbacks(
on_start=on_start,
on_progress=on_progress,
on_complete=on_complete,
on_cancelled=on_cancelled,
on_error=on_error,
)
for part in parts:
url = part.url
path = dest / part.path
assert path.is_relative_to(dest), "only relative download paths accepted"
job = DownloadJob(
source=url,
dest=path,
access_token=access_token,
)
mfdj.download_parts.add(job)
self._download_part2parent[job.source] = mfdj
if submit_job:
self.submit_multifile_download(mfdj)
return mfdj
def submit_multifile_download(self, job: MultiFileDownloadJob) -> None:
for download_job in job.download_parts:
self.submit_download_job(
download_job,
on_start=self._mfd_started,
on_progress=self._mfd_progress,
on_complete=self._mfd_complete,
on_cancelled=self._mfd_cancelled,
on_error=self._mfd_error,
)
def join(self) -> None:
"""Wait for all jobs to complete."""
self._queue.join()
def _next_id(self) -> int:
with self._lock:
id = self._next_job_id
self._next_job_id += 1
return id
def list_jobs(self) -> List[DownloadJob]:
"""List all the jobs."""
return list(self._jobs.values())
@@ -176,14 +235,14 @@ class DownloadQueueService(DownloadQueueServiceBase):
except KeyError as excp:
raise UnknownJobIDException("Unrecognized job") from excp
def cancel_job(self, job: DownloadJob) -> None:
def cancel_job(self, job: DownloadJobBase) -> None:
"""
Cancel the indicated job.
If it is running it will be stopped.
job.status will be set to DownloadJobStatus.CANCELLED
"""
with self._lock:
if job.status in [DownloadJobStatus.WAITING, DownloadJobStatus.RUNNING]:
job.cancel()
def cancel_all_jobs(self) -> None:
@@ -192,12 +251,12 @@ class DownloadQueueService(DownloadQueueServiceBase):
if not job.in_terminal_state:
self.cancel_job(job)
def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob:
def wait_for_job(self, job: DownloadJobBase, timeout: int = 0) -> DownloadJobBase:
"""Block until the indicated job has reached terminal state, or when timeout limit reached."""
start = time.time()
while not job.in_terminal_state:
if self._job_completed_event.wait(timeout=0.25): # in case we miss an event
self._job_completed_event.clear()
if self._job_terminated_event.wait(timeout=0.25): # in case we miss an event
self._job_terminated_event.clear()
if timeout > 0 and time.time() - start > timeout:
raise TimeoutError("Timeout exceeded")
return job
@@ -226,22 +285,25 @@ class DownloadQueueService(DownloadQueueServiceBase):
job.job_started = get_iso_timestamp()
self._do_download(job)
self._signal_job_complete(job)
except (OSError, HTTPError) as excp:
job.error_type = excp.__class__.__name__ + f"({str(excp)})"
job.error = traceback.format_exc()
self._signal_job_error(job, excp)
except DownloadJobCancelledException:
self._signal_job_cancelled(job)
self._cleanup_cancelled_job(job)
except Exception as excp:
job.error_type = excp.__class__.__name__ + f"({str(excp)})"
job.error = traceback.format_exc()
self._signal_job_error(job, excp)
finally:
job.job_ended = get_iso_timestamp()
self._job_completed_event.set() # signal a change to terminal state
self._job_terminated_event.set() # signal a change to terminal state
self._download_part2parent.pop(job.source, None) # if this is a subpart of a multipart job, remove it
self._job_terminated_event.set()
self._queue.task_done()
self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.")
def _do_download(self, job: DownloadJob) -> None:
"""Do the actual download."""
url = job.source
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
open_mode = "wb"
@@ -333,79 +395,53 @@ class DownloadQueueService(DownloadQueueServiceBase):
def _in_progress_path(self, path: Path) -> Path:
return path.with_name(path.name + ".downloading")
def _lookup_access_token(self, source: AnyHttpUrl) -> Optional[str]:
# Pull the token from config if it exists and matches the URL
token = None
for pair in self._app_config.remote_api_tokens or []:
if re.search(pair.url_regex, str(source)):
token = pair.token
break
return token
def _signal_job_started(self, job: DownloadJob) -> None:
job.status = DownloadJobStatus.RUNNING
if job.on_start:
try:
job.on_start(job)
except Exception as e:
self._logger.error(
f"An error occurred while processing the on_start callback: {traceback.format_exception(e)}"
)
self._execute_cb(job, "on_start")
if self._event_bus:
assert job.download_path
self._event_bus.emit_download_started(str(job.source), job.download_path.as_posix())
self._event_bus.emit_download_started(job)
def _signal_job_progress(self, job: DownloadJob) -> None:
if job.on_progress:
try:
job.on_progress(job)
except Exception as e:
self._logger.error(
f"An error occurred while processing the on_progress callback: {traceback.format_exception(e)}"
)
self._execute_cb(job, "on_progress")
if self._event_bus:
assert job.download_path
self._event_bus.emit_download_progress(
str(job.source),
download_path=job.download_path.as_posix(),
current_bytes=job.bytes,
total_bytes=job.total_bytes,
)
self._event_bus.emit_download_progress(job)
def _signal_job_complete(self, job: DownloadJob) -> None:
job.status = DownloadJobStatus.COMPLETED
if job.on_complete:
try:
job.on_complete(job)
except Exception as e:
self._logger.error(
f"An error occurred while processing the on_complete callback: {traceback.format_exception(e)}"
)
self._execute_cb(job, "on_complete")
if self._event_bus:
assert job.download_path
self._event_bus.emit_download_complete(
str(job.source), download_path=job.download_path.as_posix(), total_bytes=job.total_bytes
)
self._event_bus.emit_download_complete(job)
def _signal_job_cancelled(self, job: DownloadJob) -> None:
if job.status not in [DownloadJobStatus.RUNNING, DownloadJobStatus.WAITING]:
return
job.status = DownloadJobStatus.CANCELLED
if job.on_cancelled:
try:
job.on_cancelled(job)
except Exception as e:
self._logger.error(
f"An error occurred while processing the on_cancelled callback: {traceback.format_exception(e)}"
)
self._execute_cb(job, "on_cancelled")
if self._event_bus:
self._event_bus.emit_download_cancelled(str(job.source))
self._event_bus.emit_download_cancelled(job)
# if multifile download, then signal the parent
if parent_job := self._download_part2parent.get(job.source, None):
if not parent_job.in_terminal_state:
parent_job.status = DownloadJobStatus.CANCELLED
self._execute_cb(parent_job, "on_cancelled")
def _signal_job_error(self, job: DownloadJob, excp: Optional[Exception] = None) -> None:
job.status = DownloadJobStatus.ERROR
self._logger.error(f"{str(job.source)}: {traceback.format_exception(excp)}")
if job.on_error:
try:
job.on_error(job, excp)
except Exception as e:
self._logger.error(
f"An error occurred while processing the on_error callback: {traceback.format_exception(e)}"
)
self._execute_cb(job, "on_error", excp)
if self._event_bus:
assert job.error_type
assert job.error
self._event_bus.emit_download_error(str(job.source), error_type=job.error_type, error=job.error)
self._event_bus.emit_download_error(job)
def _cleanup_cancelled_job(self, job: DownloadJob) -> None:
self._logger.debug(f"Cleaning up leftover files from cancelled download job {job.download_path}")
@@ -416,6 +452,97 @@ class DownloadQueueService(DownloadQueueServiceBase):
except OSError as excp:
self._logger.warning(excp)
########################################
# callbacks used for multifile downloads
########################################
def _mfd_started(self, download_job: DownloadJob) -> None:
self._logger.info(f"File download started: {download_job.source}")
with self._lock:
mf_job = self._download_part2parent[download_job.source]
if mf_job.waiting:
mf_job.total_bytes = sum(x.total_bytes for x in mf_job.download_parts)
mf_job.status = DownloadJobStatus.RUNNING
assert download_job.download_path is not None
path_relative_to_destdir = download_job.download_path.relative_to(mf_job.dest)
mf_job.download_path = (
mf_job.dest / path_relative_to_destdir.parts[0]
) # keep just the first component of the path
self._execute_cb(mf_job, "on_start")
def _mfd_progress(self, download_job: DownloadJob) -> None:
with self._lock:
mf_job = self._download_part2parent[download_job.source]
if mf_job.cancelled:
for part in mf_job.download_parts:
self.cancel_job(part)
elif mf_job.running:
mf_job.total_bytes = sum(x.total_bytes for x in mf_job.download_parts)
mf_job.bytes = sum(x.total_bytes for x in mf_job.download_parts)
self._execute_cb(mf_job, "on_progress")
def _mfd_complete(self, download_job: DownloadJob) -> None:
self._logger.info(f"Download complete: {download_job.source}")
with self._lock:
mf_job = self._download_part2parent[download_job.source]
# are there any more active jobs left in this task?
if mf_job.running and all(x.complete for x in mf_job.download_parts):
mf_job.status = DownloadJobStatus.COMPLETED
self._execute_cb(mf_job, "on_complete")
# we're done with this sub-job
self._job_terminated_event.set()
def _mfd_cancelled(self, download_job: DownloadJob) -> None:
with self._lock:
mf_job = self._download_part2parent[download_job.source]
assert mf_job is not None
if not mf_job.in_terminal_state:
self._logger.warning(f"Download cancelled: {download_job.source}")
mf_job.cancel()
for s in mf_job.download_parts:
self.cancel_job(s)
def _mfd_error(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None:
with self._lock:
mf_job = self._download_part2parent[download_job.source]
assert mf_job is not None
if not mf_job.in_terminal_state:
mf_job.status = download_job.status
mf_job.error = download_job.error
mf_job.error_type = download_job.error_type
self._execute_cb(mf_job, "on_error", excp)
self._logger.error(
f"Cancelling {mf_job.dest} due to an error while downloading {download_job.source}: {str(excp)}"
)
for s in [x for x in mf_job.download_parts if x.running]:
self.cancel_job(s)
self._download_part2parent.pop(download_job.source)
self._job_terminated_event.set()
def _execute_cb(
self,
job: DownloadJob | MultiFileDownloadJob,
callback_name: Literal[
"on_start",
"on_progress",
"on_complete",
"on_cancelled",
"on_error",
],
excp: Optional[Exception] = None,
) -> None:
if callback := getattr(job, callback_name, None):
args = [job, excp] if excp else [job]
try:
callback(*args)
except Exception as e:
self._logger.error(
f"An error occurred while processing the {callback_name} callback: {traceback.format_exception(e)}"
)
def get_pc_name_max(directory: str) -> int:
if hasattr(os, "pathconf"):

View File

@@ -1,490 +1,199 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Optional
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import (
BatchStatus,
EnqueueBatchResult,
SessionQueueItem,
SessionQueueStatus,
from invokeai.app.services.events.events_common import (
BatchEnqueuedEvent,
BulkDownloadCompleteEvent,
BulkDownloadErrorEvent,
BulkDownloadStartedEvent,
DownloadCancelledEvent,
DownloadCompleteEvent,
DownloadErrorEvent,
DownloadProgressEvent,
DownloadStartedEvent,
EventBase,
InvocationCompleteEvent,
InvocationDenoiseProgressEvent,
InvocationErrorEvent,
InvocationStartedEvent,
ModelInstallCancelledEvent,
ModelInstallCompleteEvent,
ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent,
ModelInstallDownloadStartedEvent,
ModelInstallErrorEvent,
ModelInstallStartedEvent,
ModelLoadCompleteEvent,
ModelLoadStartedEvent,
QueueClearedEvent,
QueueItemStatusChangedEvent,
)
from invokeai.app.util.misc import get_timestamp
from invokeai.backend.model_manager import AnyModelConfig
from invokeai.backend.model_manager.config import SubModelType
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
if TYPE_CHECKING:
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.download.download_base import DownloadJob
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import (
BatchStatus,
EnqueueBatchResult,
SessionQueueItem,
SessionQueueStatus,
)
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
class EventServiceBase:
queue_event: str = "queue_event"
bulk_download_event: str = "bulk_download_event"
download_event: str = "download_event"
model_event: str = "model_event"
"""Basic event bus, to have an empty stand-in when not needed"""
def dispatch(self, event_name: str, payload: Any) -> None:
def dispatch(self, event: "EventBase") -> None:
pass
def _emit_bulk_download_event(self, event_name: str, payload: dict) -> None:
"""Bulk download events are emitted to a room with queue_id as the room name"""
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.bulk_download_event,
payload={"event": event_name, "data": payload},
)
# region: Invocation
def __emit_queue_event(self, event_name: str, payload: dict) -> None:
"""Queue events are emitted to a room with queue_id as the room name"""
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.queue_event,
payload={"event": event_name, "data": payload},
)
def emit_invocation_started(self, queue_item: "SessionQueueItem", invocation: "BaseInvocation") -> None:
"""Emitted when an invocation is started"""
self.dispatch(InvocationStartedEvent.build(queue_item, invocation))
def __emit_download_event(self, event_name: str, payload: dict) -> None:
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.download_event,
payload={"event": event_name, "data": payload},
)
def __emit_model_event(self, event_name: str, payload: dict) -> None:
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.model_event,
payload={"event": event_name, "data": payload},
)
# Define events here for every event in the system.
# This will make them easier to integrate until we find a schema generator.
def emit_generator_progress(
def emit_invocation_denoise_progress(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
node_id: str,
source_node_id: str,
progress_image: Optional[ProgressImage],
step: int,
order: int,
total_steps: int,
queue_item: "SessionQueueItem",
invocation: "BaseInvocation",
intermediate_state: PipelineIntermediateState,
progress_image: "ProgressImage",
) -> None:
"""Emitted when there is generation progress"""
self.__emit_queue_event(
event_name="generator_progress",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node_id": node_id,
"source_node_id": source_node_id,
"progress_image": progress_image.model_dump(mode="json") if progress_image is not None else None,
"step": step,
"order": order,
"total_steps": total_steps,
},
)
"""Emitted at each step during denoising of an invocation."""
self.dispatch(InvocationDenoiseProgressEvent.build(queue_item, invocation, intermediate_state, progress_image))
def emit_invocation_complete(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
result: dict,
node: dict,
source_node_id: str,
self, queue_item: "SessionQueueItem", invocation: "BaseInvocation", output: "BaseInvocationOutput"
) -> None:
"""Emitted when an invocation has completed"""
self.__emit_queue_event(
event_name="invocation_complete",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node": node,
"source_node_id": source_node_id,
"result": result,
},
)
"""Emitted when an invocation is complete"""
self.dispatch(InvocationCompleteEvent.build(queue_item, invocation, output))
def emit_invocation_error(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
node: dict,
source_node_id: str,
queue_item: "SessionQueueItem",
invocation: "BaseInvocation",
error_type: str,
error: str,
user_id: str | None,
project_id: str | None,
error_message: str,
error_traceback: str,
) -> None:
"""Emitted when an invocation has completed"""
self.__emit_queue_event(
event_name="invocation_error",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node": node,
"source_node_id": source_node_id,
"error_type": error_type,
"error": error,
"user_id": user_id,
"project_id": project_id,
},
)
"""Emitted when an invocation encounters an error"""
self.dispatch(InvocationErrorEvent.build(queue_item, invocation, error_type, error_message, error_traceback))
def emit_invocation_started(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
node: dict,
source_node_id: str,
) -> None:
"""Emitted when an invocation has started"""
self.__emit_queue_event(
event_name="invocation_started",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node": node,
"source_node_id": source_node_id,
},
)
# endregion
def emit_graph_execution_complete(
self, queue_id: str, queue_item_id: int, queue_batch_id: str, graph_execution_state_id: str
) -> None:
"""Emitted when a session has completed all invocations"""
self.__emit_queue_event(
event_name="graph_execution_state_complete",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
},
)
def emit_model_load_started(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> None:
"""Emitted when a model is requested"""
self.__emit_queue_event(
event_name="model_load_started",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"model_config": model_config.model_dump(mode="json"),
"submodel_type": submodel_type,
},
)
def emit_model_load_completed(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> None:
"""Emitted when a model is correctly loaded (returns model info)"""
self.__emit_queue_event(
event_name="model_load_completed",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"model_config": model_config.model_dump(mode="json"),
"submodel_type": submodel_type,
},
)
def emit_session_canceled(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
) -> None:
"""Emitted when a session is canceled"""
self.__emit_queue_event(
event_name="session_canceled",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
},
)
# region Queue
def emit_queue_item_status_changed(
self,
session_queue_item: SessionQueueItem,
batch_status: BatchStatus,
queue_status: SessionQueueStatus,
self, queue_item: "SessionQueueItem", batch_status: "BatchStatus", queue_status: "SessionQueueStatus"
) -> None:
"""Emitted when a queue item's status changes"""
self.__emit_queue_event(
event_name="queue_item_status_changed",
payload={
"queue_id": queue_status.queue_id,
"queue_item": {
"queue_id": session_queue_item.queue_id,
"item_id": session_queue_item.item_id,
"status": session_queue_item.status,
"batch_id": session_queue_item.batch_id,
"session_id": session_queue_item.session_id,
"error": session_queue_item.error,
"created_at": str(session_queue_item.created_at) if session_queue_item.created_at else None,
"updated_at": str(session_queue_item.updated_at) if session_queue_item.updated_at else None,
"started_at": str(session_queue_item.started_at) if session_queue_item.started_at else None,
"completed_at": str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
},
"batch_status": batch_status.model_dump(mode="json"),
"queue_status": queue_status.model_dump(mode="json"),
},
)
self.dispatch(QueueItemStatusChangedEvent.build(queue_item, batch_status, queue_status))
def emit_batch_enqueued(self, enqueue_result: EnqueueBatchResult) -> None:
def emit_batch_enqueued(self, enqueue_result: "EnqueueBatchResult") -> None:
"""Emitted when a batch is enqueued"""
self.__emit_queue_event(
event_name="batch_enqueued",
payload={
"queue_id": enqueue_result.queue_id,
"batch_id": enqueue_result.batch.batch_id,
"enqueued": enqueue_result.enqueued,
},
)
self.dispatch(BatchEnqueuedEvent.build(enqueue_result))
def emit_queue_cleared(self, queue_id: str) -> None:
"""Emitted when the queue is cleared"""
self.__emit_queue_event(
event_name="queue_cleared",
payload={"queue_id": queue_id},
)
"""Emitted when a queue is cleared"""
self.dispatch(QueueClearedEvent.build(queue_id))
def emit_download_started(self, source: str, download_path: str) -> None:
"""
Emit when a download job is started.
# endregion
:param url: The downloaded url
"""
self.__emit_download_event(
event_name="download_started",
payload={"source": source, "download_path": download_path},
)
# region Download
def emit_download_progress(self, source: str, download_path: str, current_bytes: int, total_bytes: int) -> None:
"""
Emit "download_progress" events at regular intervals during a download job.
def emit_download_started(self, job: "DownloadJob") -> None:
"""Emitted when a download is started"""
self.dispatch(DownloadStartedEvent.build(job))
:param source: The downloaded source
:param download_path: The local downloaded file
:param current_bytes: Number of bytes downloaded so far
:param total_bytes: The size of the file being downloaded (if known)
"""
self.__emit_download_event(
event_name="download_progress",
payload={
"source": source,
"download_path": download_path,
"current_bytes": current_bytes,
"total_bytes": total_bytes,
},
)
def emit_download_progress(self, job: "DownloadJob") -> None:
"""Emitted at intervals during a download"""
self.dispatch(DownloadProgressEvent.build(job))
def emit_download_complete(self, source: str, download_path: str, total_bytes: int) -> None:
"""
Emit a "download_complete" event at the end of a successful download.
def emit_download_complete(self, job: "DownloadJob") -> None:
"""Emitted when a download is completed"""
self.dispatch(DownloadCompleteEvent.build(job))
:param source: Source URL
:param download_path: Path to the locally downloaded file
:param total_bytes: The size of the downloaded file
"""
self.__emit_download_event(
event_name="download_complete",
payload={
"source": source,
"download_path": download_path,
"total_bytes": total_bytes,
},
)
def emit_download_cancelled(self, job: "DownloadJob") -> None:
"""Emitted when a download is cancelled"""
self.dispatch(DownloadCancelledEvent.build(job))
def emit_download_cancelled(self, source: str) -> None:
"""Emit a "download_cancelled" event in the event that the download was cancelled by user."""
self.__emit_download_event(
event_name="download_cancelled",
payload={
"source": source,
},
)
def emit_download_error(self, job: "DownloadJob") -> None:
"""Emitted when a download encounters an error"""
self.dispatch(DownloadErrorEvent.build(job))
def emit_download_error(self, source: str, error_type: str, error: str) -> None:
"""
Emit a "download_error" event when an download job encounters an exception.
# endregion
:param source: Source URL
:param error_type: The name of the exception that raised the error
:param error: The traceback from this error
"""
self.__emit_download_event(
event_name="download_error",
payload={
"source": source,
"error_type": error_type,
"error": error,
},
)
# region Model loading
def emit_model_install_downloading(
self,
source: str,
local_path: str,
bytes: int,
total_bytes: int,
parts: List[Dict[str, Union[str, int]]],
id: int,
def emit_model_load_started(self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None) -> None:
"""Emitted when a model load is started."""
self.dispatch(ModelLoadStartedEvent.build(config, submodel_type))
def emit_model_load_complete(
self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None
) -> None:
"""
Emit at intervals while the install job is in progress (remote models only).
"""Emitted when a model load is complete."""
self.dispatch(ModelLoadCompleteEvent.build(config, submodel_type))
:param source: Source of the model
:param local_path: Where model is downloading to
:param parts: Progress of downloading URLs that comprise the model, if any.
:param bytes: Number of bytes downloaded so far.
:param total_bytes: Total size of download, including all files.
This emits a Dict with keys "source", "local_path", "bytes" and "total_bytes".
"""
self.__emit_model_event(
event_name="model_install_downloading",
payload={
"source": source,
"local_path": local_path,
"bytes": bytes,
"total_bytes": total_bytes,
"parts": parts,
"id": id,
},
)
# endregion
def emit_model_install_downloads_done(self, source: str) -> None:
"""
Emit once when all parts are downloaded, but before the probing and registration start.
# region Model install
:param source: Source of the model; local path, repo_id or url
"""
self.__emit_model_event(
event_name="model_install_downloads_done",
payload={"source": source},
)
def emit_model_install_download_started(self, job: "ModelInstallJob") -> None:
"""Emitted at intervals while the install job is started (remote models only)."""
self.dispatch(ModelInstallDownloadStartedEvent.build(job))
def emit_model_install_running(self, source: str) -> None:
"""
Emit once when an install job becomes active.
def emit_model_install_download_progress(self, job: "ModelInstallJob") -> None:
"""Emitted at intervals while the install job is in progress (remote models only)."""
self.dispatch(ModelInstallDownloadProgressEvent.build(job))
:param source: Source of the model; local path, repo_id or url
"""
self.__emit_model_event(
event_name="model_install_running",
payload={"source": source},
)
def emit_model_install_downloads_complete(self, job: "ModelInstallJob") -> None:
self.dispatch(ModelInstallDownloadsCompleteEvent.build(job))
def emit_model_install_completed(self, source: str, key: str, id: int, total_bytes: Optional[int] = None) -> None:
"""
Emit when an install job is completed successfully.
def emit_model_install_started(self, job: "ModelInstallJob") -> None:
"""Emitted once when an install job is started (after any download)."""
self.dispatch(ModelInstallStartedEvent.build(job))
:param source: Source of the model; local path, repo_id or url
:param key: Model config record key
:param total_bytes: Size of the model (may be None for installation of a local path)
"""
self.__emit_model_event(
event_name="model_install_completed",
payload={"source": source, "total_bytes": total_bytes, "key": key, "id": id},
)
def emit_model_install_complete(self, job: "ModelInstallJob") -> None:
"""Emitted when an install job is completed successfully."""
self.dispatch(ModelInstallCompleteEvent.build(job))
def emit_model_install_cancelled(self, source: str, id: int) -> None:
"""
Emit when an install job is cancelled.
def emit_model_install_cancelled(self, job: "ModelInstallJob") -> None:
"""Emitted when an install job is cancelled."""
self.dispatch(ModelInstallCancelledEvent.build(job))
:param source: Source of the model; local path, repo_id or url
"""
self.__emit_model_event(
event_name="model_install_cancelled",
payload={"source": source, "id": id},
)
def emit_model_install_error(self, job: "ModelInstallJob") -> None:
"""Emitted when an install job encounters an exception."""
self.dispatch(ModelInstallErrorEvent.build(job))
def emit_model_install_error(self, source: str, error_type: str, error: str, id: int) -> None:
"""
Emit when an install job encounters an exception.
# endregion
:param source: Source of the model
:param error_type: The name of the exception
:param error: A text description of the exception
"""
self.__emit_model_event(
event_name="model_install_error",
payload={"source": source, "error_type": error_type, "error": error, "id": id},
)
# region Bulk image download
def emit_bulk_download_started(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> None:
"""Emitted when a bulk download starts"""
self._emit_bulk_download_event(
event_name="bulk_download_started",
payload={
"bulk_download_id": bulk_download_id,
"bulk_download_item_id": bulk_download_item_id,
"bulk_download_item_name": bulk_download_item_name,
},
)
"""Emitted when a bulk image download is started"""
self.dispatch(BulkDownloadStartedEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name))
def emit_bulk_download_completed(
def emit_bulk_download_complete(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> None:
"""Emitted when a bulk download completes"""
self._emit_bulk_download_event(
event_name="bulk_download_completed",
payload={
"bulk_download_id": bulk_download_id,
"bulk_download_item_id": bulk_download_item_id,
"bulk_download_item_name": bulk_download_item_name,
},
)
"""Emitted when a bulk image download is complete"""
self.dispatch(BulkDownloadCompleteEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name))
def emit_bulk_download_failed(
def emit_bulk_download_error(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
) -> None:
"""Emitted when a bulk download fails"""
self._emit_bulk_download_event(
event_name="bulk_download_failed",
payload={
"bulk_download_id": bulk_download_id,
"bulk_download_item_id": bulk_download_item_id,
"bulk_download_item_name": bulk_download_item_name,
"error": error,
},
"""Emitted when a bulk image download has an error"""
self.dispatch(
BulkDownloadErrorEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, error)
)
# endregion

View File

@@ -0,0 +1,628 @@
from math import floor
from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Generic, Optional, Protocol, TypeAlias, TypeVar
from fastapi_events.handlers.local import local_handler
from fastapi_events.registry.payload_schema import registry as payload_schema
from pydantic import BaseModel, ConfigDict, Field
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import (
QUEUE_ITEM_STATUS,
BatchStatus,
EnqueueBatchResult,
SessionQueueItem,
SessionQueueStatus,
)
from invokeai.app.services.shared.graph import AnyInvocation, AnyInvocationOutput
from invokeai.app.util.misc import get_timestamp
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
if TYPE_CHECKING:
from invokeai.app.services.download.download_base import DownloadJob
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
class EventBase(BaseModel):
"""Base class for all events. All events must inherit from this class.
Events must define a class attribute `__event_name__` to identify the event.
All other attributes should be defined as normal for a pydantic model.
A timestamp is automatically added to the event when it is created.
"""
__event_name__: ClassVar[str]
timestamp: int = Field(description="The timestamp of the event", default_factory=get_timestamp)
model_config = ConfigDict(json_schema_serialization_defaults_required=True)
@classmethod
def get_events(cls) -> set[type["EventBase"]]:
"""Get a set of all event models."""
event_subclasses: set[type["EventBase"]] = set()
for subclass in cls.__subclasses__():
# We only want to include subclasses that are event models, not intermediary classes
if hasattr(subclass, "__event_name__"):
event_subclasses.add(subclass)
event_subclasses.update(subclass.get_events())
return event_subclasses
TEvent = TypeVar("TEvent", bound=EventBase, contravariant=True)
FastAPIEvent: TypeAlias = tuple[str, TEvent]
"""
A tuple representing a `fastapi-events` event, with the event name and payload.
Provide a generic type to `TEvent` to specify the payload type.
"""
class FastAPIEventFunc(Protocol, Generic[TEvent]):
def __call__(self, event: FastAPIEvent[TEvent]) -> Optional[Coroutine[Any, Any, None]]: ...
def register_events(events: set[type[TEvent]] | type[TEvent], func: FastAPIEventFunc[TEvent]) -> None:
"""Register a function to handle specific events.
:param events: An event or set of events to handle
:param func: The function to handle the events
"""
events = events if isinstance(events, set) else {events}
for event in events:
assert hasattr(event, "__event_name__")
local_handler.register(event_name=event.__event_name__, _func=func) # pyright: ignore [reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
class QueueEventBase(EventBase):
"""Base class for queue events"""
queue_id: str = Field(description="The ID of the queue")
class QueueItemEventBase(QueueEventBase):
"""Base class for queue item events"""
item_id: int = Field(description="The ID of the queue item")
batch_id: str = Field(description="The ID of the queue batch")
class InvocationEventBase(QueueItemEventBase):
"""Base class for invocation events"""
session_id: str = Field(description="The ID of the session (aka graph execution state)")
queue_id: str = Field(description="The ID of the queue")
item_id: int = Field(description="The ID of the queue item")
batch_id: str = Field(description="The ID of the queue batch")
session_id: str = Field(description="The ID of the session (aka graph execution state)")
invocation: AnyInvocation = Field(description="The ID of the invocation")
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
@payload_schema.register
class InvocationStartedEvent(InvocationEventBase):
"""Event model for invocation_started"""
__event_name__ = "invocation_started"
@classmethod
def build(cls, queue_item: SessionQueueItem, invocation: AnyInvocation) -> "InvocationStartedEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
)
@payload_schema.register
class InvocationDenoiseProgressEvent(InvocationEventBase):
"""Event model for invocation_denoise_progress"""
__event_name__ = "invocation_denoise_progress"
progress_image: ProgressImage = Field(description="The progress image sent at each step during processing")
step: int = Field(description="The current step of the invocation")
total_steps: int = Field(description="The total number of steps in the invocation")
order: int = Field(description="The order of the invocation in the session")
percentage: float = Field(description="The percentage of completion of the invocation")
@classmethod
def build(
cls,
queue_item: SessionQueueItem,
invocation: AnyInvocation,
intermediate_state: PipelineIntermediateState,
progress_image: ProgressImage,
) -> "InvocationDenoiseProgressEvent":
step = intermediate_state.step
total_steps = intermediate_state.total_steps
order = intermediate_state.order
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
progress_image=progress_image,
step=step,
total_steps=total_steps,
order=order,
percentage=cls.calc_percentage(step, total_steps, order),
)
@staticmethod
def calc_percentage(step: int, total_steps: int, scheduler_order: float) -> float:
"""Calculate the percentage of completion of denoising."""
if total_steps == 0:
return 0.0
if scheduler_order == 2:
return floor((step + 1 + 1) / 2) / floor((total_steps + 1) / 2)
# order == 1
return (step + 1 + 1) / (total_steps + 1)
@payload_schema.register
class InvocationCompleteEvent(InvocationEventBase):
"""Event model for invocation_complete"""
__event_name__ = "invocation_complete"
result: AnyInvocationOutput = Field(description="The result of the invocation")
@classmethod
def build(
cls, queue_item: SessionQueueItem, invocation: AnyInvocation, result: AnyInvocationOutput
) -> "InvocationCompleteEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
result=result,
)
@payload_schema.register
class InvocationErrorEvent(InvocationEventBase):
"""Event model for invocation_error"""
__event_name__ = "invocation_error"
error_type: str = Field(description="The error type")
error_message: str = Field(description="The error message")
error_traceback: str = Field(description="The error traceback")
user_id: Optional[str] = Field(default=None, description="The ID of the user who created the invocation")
project_id: Optional[str] = Field(default=None, description="The ID of the user who created the invocation")
@classmethod
def build(
cls,
queue_item: SessionQueueItem,
invocation: AnyInvocation,
error_type: str,
error_message: str,
error_traceback: str,
) -> "InvocationErrorEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
user_id=getattr(queue_item, "user_id", None),
project_id=getattr(queue_item, "project_id", None),
)
@payload_schema.register
class QueueItemStatusChangedEvent(QueueItemEventBase):
"""Event model for queue_item_status_changed"""
__event_name__ = "queue_item_status_changed"
status: QUEUE_ITEM_STATUS = Field(description="The new status of the queue item")
error_type: Optional[str] = Field(default=None, description="The error type, if any")
error_message: Optional[str] = Field(default=None, description="The error message, if any")
error_traceback: Optional[str] = Field(default=None, description="The error traceback, if any")
created_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was created")
updated_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was last updated")
started_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was started")
completed_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was completed")
batch_status: BatchStatus = Field(description="The status of the batch")
queue_status: SessionQueueStatus = Field(description="The status of the queue")
session_id: str = Field(description="The ID of the session (aka graph execution state)")
@classmethod
def build(
cls, queue_item: SessionQueueItem, batch_status: BatchStatus, queue_status: SessionQueueStatus
) -> "QueueItemStatusChangedEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
status=queue_item.status,
error_type=queue_item.error_type,
error_message=queue_item.error_message,
error_traceback=queue_item.error_traceback,
created_at=str(queue_item.created_at) if queue_item.created_at else None,
updated_at=str(queue_item.updated_at) if queue_item.updated_at else None,
started_at=str(queue_item.started_at) if queue_item.started_at else None,
completed_at=str(queue_item.completed_at) if queue_item.completed_at else None,
batch_status=batch_status,
queue_status=queue_status,
)
@payload_schema.register
class BatchEnqueuedEvent(QueueEventBase):
"""Event model for batch_enqueued"""
__event_name__ = "batch_enqueued"
batch_id: str = Field(description="The ID of the batch")
enqueued: int = Field(description="The number of invocations enqueued")
requested: int = Field(
description="The number of invocations initially requested to be enqueued (may be less than enqueued if queue was full)"
)
priority: int = Field(description="The priority of the batch")
@classmethod
def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent":
return cls(
queue_id=enqueue_result.queue_id,
batch_id=enqueue_result.batch.batch_id,
enqueued=enqueue_result.enqueued,
requested=enqueue_result.requested,
priority=enqueue_result.priority,
)
@payload_schema.register
class QueueClearedEvent(QueueEventBase):
"""Event model for queue_cleared"""
__event_name__ = "queue_cleared"
@classmethod
def build(cls, queue_id: str) -> "QueueClearedEvent":
return cls(queue_id=queue_id)
class DownloadEventBase(EventBase):
"""Base class for events associated with a download"""
source: str = Field(description="The source of the download")
@payload_schema.register
class DownloadStartedEvent(DownloadEventBase):
"""Event model for download_started"""
__event_name__ = "download_started"
download_path: str = Field(description="The local path where the download is saved")
@classmethod
def build(cls, job: "DownloadJob") -> "DownloadStartedEvent":
assert job.download_path
return cls(source=str(job.source), download_path=job.download_path.as_posix())
@payload_schema.register
class DownloadProgressEvent(DownloadEventBase):
"""Event model for download_progress"""
__event_name__ = "download_progress"
download_path: str = Field(description="The local path where the download is saved")
current_bytes: int = Field(description="The number of bytes downloaded so far")
total_bytes: int = Field(description="The total number of bytes to be downloaded")
@classmethod
def build(cls, job: "DownloadJob") -> "DownloadProgressEvent":
assert job.download_path
return cls(
source=str(job.source),
download_path=job.download_path.as_posix(),
current_bytes=job.bytes,
total_bytes=job.total_bytes,
)
@payload_schema.register
class DownloadCompleteEvent(DownloadEventBase):
"""Event model for download_complete"""
__event_name__ = "download_complete"
download_path: str = Field(description="The local path where the download is saved")
total_bytes: int = Field(description="The total number of bytes downloaded")
@classmethod
def build(cls, job: "DownloadJob") -> "DownloadCompleteEvent":
assert job.download_path
return cls(source=str(job.source), download_path=job.download_path.as_posix(), total_bytes=job.total_bytes)
@payload_schema.register
class DownloadCancelledEvent(DownloadEventBase):
"""Event model for download_cancelled"""
__event_name__ = "download_cancelled"
@classmethod
def build(cls, job: "DownloadJob") -> "DownloadCancelledEvent":
return cls(source=str(job.source))
@payload_schema.register
class DownloadErrorEvent(DownloadEventBase):
"""Event model for download_error"""
__event_name__ = "download_error"
error_type: str = Field(description="The type of error")
error: str = Field(description="The error message")
@classmethod
def build(cls, job: "DownloadJob") -> "DownloadErrorEvent":
assert job.error_type
assert job.error
return cls(source=str(job.source), error_type=job.error_type, error=job.error)
class ModelEventBase(EventBase):
"""Base class for events associated with a model"""
@payload_schema.register
class ModelLoadStartedEvent(ModelEventBase):
"""Event model for model_load_started"""
__event_name__ = "model_load_started"
config: AnyModelConfig = Field(description="The model's config")
submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any")
@classmethod
def build(cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> "ModelLoadStartedEvent":
return cls(config=config, submodel_type=submodel_type)
@payload_schema.register
class ModelLoadCompleteEvent(ModelEventBase):
"""Event model for model_load_complete"""
__event_name__ = "model_load_complete"
config: AnyModelConfig = Field(description="The model's config")
submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any")
@classmethod
def build(cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> "ModelLoadCompleteEvent":
return cls(config=config, submodel_type=submodel_type)
@payload_schema.register
class ModelInstallDownloadStartedEvent(ModelEventBase):
"""Event model for model_install_download_started"""
__event_name__ = "model_install_download_started"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
local_path: str = Field(description="Where model is downloading to")
bytes: int = Field(description="Number of bytes downloaded so far")
total_bytes: int = Field(description="Total size of download, including all files")
parts: list[dict[str, int | str]] = Field(
description="Progress of downloading URLs that comprise the model, if any"
)
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadStartedEvent":
parts: list[dict[str, str | int]] = [
{
"url": str(x.source),
"local_path": str(x.download_path),
"bytes": x.bytes,
"total_bytes": x.total_bytes,
}
for x in job.download_parts
]
return cls(
id=job.id,
source=str(job.source),
local_path=job.local_path.as_posix(),
parts=parts,
bytes=job.bytes,
total_bytes=job.total_bytes,
)
@payload_schema.register
class ModelInstallDownloadProgressEvent(ModelEventBase):
"""Event model for model_install_download_progress"""
__event_name__ = "model_install_download_progress"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
local_path: str = Field(description="Where model is downloading to")
bytes: int = Field(description="Number of bytes downloaded so far")
total_bytes: int = Field(description="Total size of download, including all files")
parts: list[dict[str, int | str]] = Field(
description="Progress of downloading URLs that comprise the model, if any"
)
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadProgressEvent":
parts: list[dict[str, str | int]] = [
{
"url": str(x.source),
"local_path": str(x.download_path),
"bytes": x.bytes,
"total_bytes": x.total_bytes,
}
for x in job.download_parts
]
return cls(
id=job.id,
source=str(job.source),
local_path=job.local_path.as_posix(),
parts=parts,
bytes=job.bytes,
total_bytes=job.total_bytes,
)
@payload_schema.register
class ModelInstallDownloadsCompleteEvent(ModelEventBase):
"""Emitted once when an install job becomes active."""
__event_name__ = "model_install_downloads_complete"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadsCompleteEvent":
return cls(id=job.id, source=str(job.source))
@payload_schema.register
class ModelInstallStartedEvent(ModelEventBase):
"""Event model for model_install_started"""
__event_name__ = "model_install_started"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallStartedEvent":
return cls(id=job.id, source=str(job.source))
@payload_schema.register
class ModelInstallCompleteEvent(ModelEventBase):
"""Event model for model_install_complete"""
__event_name__ = "model_install_complete"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
key: str = Field(description="Model config record key")
total_bytes: Optional[int] = Field(description="Size of the model (may be None for installation of a local path)")
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallCompleteEvent":
assert job.config_out is not None
return cls(id=job.id, source=str(job.source), key=(job.config_out.key), total_bytes=job.total_bytes)
@payload_schema.register
class ModelInstallCancelledEvent(ModelEventBase):
"""Event model for model_install_cancelled"""
__event_name__ = "model_install_cancelled"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallCancelledEvent":
return cls(id=job.id, source=str(job.source))
@payload_schema.register
class ModelInstallErrorEvent(ModelEventBase):
"""Event model for model_install_error"""
__event_name__ = "model_install_error"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
error_type: str = Field(description="The name of the exception")
error: str = Field(description="A text description of the exception")
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallErrorEvent":
assert job.error_type is not None
assert job.error is not None
return cls(id=job.id, source=str(job.source), error_type=job.error_type, error=job.error)
class BulkDownloadEventBase(EventBase):
"""Base class for events associated with a bulk image download"""
bulk_download_id: str = Field(description="The ID of the bulk image download")
bulk_download_item_id: str = Field(description="The ID of the bulk image download item")
bulk_download_item_name: str = Field(description="The name of the bulk image download item")
@payload_schema.register
class BulkDownloadStartedEvent(BulkDownloadEventBase):
"""Event model for bulk_download_started"""
__event_name__ = "bulk_download_started"
@classmethod
def build(
cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> "BulkDownloadStartedEvent":
return cls(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
)
@payload_schema.register
class BulkDownloadCompleteEvent(BulkDownloadEventBase):
"""Event model for bulk_download_complete"""
__event_name__ = "bulk_download_complete"
@classmethod
def build(
cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> "BulkDownloadCompleteEvent":
return cls(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
)
@payload_schema.register
class BulkDownloadErrorEvent(BulkDownloadEventBase):
"""Event model for bulk_download_error"""
__event_name__ = "bulk_download_error"
error: str = Field(description="The error message")
@classmethod
def build(
cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
) -> "BulkDownloadErrorEvent":
return cls(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
error=error,
)

View File

@@ -0,0 +1,47 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import asyncio
import threading
from queue import Empty, Queue
from fastapi_events.dispatcher import dispatch
from invokeai.app.services.events.events_common import (
EventBase,
)
from .events_base import EventServiceBase
class FastAPIEventService(EventServiceBase):
def __init__(self, event_handler_id: int) -> None:
self.event_handler_id = event_handler_id
self._queue = Queue[EventBase | None]()
self._stop_event = threading.Event()
asyncio.create_task(self._dispatch_from_queue(stop_event=self._stop_event))
super().__init__()
def stop(self, *args, **kwargs):
self._stop_event.set()
self._queue.put(None)
def dispatch(self, event: EventBase) -> None:
self._queue.put(event)
async def _dispatch_from_queue(self, stop_event: threading.Event):
"""Get events on from the queue and dispatch them, from the correct thread"""
while not stop_event.is_set():
try:
event = self._queue.get(block=False)
if not event: # Probably stopping
continue
# Leave the payloads as live pydantic models
dispatch(event, middleware_id=self.event_handler_id, payload_schema_dump=False)
except Empty:
await asyncio.sleep(0.1)
pass
except asyncio.CancelledError as e:
raise e # Raise a proper error

View File

@@ -1,11 +1,13 @@
"""Initialization file for model install service package."""
from .model_install_base import (
ModelInstallServiceBase,
)
from .model_install_common import (
HFModelSource,
InstallStatus,
LocalModelSource,
ModelInstallJob,
ModelInstallServiceBase,
ModelSource,
UnknownInstallJobException,
URLModelSource,

View File

@@ -1,244 +1,19 @@
# Copyright 2023 Lincoln D. Stein and the InvokeAI development team
"""Baseclass definitions for the model installer."""
import re
import traceback
from abc import ABC, abstractmethod
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Set, Union
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field, PrivateAttr, field_validator
from pydantic.networks import AnyHttpUrl
from typing_extensions import Annotated
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
from invokeai.app.services.download import DownloadQueueServiceBase
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource
from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
from invokeai.backend.model_manager.config import ModelSourceType
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
class InstallStatus(str, Enum):
"""State of an install job running in the background."""
WAITING = "waiting" # waiting to be dequeued
DOWNLOADING = "downloading" # downloading of model files in process
DOWNLOADS_DONE = "downloads_done" # downloading done, waiting to run
RUNNING = "running" # being processed
COMPLETED = "completed" # finished running
ERROR = "error" # terminated with an error message
CANCELLED = "cancelled" # terminated with an error message
class ModelInstallPart(BaseModel):
url: AnyHttpUrl
path: Path
bytes: int = 0
total_bytes: int = 0
class UnknownInstallJobException(Exception):
"""Raised when the status of an unknown job is requested."""
class StringLikeSource(BaseModel):
"""
Base class for model sources, implements functions that lets the source be sorted and indexed.
These shenanigans let this stuff work:
source1 = LocalModelSource(path='C:/users/mort/foo.safetensors')
mydict = {source1: 'model 1'}
assert mydict['C:/users/mort/foo.safetensors'] == 'model 1'
assert mydict[LocalModelSource(path='C:/users/mort/foo.safetensors')] == 'model 1'
source2 = LocalModelSource(path=Path('C:/users/mort/foo.safetensors'))
assert source1 == source2
assert source1 == 'C:/users/mort/foo.safetensors'
"""
def __hash__(self) -> int:
"""Return hash of the path field, for indexing."""
return hash(str(self))
def __lt__(self, other: object) -> int:
"""Return comparison of the stringified version, for sorting."""
return str(self) < str(other)
def __eq__(self, other: object) -> bool:
"""Return equality on the stringified version."""
if isinstance(other, Path):
return str(self) == other.as_posix()
else:
return str(self) == str(other)
class LocalModelSource(StringLikeSource):
"""A local file or directory path."""
path: str | Path
inplace: Optional[bool] = False
type: Literal["local"] = "local"
# these methods allow the source to be used in a string-like way,
# for example as an index into a dict
def __str__(self) -> str:
"""Return string version of path when string rep needed."""
return Path(self.path).as_posix()
class HFModelSource(StringLikeSource):
"""
A HuggingFace repo_id with optional variant, sub-folder and access token.
Note that the variant option, if not provided to the constructor, will default to fp16, which is
what people (almost) always want.
"""
repo_id: str
variant: Optional[ModelRepoVariant] = ModelRepoVariant.FP16
subfolder: Optional[Path] = None
access_token: Optional[str] = None
type: Literal["hf"] = "hf"
@field_validator("repo_id")
@classmethod
def proper_repo_id(cls, v: str) -> str: # noqa D102
if not re.match(r"^([.\w-]+/[.\w-]+)$", v):
raise ValueError(f"{v}: invalid repo_id format")
return v
def __str__(self) -> str:
"""Return string version of repoid when string rep needed."""
base: str = self.repo_id
if self.variant:
base += f":{self.variant or ''}"
if self.subfolder:
base += f":{self.subfolder}"
return base
class URLModelSource(StringLikeSource):
"""A generic URL point to a checkpoint file."""
url: AnyHttpUrl
access_token: Optional[str] = None
type: Literal["url"] = "url"
def __str__(self) -> str:
"""Return string version of the url when string rep needed."""
return str(self.url)
ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Field(discriminator="type")]
MODEL_SOURCE_TO_TYPE_MAP = {
URLModelSource: ModelSourceType.Url,
HFModelSource: ModelSourceType.HFRepoID,
LocalModelSource: ModelSourceType.Path,
}
class ModelInstallJob(BaseModel):
"""Object that tracks the current status of an install request."""
id: int = Field(description="Unique ID for this job")
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
error_reason: Optional[str] = Field(default=None, description="Information about why the job failed")
config_in: Dict[str, Any] = Field(
default_factory=dict, description="Configuration information (e.g. 'description') to apply to model."
)
config_out: Optional[AnyModelConfig] = Field(
default=None, description="After successful installation, this will hold the configuration object."
)
inplace: bool = Field(
default=False, description="Leave model in its current location; otherwise install under models directory"
)
source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model")
local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source")
bytes: int = Field(
default=0, description="For a remote model, the number of bytes downloaded so far (may not be available)"
)
total_bytes: int = Field(default=0, description="Total size of the model to be installed")
source_metadata: Optional[AnyModelRepoMetadata] = Field(
default=None, description="Metadata provided by the model source"
)
download_parts: Set[DownloadJob] = Field(
default_factory=set, description="Download jobs contributing to this install"
)
error: Optional[str] = Field(
default=None, description="On an error condition, this field will contain the text of the exception"
)
error_traceback: Optional[str] = Field(
default=None, description="On an error condition, this field will contain the exception traceback"
)
# internal flags and transitory settings
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
_exception: Optional[Exception] = PrivateAttr(default=None)
def set_error(self, e: Exception) -> None:
"""Record the error and traceback from an exception."""
self._exception = e
self.error = str(e)
self.error_traceback = self._format_error(e)
self.status = InstallStatus.ERROR
self.error_reason = self._exception.__class__.__name__ if self._exception else None
def cancel(self) -> None:
"""Call to cancel the job."""
self.status = InstallStatus.CANCELLED
@property
def error_type(self) -> Optional[str]:
"""Class name of the exception that led to status==ERROR."""
return self._exception.__class__.__name__ if self._exception else None
def _format_error(self, exception: Exception) -> str:
"""Error traceback."""
return "".join(traceback.format_exception(exception))
@property
def cancelled(self) -> bool:
"""Set status to CANCELLED."""
return self.status == InstallStatus.CANCELLED
@property
def errored(self) -> bool:
"""Return true if job has errored."""
return self.status == InstallStatus.ERROR
@property
def waiting(self) -> bool:
"""Return true if job is waiting to run."""
return self.status == InstallStatus.WAITING
@property
def downloading(self) -> bool:
"""Return true if job is downloading."""
return self.status == InstallStatus.DOWNLOADING
@property
def downloads_done(self) -> bool:
"""Return true if job's downloads ae done."""
return self.status == InstallStatus.DOWNLOADS_DONE
@property
def running(self) -> bool:
"""Return true if job is running."""
return self.status == InstallStatus.RUNNING
@property
def complete(self) -> bool:
"""Return true if job completed without errors."""
return self.status == InstallStatus.COMPLETED
@property
def in_terminal_state(self) -> bool:
"""Return true if job is in a terminal state."""
return self.status in [InstallStatus.COMPLETED, InstallStatus.ERROR, InstallStatus.CANCELLED]
from invokeai.backend.model_manager import AnyModelConfig
class ModelInstallServiceBase(ABC):
@@ -282,7 +57,7 @@ class ModelInstallServiceBase(ABC):
@property
@abstractmethod
def event_bus(self) -> Optional[EventServiceBase]:
def event_bus(self) -> Optional["EventServiceBase"]:
"""Return the event service base object associated with the installer."""
@abstractmethod
@@ -468,12 +243,11 @@ class ModelInstallServiceBase(ABC):
"""
@abstractmethod
def download_and_cache(self, source: Union[str, AnyHttpUrl], access_token: Optional[str] = None) -> Path:
def download_and_cache_model(self, source: str | AnyHttpUrl) -> Path:
"""
Download the model file located at source to the models cache and return its Path.
:param source: A Url or a string that can be converted into one.
:param access_token: Optional access token to access restricted resources.
:param source: A string representing a URL or repo_id.
The model file will be downloaded into the system-wide model cache
(`models/.cache`) if it isn't already there. Note that the model cache

View File

@@ -0,0 +1,227 @@
import re
import traceback
from enum import Enum
from pathlib import Path
from typing import Any, Dict, Literal, Optional, Set, Union
from pydantic import BaseModel, Field, PrivateAttr, field_validator
from pydantic.networks import AnyHttpUrl
from typing_extensions import Annotated
from invokeai.app.services.download import DownloadJob, MultiFileDownloadJob
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
from invokeai.backend.model_manager.config import ModelSourceType
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
class InstallStatus(str, Enum):
"""State of an install job running in the background."""
WAITING = "waiting" # waiting to be dequeued
DOWNLOADING = "downloading" # downloading of model files in process
DOWNLOADS_DONE = "downloads_done" # downloading done, waiting to run
RUNNING = "running" # being processed
COMPLETED = "completed" # finished running
ERROR = "error" # terminated with an error message
CANCELLED = "cancelled" # terminated with an error message
class UnknownInstallJobException(Exception):
"""Raised when the status of an unknown job is requested."""
class StringLikeSource(BaseModel):
"""
Base class for model sources, implements functions that lets the source be sorted and indexed.
These shenanigans let this stuff work:
source1 = LocalModelSource(path='C:/users/mort/foo.safetensors')
mydict = {source1: 'model 1'}
assert mydict['C:/users/mort/foo.safetensors'] == 'model 1'
assert mydict[LocalModelSource(path='C:/users/mort/foo.safetensors')] == 'model 1'
source2 = LocalModelSource(path=Path('C:/users/mort/foo.safetensors'))
assert source1 == source2
assert source1 == 'C:/users/mort/foo.safetensors'
"""
def __hash__(self) -> int:
"""Return hash of the path field, for indexing."""
return hash(str(self))
def __lt__(self, other: object) -> int:
"""Return comparison of the stringified version, for sorting."""
return str(self) < str(other)
def __eq__(self, other: object) -> bool:
"""Return equality on the stringified version."""
if isinstance(other, Path):
return str(self) == other.as_posix()
else:
return str(self) == str(other)
class LocalModelSource(StringLikeSource):
"""A local file or directory path."""
path: str | Path
inplace: Optional[bool] = False
type: Literal["local"] = "local"
# these methods allow the source to be used in a string-like way,
# for example as an index into a dict
def __str__(self) -> str:
"""Return string version of path when string rep needed."""
return Path(self.path).as_posix()
class HFModelSource(StringLikeSource):
"""
A HuggingFace repo_id with optional variant, sub-folder and access token.
Note that the variant option, if not provided to the constructor, will default to fp16, which is
what people (almost) always want.
"""
repo_id: str
variant: Optional[ModelRepoVariant] = ModelRepoVariant.FP16
subfolder: Optional[Path] = None
access_token: Optional[str] = None
type: Literal["hf"] = "hf"
@field_validator("repo_id")
@classmethod
def proper_repo_id(cls, v: str) -> str: # noqa D102
if not re.match(r"^([.\w-]+/[.\w-]+)$", v):
raise ValueError(f"{v}: invalid repo_id format")
return v
def __str__(self) -> str:
"""Return string version of repoid when string rep needed."""
base: str = self.repo_id
if self.variant:
base += f":{self.variant or ''}"
if self.subfolder:
base += f":{self.subfolder}"
return base
class URLModelSource(StringLikeSource):
"""A generic URL point to a checkpoint file."""
url: AnyHttpUrl
access_token: Optional[str] = None
type: Literal["url"] = "url"
def __str__(self) -> str:
"""Return string version of the url when string rep needed."""
return str(self.url)
ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Field(discriminator="type")]
MODEL_SOURCE_TO_TYPE_MAP = {
URLModelSource: ModelSourceType.Url,
HFModelSource: ModelSourceType.HFRepoID,
LocalModelSource: ModelSourceType.Path,
}
class ModelInstallJob(BaseModel):
"""Object that tracks the current status of an install request."""
id: int = Field(description="Unique ID for this job")
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
error_reason: Optional[str] = Field(default=None, description="Information about why the job failed")
config_in: Dict[str, Any] = Field(
default_factory=dict, description="Configuration information (e.g. 'description') to apply to model."
)
config_out: Optional[AnyModelConfig] = Field(
default=None, description="After successful installation, this will hold the configuration object."
)
inplace: bool = Field(
default=False, description="Leave model in its current location; otherwise install under models directory"
)
source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model")
local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source")
bytes: int = Field(
default=0, description="For a remote model, the number of bytes downloaded so far (may not be available)"
)
total_bytes: int = Field(default=0, description="Total size of the model to be installed")
source_metadata: Optional[AnyModelRepoMetadata] = Field(
default=None, description="Metadata provided by the model source"
)
download_parts: Set[DownloadJob] = Field(
default_factory=set, description="Download jobs contributing to this install"
)
error: Optional[str] = Field(
default=None, description="On an error condition, this field will contain the text of the exception"
)
error_traceback: Optional[str] = Field(
default=None, description="On an error condition, this field will contain the exception traceback"
)
# internal flags and transitory settings
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
_multifile_job: Optional[MultiFileDownloadJob] = PrivateAttr(default=None)
_exception: Optional[Exception] = PrivateAttr(default=None)
def set_error(self, e: Exception) -> None:
"""Record the error and traceback from an exception."""
self._exception = e
self.error = str(e)
self.error_traceback = self._format_error(e)
self.status = InstallStatus.ERROR
self.error_reason = self._exception.__class__.__name__ if self._exception else None
def cancel(self) -> None:
"""Call to cancel the job."""
self.status = InstallStatus.CANCELLED
@property
def error_type(self) -> Optional[str]:
"""Class name of the exception that led to status==ERROR."""
return self._exception.__class__.__name__ if self._exception else None
def _format_error(self, exception: Exception) -> str:
"""Error traceback."""
return "".join(traceback.format_exception(exception))
@property
def cancelled(self) -> bool:
"""Set status to CANCELLED."""
return self.status == InstallStatus.CANCELLED
@property
def errored(self) -> bool:
"""Return true if job has errored."""
return self.status == InstallStatus.ERROR
@property
def waiting(self) -> bool:
"""Return true if job is waiting to run."""
return self.status == InstallStatus.WAITING
@property
def downloading(self) -> bool:
"""Return true if job is downloading."""
return self.status == InstallStatus.DOWNLOADING
@property
def downloads_done(self) -> bool:
"""Return true if job's downloads ae done."""
return self.status == InstallStatus.DOWNLOADS_DONE
@property
def running(self) -> bool:
"""Return true if job is running."""
return self.status == InstallStatus.RUNNING
@property
def complete(self) -> bool:
"""Return true if job completed without errors."""
return self.status == InstallStatus.COMPLETED
@property
def in_terminal_state(self) -> bool:
"""Return true if job is in a terminal state."""
return self.status in [InstallStatus.COMPLETED, InstallStatus.ERROR, InstallStatus.CANCELLED]

View File

@@ -5,23 +5,24 @@ import os
import re
import threading
import time
from hashlib import sha256
from pathlib import Path
from queue import Empty, Queue
from shutil import copyfile, copytree, move, rmtree
from tempfile import mkdtemp
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import torch
import yaml
from huggingface_hub import HfFolder
from pydantic.networks import AnyHttpUrl
from pydantic_core import Url
from requests import Session
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress
from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_install.model_install_base import ModelInstallServiceBase
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.backend.model_manager.config import (
@@ -44,14 +45,14 @@ from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.util import InvokeAILogger
from invokeai.backend.util.catch_sigint import catch_sigint
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.util import slugify
from .model_install_base import (
from .model_install_common import (
MODEL_SOURCE_TO_TYPE_MAP,
HFModelSource,
InstallStatus,
LocalModelSource,
ModelInstallJob,
ModelInstallServiceBase,
ModelSource,
StringLikeSource,
URLModelSource,
@@ -68,7 +69,7 @@ class ModelInstallService(ModelInstallServiceBase):
app_config: InvokeAIAppConfig,
record_store: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase,
event_bus: Optional[EventServiceBase] = None,
event_bus: Optional["EventServiceBase"] = None,
session: Optional[Session] = None,
):
"""
@@ -89,7 +90,7 @@ class ModelInstallService(ModelInstallServiceBase):
self._downloads_changed_event = threading.Event()
self._install_completed_event = threading.Event()
self._download_queue = download_queue
self._download_cache: Dict[AnyHttpUrl, ModelInstallJob] = {}
self._download_cache: Dict[int, ModelInstallJob] = {}
self._running = False
self._session = session
self._install_thread: Optional[threading.Thread] = None
@@ -104,7 +105,7 @@ class ModelInstallService(ModelInstallServiceBase):
return self._record_store
@property
def event_bus(self) -> Optional[EventServiceBase]: # noqa D102
def event_bus(self) -> Optional["EventServiceBase"]: # noqa D102
return self._event_bus
# make the invoker optional here because we don't need it and it
@@ -208,33 +209,12 @@ class ModelInstallService(ModelInstallServiceBase):
access_token: Optional[str] = None,
inplace: Optional[bool] = False,
) -> ModelInstallJob:
variants = "|".join(ModelRepoVariant.__members__.values())
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
source_obj: Optional[StringLikeSource] = None
if Path(source).exists(): # A local file or directory
source_obj = LocalModelSource(path=Path(source), inplace=inplace)
elif match := re.match(hf_repoid_re, source):
source_obj = HFModelSource(
repo_id=match.group(1),
variant=match.group(2) if match.group(2) else None, # pass None rather than ''
subfolder=Path(match.group(3)) if match.group(3) else None,
access_token=access_token,
)
elif re.match(r"^https?://[^/]+", source):
# Pull the token from config if it exists and matches the URL
_token = access_token
if _token is None:
for pair in self.app_config.remote_api_tokens or []:
if re.search(pair.url_regex, source):
_token = pair.token
break
source_obj = URLModelSource(
url=AnyHttpUrl(source),
access_token=_token,
)
else:
raise ValueError(f"Unsupported model source: '{source}'")
"""Install a model using pattern matching to infer the type of source."""
source_obj = self._guess_source(source)
if isinstance(source_obj, LocalModelSource):
source_obj.inplace = inplace
elif isinstance(source_obj, HFModelSource) or isinstance(source_obj, URLModelSource):
source_obj.access_token = access_token
return self.import_model(source_obj, config)
def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102
@@ -295,8 +275,9 @@ class ModelInstallService(ModelInstallServiceBase):
def cancel_job(self, job: ModelInstallJob) -> None:
"""Cancel the indicated job."""
job.cancel()
with self._lock:
self._cancel_download_parts(job)
self._logger.warning(f"Cancelling {job.source}")
if dj := job._multifile_job:
self._download_queue.cancel_job(dj)
def prune_jobs(self) -> None:
"""Prune all completed and errored jobs."""
@@ -344,7 +325,7 @@ class ModelInstallService(ModelInstallServiceBase):
legacy_config_path = stanza.get("config")
if legacy_config_path:
# In v3, these paths were relative to the root. Migrate them to be relative to the legacy_conf_dir.
legacy_config_path: Path = self._app_config.root_path / legacy_config_path
legacy_config_path = self._app_config.root_path / legacy_config_path
if legacy_config_path.is_relative_to(self._app_config.legacy_conf_path):
legacy_config_path = legacy_config_path.relative_to(self._app_config.legacy_conf_path)
config["config_path"] = str(legacy_config_path)
@@ -384,38 +365,95 @@ class ModelInstallService(ModelInstallServiceBase):
rmtree(model_path)
self.unregister(key)
def download_and_cache(
@classmethod
def _download_cache_path(cls, source: Union[str, AnyHttpUrl], app_config: InvokeAIAppConfig) -> Path:
escaped_source = slugify(str(source))
return app_config.download_cache_path / escaped_source
def download_and_cache_model(
self,
source: Union[str, AnyHttpUrl],
access_token: Optional[str] = None,
timeout: int = 0,
source: str | AnyHttpUrl,
) -> Path:
"""Download the model file located at source to the models cache and return its Path."""
model_hash = sha256(str(source).encode("utf-8")).hexdigest()[0:32]
model_path = self._app_config.convert_cache_path / model_hash
model_path = self._download_cache_path(str(source), self._app_config)
# We expect the cache directory to contain one and only one downloaded file.
# We expect the cache directory to contain one and only one downloaded file or directory.
# We don't know the file's name in advance, as it is set by the download
# content-disposition header.
if model_path.exists():
contents = [x for x in model_path.iterdir() if x.is_file()]
contents: List[Path] = list(model_path.iterdir())
if len(contents) > 0:
return contents[0]
model_path.mkdir(parents=True, exist_ok=True)
job = self._download_queue.download(
source=AnyHttpUrl(str(source)),
model_source = self._guess_source(str(source))
remote_files, _ = self._remote_files_from_source(model_source)
job = self._multifile_download(
dest=model_path,
access_token=access_token,
on_progress=TqdmProgress().update,
remote_files=remote_files,
subfolder=model_source.subfolder if isinstance(model_source, HFModelSource) else None,
)
self._download_queue.wait_for_job(job, timeout)
files_string = "file" if len(remote_files) == 1 else "files"
self._logger.info(f"Queuing model download: {source} ({len(remote_files)} {files_string})")
self._download_queue.wait_for_job(job)
if job.complete:
assert job.download_path is not None
return job.download_path
else:
raise Exception(job.error)
def _remote_files_from_source(
self, source: ModelSource
) -> Tuple[List[RemoteModelFile], Optional[AnyModelRepoMetadata]]:
metadata = None
if isinstance(source, HFModelSource):
metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant)
assert isinstance(metadata, ModelMetadataWithFiles)
return (
metadata.download_urls(
variant=source.variant or self._guess_variant(),
subfolder=source.subfolder,
session=self._session,
),
metadata,
)
if isinstance(source, URLModelSource):
try:
fetcher = self.get_fetcher_from_url(str(source.url))
kwargs: dict[str, Any] = {"session": self._session}
metadata = fetcher(**kwargs).from_url(source.url)
assert isinstance(metadata, ModelMetadataWithFiles)
return metadata.download_urls(session=self._session), metadata
except ValueError:
pass
return [RemoteModelFile(url=source.url, path=Path("."), size=0)], None
raise Exception(f"No files associated with {source}")
def _guess_source(self, source: str) -> ModelSource:
"""Turn a source string into a ModelSource object."""
variants = "|".join(ModelRepoVariant.__members__.values())
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
source_obj: Optional[StringLikeSource] = None
if Path(source).exists(): # A local file or directory
source_obj = LocalModelSource(path=Path(source))
elif match := re.match(hf_repoid_re, source):
source_obj = HFModelSource(
repo_id=match.group(1),
variant=ModelRepoVariant(match.group(2)) if match.group(2) else None, # pass None rather than ''
subfolder=Path(match.group(3)) if match.group(3) else None,
)
elif re.match(r"^https?://[^/]+", source):
source_obj = URLModelSource(
url=Url(source),
)
else:
raise ValueError(f"Unsupported model source: '{source}'")
return source_obj
# --------------------------------------------------------------------------------------------
# Internal functions that manage the installer threads
# --------------------------------------------------------------------------------------------
@@ -476,16 +514,19 @@ class ModelInstallService(ModelInstallServiceBase):
job.config_out = self.record_store.get_model(key)
self._signal_job_completed(job)
def _set_error(self, job: ModelInstallJob, excp: Exception) -> None:
if any(x.content_type is not None and "text/html" in x.content_type for x in job.download_parts):
job.set_error(
def _set_error(self, install_job: ModelInstallJob, excp: Exception) -> None:
multifile_download_job = install_job._multifile_job
if multifile_download_job and any(
x.content_type is not None and "text/html" in x.content_type for x in multifile_download_job.download_parts
):
install_job.set_error(
InvalidModelConfigException(
f"At least one file in {job.local_path} is an HTML page, not a model. This can happen when an access token is required to download."
f"At least one file in {install_job.local_path} is an HTML page, not a model. This can happen when an access token is required to download."
)
)
else:
job.set_error(excp)
self._signal_job_errored(job)
install_job.set_error(excp)
self._signal_job_errored(install_job)
# --------------------------------------------------------------------------------------------
# Internal functions that manage the models directory
@@ -511,7 +552,6 @@ class ModelInstallService(ModelInstallServiceBase):
This is typically only used during testing with a new DB or when using the memory DB, because those are the
only situations in which we may have orphaned models in the models directory.
"""
installed_model_paths = {
(self._app_config.models_path / x.path).resolve() for x in self.record_store.all_models()
}
@@ -523,8 +563,13 @@ class ModelInstallService(ModelInstallServiceBase):
if resolved_path in installed_model_paths:
return True
# Skip core models entirely - these aren't registered with the model manager.
if str(resolved_path).startswith(str(self.app_config.models_path / "core")):
return False
for special_directory in [
self.app_config.models_path / "core",
self.app_config.convert_cache_dir,
self.app_config.download_cache_dir,
]:
if resolved_path.is_relative_to(special_directory):
return False
try:
model_id = self.register_path(model_path)
self._logger.info(f"Registered {model_path.name} with id {model_id}")
@@ -639,20 +684,15 @@ class ModelInstallService(ModelInstallServiceBase):
inplace=source.inplace or False,
)
def _import_from_hf(self, source: HFModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
def _import_from_hf(
self,
source: HFModelSource,
config: Optional[Dict[str, Any]] = None,
) -> ModelInstallJob:
# Add user's cached access token to HuggingFace requests
source.access_token = source.access_token or HfFolder.get_token()
if not source.access_token:
self._logger.info("No HuggingFace access token present; some models may not be downloadable.")
metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant)
assert isinstance(metadata, ModelMetadataWithFiles)
remote_files = metadata.download_urls(
variant=source.variant or self._guess_variant(),
subfolder=source.subfolder,
session=self._session,
)
if source.access_token is None:
source.access_token = HfFolder.get_token()
remote_files, metadata = self._remote_files_from_source(source)
return self._import_remote_model(
source=source,
config=config,
@@ -660,22 +700,12 @@ class ModelInstallService(ModelInstallServiceBase):
metadata=metadata,
)
def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
# URLs from HuggingFace will be handled specially
metadata = None
fetcher = None
try:
fetcher = self.get_fetcher_from_url(str(source.url))
except ValueError:
pass
kwargs: dict[str, Any] = {"session": self._session}
if fetcher is not None:
metadata = fetcher(**kwargs).from_url(source.url)
self._logger.debug(f"metadata={metadata}")
if metadata and isinstance(metadata, ModelMetadataWithFiles):
remote_files = metadata.download_urls(session=self._session)
else:
remote_files = [RemoteModelFile(url=source.url, path=Path("."), size=0)]
def _import_from_url(
self,
source: URLModelSource,
config: Optional[Dict[str, Any]],
) -> ModelInstallJob:
remote_files, metadata = self._remote_files_from_source(source)
return self._import_remote_model(
source=source,
config=config,
@@ -690,12 +720,9 @@ class ModelInstallService(ModelInstallServiceBase):
metadata: Optional[AnyModelRepoMetadata],
config: Optional[Dict[str, Any]],
) -> ModelInstallJob:
# TODO: Replace with tempfile.tmpdir() when multithreading is cleaned up.
# Currently the tmpdir isn't automatically removed at exit because it is
# being held in a daemon thread.
if len(remote_files) == 0:
raise ValueError(f"{source}: No downloadable files found")
tmpdir = Path(
destdir = Path(
mkdtemp(
dir=self._app_config.models_path,
prefix=TMPDIR_PREFIX,
@@ -706,55 +733,28 @@ class ModelInstallService(ModelInstallServiceBase):
source=source,
config_in=config or {},
source_metadata=metadata,
local_path=tmpdir, # local path may change once the download has started due to content-disposition handling
local_path=destdir, # local path may change once the download has started due to content-disposition handling
bytes=0,
total_bytes=0,
)
# In the event that there is a subfolder specified in the source,
# we need to remove it from the destination path in order to avoid
# creating unwanted subfolders
if isinstance(source, HFModelSource) and source.subfolder:
root = Path(remote_files[0].path.parts[0])
subfolder = root / source.subfolder
else:
root = Path(".")
subfolder = Path(".")
# remember the temporary directory for later removal
install_job._install_tmpdir = destdir
install_job.total_bytes = sum((x.size or 0) for x in remote_files)
# we remember the path up to the top of the tmpdir so that it may be
# removed safely at the end of the install process.
install_job._install_tmpdir = tmpdir
assert install_job.total_bytes is not None # to avoid type checking complaints in the loop below
multifile_job = self._multifile_download(
remote_files=remote_files,
dest=destdir,
subfolder=source.subfolder if isinstance(source, HFModelSource) else None,
access_token=source.access_token,
submit_job=False, # Important! Don't submit the job until we have set our _download_cache dict
)
self._download_cache[multifile_job.id] = install_job
install_job._multifile_job = multifile_job
files_string = "file" if len(remote_files) == 1 else "file"
self._logger.info(f"Queuing model install: {source} ({len(remote_files)} {files_string})")
files_string = "file" if len(remote_files) == 1 else "files"
self._logger.info(f"Queueing model install: {source} ({len(remote_files)} {files_string})")
self._logger.debug(f"remote_files={remote_files}")
for model_file in remote_files:
url = model_file.url
path = root / model_file.path.relative_to(subfolder)
self._logger.debug(f"Downloading {url} => {path}")
install_job.total_bytes += model_file.size
assert hasattr(source, "access_token")
dest = tmpdir / path.parent
dest.mkdir(parents=True, exist_ok=True)
download_job = DownloadJob(
source=url,
dest=dest,
access_token=source.access_token,
)
self._download_cache[download_job.source] = install_job # matches a download job to an install job
install_job.download_parts.add(download_job)
# only start the jobs once install_job.download_parts is fully populated
for download_job in install_job.download_parts:
self._download_queue.submit_download_job(
download_job,
on_start=self._download_started_callback,
on_progress=self._download_progress_callback,
on_complete=self._download_complete_callback,
on_error=self._download_error_callback,
on_cancelled=self._download_cancelled_callback,
)
self._download_queue.submit_multifile_download(multifile_job)
return install_job
def _stat_size(self, path: Path) -> int:
@@ -766,87 +766,104 @@ class ModelInstallService(ModelInstallServiceBase):
size += sum(self._stat_size(Path(root, x)) for x in files)
return size
def _multifile_download(
self,
remote_files: List[RemoteModelFile],
dest: Path,
subfolder: Optional[Path] = None,
access_token: Optional[str] = None,
submit_job: bool = True,
) -> MultiFileDownloadJob:
# HuggingFace repo subfolders are a little tricky. If the name of the model is "sdxl-turbo", and
# we are installing the "vae" subfolder, we do not want to create an additional folder level, such
# as "sdxl-turbo/vae", nor do we want to put the contents of the vae folder directly into "sdxl-turbo".
# So what we do is to synthesize a folder named "sdxl-turbo_vae" here.
if subfolder:
top = Path(remote_files[0].path.parts[0]) # e.g. "sdxl-turbo/"
path_to_remove = top / subfolder.parts[-1] # sdxl-turbo/vae/
path_to_add = Path(f"{top}_{subfolder}")
else:
path_to_remove = Path(".")
path_to_add = Path(".")
parts: List[RemoteModelFile] = []
for model_file in remote_files:
assert model_file.size is not None
parts.append(
RemoteModelFile(
url=model_file.url, # if a subfolder, then sdxl-turbo_vae/config.json
path=path_to_add / model_file.path.relative_to(path_to_remove),
)
)
return self._download_queue.multifile_download(
parts=parts,
dest=dest,
access_token=access_token,
submit_job=submit_job,
on_start=self._download_started_callback,
on_progress=self._download_progress_callback,
on_complete=self._download_complete_callback,
on_error=self._download_error_callback,
on_cancelled=self._download_cancelled_callback,
)
# ------------------------------------------------------------------
# Callbacks are executed by the download queue in a separate thread
# ------------------------------------------------------------------
def _download_started_callback(self, download_job: DownloadJob) -> None:
self._logger.info(f"Model download started: {download_job.source}")
def _download_started_callback(self, download_job: MultiFileDownloadJob) -> None:
with self._lock:
install_job = self._download_cache[download_job.source]
install_job.status = InstallStatus.DOWNLOADING
if install_job := self._download_cache.get(download_job.id, None):
install_job.status = InstallStatus.DOWNLOADING
assert download_job.download_path
if install_job.local_path == install_job._install_tmpdir:
partial_path = download_job.download_path.relative_to(install_job._install_tmpdir)
dest_name = partial_path.parts[0]
install_job.local_path = install_job._install_tmpdir / dest_name
if install_job.local_path == install_job._install_tmpdir: # first time
assert download_job.download_path
install_job.local_path = download_job.download_path
install_job.download_parts = download_job.download_parts
install_job.bytes = sum(x.bytes for x in download_job.download_parts)
install_job.total_bytes = download_job.total_bytes
self._signal_job_download_started(install_job)
# Update the total bytes count for remote sources.
if not install_job.total_bytes:
install_job.total_bytes = sum(x.total_bytes for x in install_job.download_parts)
def _download_progress_callback(self, download_job: DownloadJob) -> None:
def _download_progress_callback(self, download_job: MultiFileDownloadJob) -> None:
with self._lock:
install_job = self._download_cache[download_job.source]
if install_job.cancelled: # This catches the case in which the caller directly calls job.cancel()
self._cancel_download_parts(install_job)
else:
# update sizes
install_job.bytes = sum(x.bytes for x in install_job.download_parts)
self._signal_job_downloading(install_job)
if install_job := self._download_cache.get(download_job.id, None):
if install_job.cancelled: # This catches the case in which the caller directly calls job.cancel()
self._download_queue.cancel_job(download_job)
else:
# update sizes
install_job.bytes = sum(x.bytes for x in download_job.download_parts)
install_job.total_bytes = sum(x.total_bytes for x in download_job.download_parts)
self._signal_job_downloading(install_job)
def _download_complete_callback(self, download_job: DownloadJob) -> None:
self._logger.info(f"Model download complete: {download_job.source}")
def _download_complete_callback(self, download_job: MultiFileDownloadJob) -> None:
with self._lock:
install_job = self._download_cache[download_job.source]
# are there any more active jobs left in this task?
if install_job.downloading and all(x.complete for x in install_job.download_parts):
if install_job := self._download_cache.pop(download_job.id, None):
self._signal_job_downloads_done(install_job)
self._put_in_queue(install_job)
self._put_in_queue(install_job) # this starts the installation and registration
# Let other threads know that the number of downloads has changed
self._download_cache.pop(download_job.source, None)
self._downloads_changed_event.set()
# Let other threads know that the number of downloads has changed
self._downloads_changed_event.set()
def _download_error_callback(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None:
def _download_error_callback(self, download_job: MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
with self._lock:
install_job = self._download_cache.pop(download_job.source, None)
assert install_job is not None
assert excp is not None
install_job.set_error(excp)
self._logger.error(
f"Cancelling {install_job.source} due to an error while downloading {download_job.source}: {str(excp)}"
)
self._cancel_download_parts(install_job)
if install_job := self._download_cache.pop(download_job.id, None):
assert excp is not None
install_job.set_error(excp)
self._download_queue.cancel_job(download_job)
# Let other threads know that the number of downloads has changed
self._downloads_changed_event.set()
# Let other threads know that the number of downloads has changed
self._downloads_changed_event.set()
def _download_cancelled_callback(self, download_job: DownloadJob) -> None:
def _download_cancelled_callback(self, download_job: MultiFileDownloadJob) -> None:
with self._lock:
install_job = self._download_cache.pop(download_job.source, None)
if not install_job:
return
self._downloads_changed_event.set()
self._logger.warning(f"Model download canceled: {download_job.source}")
# if install job has already registered an error, then do not replace its status with cancelled
if not install_job.errored:
install_job.cancel()
self._cancel_download_parts(install_job)
if install_job := self._download_cache.pop(download_job.id, None):
self._downloads_changed_event.set()
# if install job has already registered an error, then do not replace its status with cancelled
if not install_job.errored:
install_job.cancel()
# Let other threads know that the number of downloads has changed
self._downloads_changed_event.set()
def _cancel_download_parts(self, install_job: ModelInstallJob) -> None:
# on multipart downloads, _cancel_components() will get called repeatedly from the download callbacks
# do not lock here because it gets called within a locked context
for s in install_job.download_parts:
self._download_queue.cancel_job(s)
if all(x.in_terminal_state for x in install_job.download_parts):
# When all parts have reached their terminal state, we finalize the job to clean up the temporary directory and other resources
self._put_in_queue(install_job)
# Let other threads know that the number of downloads has changed
self._downloads_changed_event.set()
# ------------------------------------------------------------------------------------------------
# Internal methods that put events on the event bus
@@ -855,35 +872,27 @@ class ModelInstallService(ModelInstallServiceBase):
job.status = InstallStatus.RUNNING
self._logger.info(f"Model install started: {job.source}")
if self._event_bus:
self._event_bus.emit_model_install_running(str(job.source))
self._event_bus.emit_model_install_started(job)
def _signal_job_download_started(self, job: ModelInstallJob) -> None:
if self._event_bus:
assert job._multifile_job is not None
assert job.bytes is not None
assert job.total_bytes is not None
self._event_bus.emit_model_install_download_started(job)
def _signal_job_downloading(self, job: ModelInstallJob) -> None:
if self._event_bus:
parts: List[Dict[str, str | int]] = [
{
"url": str(x.source),
"local_path": str(x.download_path),
"bytes": x.bytes,
"total_bytes": x.total_bytes,
}
for x in job.download_parts
]
assert job._multifile_job is not None
assert job.bytes is not None
assert job.total_bytes is not None
self._event_bus.emit_model_install_downloading(
str(job.source),
local_path=job.local_path.as_posix(),
parts=parts,
bytes=job.bytes,
total_bytes=job.total_bytes,
id=job.id,
)
self._event_bus.emit_model_install_download_progress(job)
def _signal_job_downloads_done(self, job: ModelInstallJob) -> None:
job.status = InstallStatus.DOWNLOADS_DONE
self._logger.info(f"Model download complete: {job.source}")
if self._event_bus:
self._event_bus.emit_model_install_downloads_done(str(job.source))
self._event_bus.emit_model_install_downloads_complete(job)
def _signal_job_completed(self, job: ModelInstallJob) -> None:
job.status = InstallStatus.COMPLETED
@@ -893,25 +902,28 @@ class ModelInstallService(ModelInstallServiceBase):
if self._event_bus:
assert job.local_path is not None
assert job.config_out is not None
key = job.config_out.key
self._event_bus.emit_model_install_completed(str(job.source), key, id=job.id)
self._event_bus.emit_model_install_complete(job)
def _signal_job_errored(self, job: ModelInstallJob) -> None:
self._logger.error(f"Model install error: {job.source}\n{job.error_type}: {job.error}")
if self._event_bus:
error_type = job.error_type
error = job.error
assert error_type is not None
assert error is not None
self._event_bus.emit_model_install_error(str(job.source), error_type, error, id=job.id)
assert job.error_type is not None
assert job.error is not None
self._event_bus.emit_model_install_error(job)
def _signal_job_cancelled(self, job: ModelInstallJob) -> None:
self._logger.info(f"Model install canceled: {job.source}")
if self._event_bus:
self._event_bus.emit_model_install_cancelled(str(job.source), id=job.id)
self._event_bus.emit_model_install_cancelled(job)
@staticmethod
def get_fetcher_from_url(url: str) -> ModelMetadataFetchBase:
def get_fetcher_from_url(url: str) -> Type[ModelMetadataFetchBase]:
"""
Return a metadata fetcher appropriate for provided url.
This used to be more useful, but the number of supported model
sources has been reduced to HuggingFace alone.
"""
if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
return HuggingFaceMetadataFetch
raise ValueError(f"Unsupported model source: '{url}'")

View File

@@ -2,11 +2,11 @@
"""Base class for model loader."""
from abc import ABC, abstractmethod
from typing import Optional
from pathlib import Path
from typing import Callable, Optional
from invokeai.app.services.shared.invocation_context import InvocationContextData
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager.load import LoadedModel
from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
@@ -15,18 +15,12 @@ class ModelLoadServiceBase(ABC):
"""Wrapper around AnyModelLoader."""
@abstractmethod
def load_model(
self,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""
Given a model's configuration, load it and return the LoadedModel object.
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
:param submodel: For main (pipeline models), the submodel to fetch.
:param context_data: Invocation context data used for event reporting
"""
@property
@@ -38,3 +32,26 @@ class ModelLoadServiceBase(ABC):
@abstractmethod
def convert_cache(self) -> ModelConvertCacheBase:
"""Return the checkpoint convert cache used by this loader."""
@abstractmethod
def load_model_from_path(
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
) -> LoadedModelWithoutConfig:
"""
Load the model file or directory located at the indicated Path.
This will load an arbitrary model file into the RAM cache. If the optional loader
argument is provided, the loader will be invoked to load the model into
memory. Otherwise the method will call safetensors.torch.load_file() or
torch.load() as appropriate to the file suffix.
Be aware that this returns a LoadedModelWithoutConfig object, which is the same as
LoadedModel, but without the config attribute.
Args:
model_path: A pathlib.Path to a checkpoint-style models file
loader: A Callable that expects a Path and returns a Dict[str, Tensor]
Returns:
A LoadedModel object.
"""

View File

@@ -1,19 +1,26 @@
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team
"""Implementation of model loader service."""
from typing import Optional, Type
from pathlib import Path
from typing import Callable, Optional, Type
from picklescan.scanner import scan_file_path
from safetensors.torch import load_file as safetensors_load_file
from torch import load as torch_load
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.invocation_context import InvocationContextData
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager.load import (
LoadedModel,
LoadedModelWithoutConfig,
ModelLoaderRegistry,
ModelLoaderRegistryBase,
)
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
from .model_load_base import ModelLoadServiceBase
@@ -51,25 +58,18 @@ class ModelLoadService(ModelLoadServiceBase):
"""Return the checkpoint convert cache used by this loader."""
return self._convert_cache
def load_model(
self,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""
Given a model's configuration, load it and return the LoadedModel object.
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
:param submodel: For main (pipeline models), the submodel to fetch.
:param context: Invocation context used for event reporting
"""
if context_data:
self._emit_load_event(
context_data=context_data,
model_config=model_config,
submodel_type=submodel_type,
)
# We don't have an invoker during testing
# TODO(psyche): Mock this method on the invoker in the tests
if hasattr(self, "_invoker"):
self._invoker.services.events.emit_model_load_started(model_config, submodel_type)
implementation, model_config, submodel_type = self._registry.get_implementation(model_config, submodel_type) # type: ignore
loaded_model: LoadedModel = implementation(
@@ -79,40 +79,45 @@ class ModelLoadService(ModelLoadServiceBase):
convert_cache=self._convert_cache,
).load_model(model_config, submodel_type)
if context_data:
self._emit_load_event(
context_data=context_data,
model_config=model_config,
submodel_type=submodel_type,
loaded=True,
)
if hasattr(self, "_invoker"):
self._invoker.services.events.emit_model_load_complete(model_config, submodel_type)
return loaded_model
def _emit_load_event(
self,
context_data: InvocationContextData,
model_config: AnyModelConfig,
loaded: Optional[bool] = False,
submodel_type: Optional[SubModelType] = None,
) -> None:
if not self._invoker:
return
def load_model_from_path(
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
) -> LoadedModelWithoutConfig:
cache_key = str(model_path)
ram_cache = self.ram_cache
try:
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))
except IndexError:
pass
if not loaded:
self._invoker.services.events.emit_model_load_started(
queue_id=context_data.queue_item.queue_id,
queue_item_id=context_data.queue_item.item_id,
queue_batch_id=context_data.queue_item.batch_id,
graph_execution_state_id=context_data.queue_item.session_id,
model_config=model_config,
submodel_type=submodel_type,
)
else:
self._invoker.services.events.emit_model_load_completed(
queue_id=context_data.queue_item.queue_id,
queue_item_id=context_data.queue_item.item_id,
queue_batch_id=context_data.queue_item.batch_id,
graph_execution_state_id=context_data.queue_item.session_id,
model_config=model_config,
submodel_type=submodel_type,
)
def torch_load_file(checkpoint: Path) -> AnyModel:
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0:
raise Exception("The model at {checkpoint} is potentially infected by malware. Aborting load.")
result = torch_load(checkpoint, map_location="cpu")
return result
def diffusers_load_directory(directory: Path) -> AnyModel:
load_class = GenericDiffusersLoader(
app_config=self._app_config,
logger=self._logger,
ram_cache=self._ram_cache,
convert_cache=self.convert_cache,
).get_hf_load_class(directory)
return load_class.from_pretrained(model_path, torch_dtype=TorchDevice.choose_torch_dtype())
loader = loader or (
diffusers_load_directory
if model_path.is_dir()
else torch_load_file
if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin"))
else lambda path: safetensors_load_file(path, device="cpu")
)
assert loader is not None
raw_model = loader(model_path)
ram_cache.put(key=cache_key, model=raw_model)
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))

View File

@@ -12,15 +12,13 @@ from pydantic import BaseModel, Field
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
from invokeai.backend.model_manager import (
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.config import (
ControlAdapterDefaultSettings,
MainModelDefaultSettings,
ModelFormat,
ModelType,
ModelVariantType,
SchedulerPredictionType,
)

View File

@@ -1,6 +1,49 @@
from abc import ABC, abstractmethod
from threading import Event
from typing import Optional, Protocol
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
from invokeai.app.util.profiler import Profiler
class SessionRunnerBase(ABC):
"""
Base class for session runner.
"""
@abstractmethod
def start(self, services: InvocationServices, cancel_event: Event, profiler: Optional[Profiler] = None) -> None:
"""Starts the session runner.
Args:
services: The invocation services.
cancel_event: The cancel event.
profiler: The profiler to use for session profiling via cProfile. Omit to disable profiling. Basic session
stats will be still be recorded and logged when profiling is disabled.
"""
pass
@abstractmethod
def run(self, queue_item: SessionQueueItem) -> None:
"""Runs a session.
Args:
queue_item: The session to run.
"""
pass
@abstractmethod
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> None:
"""Run a single node in the graph.
Args:
invocation: The invocation to run.
queue_item: The session queue item.
"""
pass
class SessionProcessorBase(ABC):
@@ -26,3 +69,85 @@ class SessionProcessorBase(ABC):
def get_status(self) -> SessionProcessorStatus:
"""Gets the status of the session processor"""
pass
class OnBeforeRunNode(Protocol):
def __call__(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> None:
"""Callback to run before executing a node.
Args:
invocation: The invocation that will be executed.
queue_item: The session queue item.
"""
...
class OnAfterRunNode(Protocol):
def __call__(self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput) -> None:
"""Callback to run before executing a node.
Args:
invocation: The invocation that was executed.
queue_item: The session queue item.
"""
...
class OnNodeError(Protocol):
def __call__(
self,
invocation: BaseInvocation,
queue_item: SessionQueueItem,
error_type: str,
error_message: str,
error_traceback: str,
) -> None:
"""Callback to run when a node has an error.
Args:
invocation: The invocation that errored.
queue_item: The session queue item.
error_type: The type of error, e.g. "ValueError".
error_message: The error message, e.g. "Invalid value".
error_traceback: The stringified error traceback.
"""
...
class OnBeforeRunSession(Protocol):
def __call__(self, queue_item: SessionQueueItem) -> None:
"""Callback to run before executing a session.
Args:
queue_item: The session queue item.
"""
...
class OnAfterRunSession(Protocol):
def __call__(self, queue_item: SessionQueueItem) -> None:
"""Callback to run after executing a session.
Args:
queue_item: The session queue item.
"""
...
class OnNonFatalProcessorError(Protocol):
def __call__(
self,
queue_item: Optional[SessionQueueItem],
error_type: str,
error_message: str,
error_traceback: str,
) -> None:
"""Callback to run when a non-fatal error occurs in the processor.
Args:
queue_item: The session queue item, if one was being executed when the error occurred.
error_type: The type of error, e.g. "ValueError".
error_message: The error message, e.g. "Invalid value".
error_traceback: The stringified error traceback.
"""
...

View File

@@ -4,24 +4,325 @@ from threading import BoundedSemaphore, Thread
from threading import Event as ThreadEvent
from typing import Optional
from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event as FastAPIEvent
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.events.events_common import (
BatchEnqueuedEvent,
FastAPIEvent,
QueueClearedEvent,
QueueItemStatusChangedEvent,
register_events,
)
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
from invokeai.app.services.session_processor.session_processor_base import (
OnAfterRunNode,
OnAfterRunSession,
OnBeforeRunNode,
OnBeforeRunSession,
OnNodeError,
OnNonFatalProcessorError,
)
from invokeai.app.services.session_processor.session_processor_common import CanceledException
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem, SessionQueueItemNotFoundError
from invokeai.app.services.shared.graph import NodeInputError
from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context
from invokeai.app.util.profiler import Profiler
from ..invoker import Invoker
from .session_processor_base import SessionProcessorBase
from .session_processor_base import InvocationServices, SessionProcessorBase, SessionRunnerBase
from .session_processor_common import SessionProcessorStatus
class DefaultSessionRunner(SessionRunnerBase):
"""Processes a single session's invocations."""
def __init__(
self,
on_before_run_session_callbacks: Optional[list[OnBeforeRunSession]] = None,
on_before_run_node_callbacks: Optional[list[OnBeforeRunNode]] = None,
on_after_run_node_callbacks: Optional[list[OnAfterRunNode]] = None,
on_node_error_callbacks: Optional[list[OnNodeError]] = None,
on_after_run_session_callbacks: Optional[list[OnAfterRunSession]] = None,
):
"""
Args:
on_before_run_session_callbacks: Callbacks to run before the session starts.
on_before_run_node_callbacks: Callbacks to run before each node starts.
on_after_run_node_callbacks: Callbacks to run after each node completes.
on_node_error_callbacks: Callbacks to run when a node errors.
on_after_run_session_callbacks: Callbacks to run after the session completes.
"""
self._on_before_run_session_callbacks = on_before_run_session_callbacks or []
self._on_before_run_node_callbacks = on_before_run_node_callbacks or []
self._on_after_run_node_callbacks = on_after_run_node_callbacks or []
self._on_node_error_callbacks = on_node_error_callbacks or []
self._on_after_run_session_callbacks = on_after_run_session_callbacks or []
def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None):
self._services = services
self._cancel_event = cancel_event
self._profiler = profiler
def _is_canceled(self) -> bool:
"""Check if the cancel event is set. This is also passed to the invocation context builder and called during
denoising to check if the session has been canceled."""
return self._cancel_event.is_set()
def run(self, queue_item: SessionQueueItem):
# Exceptions raised outside `run_node` are handled by the processor. There is no need to catch them here.
self._on_before_run_session(queue_item=queue_item)
# Loop over invocations until the session is complete or canceled
while True:
try:
invocation = queue_item.session.next()
# Anything other than a `NodeInputError` is handled as a processor error
except NodeInputError as e:
error_type = e.__class__.__name__
error_message = str(e)
error_traceback = traceback.format_exc()
self._on_node_error(
invocation=e.node,
queue_item=queue_item,
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)
break
if invocation is None or self._is_canceled():
break
self.run_node(invocation, queue_item)
# The session is complete if all invocations have been run or there is an error on the session.
# At this time, the queue item may be canceled, but the object itself here won't be updated yet. We must
# use the cancel event to check if the session is canceled.
if (
queue_item.session.is_complete()
or self._is_canceled()
or queue_item.status in ["failed", "canceled", "completed"]
):
break
self._on_after_run_session(queue_item=queue_item)
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
try:
# Any unhandled exception in this scope is an invocation error & will fail the graph
with self._services.performance_statistics.collect_stats(invocation, queue_item.session_id):
self._on_before_run_node(invocation, queue_item)
data = InvocationContextData(
invocation=invocation,
source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id],
queue_item=queue_item,
)
context = build_invocation_context(
data=data,
services=self._services,
is_canceled=self._is_canceled,
)
# Invoke the node
output = invocation.invoke_internal(context=context, services=self._services)
# Save output and history
queue_item.session.complete(invocation.id, output)
self._on_after_run_node(invocation, queue_item, output)
except KeyboardInterrupt:
# TODO(psyche): This is expected to be caught in the main thread. Do we need to catch this here?
pass
except CanceledException:
# A CanceledException is raised during the denoising step callback if the cancel event is set. We don't need
# to do any handling here, and no error should be set - just pass and the cancellation will be handled
# correctly in the next iteration of the session runner loop.
#
# See the comment in the processor's `_on_queue_item_status_changed()` method for more details on how we
# handle cancellation.
pass
except Exception as e:
error_type = e.__class__.__name__
error_message = str(e)
error_traceback = traceback.format_exc()
self._on_node_error(
invocation=invocation,
queue_item=queue_item,
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)
def _on_before_run_session(self, queue_item: SessionQueueItem) -> None:
"""Called before a session is run.
- Start the profiler if profiling is enabled.
- Run any callbacks registered for this event.
"""
self._services.logger.debug(
f"On before run session: queue item {queue_item.item_id}, session {queue_item.session_id}"
)
# If profiling is enabled, start the profiler
if self._profiler is not None:
self._profiler.start(profile_id=queue_item.session_id)
for callback in self._on_before_run_session_callbacks:
callback(queue_item=queue_item)
def _on_after_run_session(self, queue_item: SessionQueueItem) -> None:
"""Called after a session is run.
- Stop the profiler if profiling is enabled.
- Update the queue item's session object in the database.
- If not already canceled or failed, complete the queue item.
- Log and reset performance statistics.
- Run any callbacks registered for this event.
"""
self._services.logger.debug(
f"On after run session: queue item {queue_item.item_id}, session {queue_item.session_id}"
)
# If we are profiling, stop the profiler and dump the profile & stats
if self._profiler is not None:
profile_path = self._profiler.stop()
stats_path = profile_path.with_suffix(".json")
self._services.performance_statistics.dump_stats(
graph_execution_state_id=queue_item.session.id, output_path=stats_path
)
try:
# Update the queue item with the completed session. If the queue item has been removed from the queue,
# we'll get a SessionQueueItemNotFoundError and we can ignore it. This can happen if the queue is cleared
# while the session is running.
queue_item = self._services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
# The queue item may have been canceled or failed while the session was running. We should only complete it
# if it is not already canceled or failed.
if queue_item.status not in ["canceled", "failed"]:
queue_item = self._services.session_queue.complete_queue_item(queue_item.item_id)
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
# we don't care about that - suppress the error.
with suppress(GESStatsNotFoundError):
self._services.performance_statistics.log_stats(queue_item.session.id)
self._services.performance_statistics.reset_stats()
for callback in self._on_after_run_session_callbacks:
callback(queue_item=queue_item)
except SessionQueueItemNotFoundError:
pass
def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
"""Called before a node is run.
- Emits an invocation started event.
- Run any callbacks registered for this event.
"""
self._services.logger.debug(
f"On before run node: queue item {queue_item.item_id}, session {queue_item.session_id}, node {invocation.id} ({invocation.get_type()})"
)
# Send starting event
self._services.events.emit_invocation_started(queue_item=queue_item, invocation=invocation)
for callback in self._on_before_run_node_callbacks:
callback(invocation=invocation, queue_item=queue_item)
def _on_after_run_node(
self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput
):
"""Called after a node is run.
- Emits an invocation complete event.
- Run any callbacks registered for this event.
"""
self._services.logger.debug(
f"On after run node: queue item {queue_item.item_id}, session {queue_item.session_id}, node {invocation.id} ({invocation.get_type()})"
)
# Send complete event on successful runs
self._services.events.emit_invocation_complete(invocation=invocation, queue_item=queue_item, output=output)
for callback in self._on_after_run_node_callbacks:
callback(invocation=invocation, queue_item=queue_item, output=output)
def _on_node_error(
self,
invocation: BaseInvocation,
queue_item: SessionQueueItem,
error_type: str,
error_message: str,
error_traceback: str,
):
"""Called when a node errors. Node errors may occur when running or preparing the node..
- Set the node error on the session object.
- Log the error.
- Fail the queue item.
- Emits an invocation error event.
- Run any callbacks registered for this event.
"""
self._services.logger.debug(
f"On node error: queue item {queue_item.item_id}, session {queue_item.session_id}, node {invocation.id} ({invocation.get_type()})"
)
# Node errors do not get the full traceback. Only the queue item gets the full traceback.
node_error = f"{error_type}: {error_message}"
queue_item.session.set_node_error(invocation.id, node_error)
self._services.logger.error(
f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}): {error_message}"
)
self._services.logger.error(error_traceback)
# Fail the queue item
queue_item = self._services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
queue_item = self._services.session_queue.fail_queue_item(
queue_item.item_id, error_type, error_message, error_traceback
)
# Send error event
self._services.events.emit_invocation_error(
queue_item=queue_item,
invocation=invocation,
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)
for callback in self._on_node_error_callbacks:
callback(
invocation=invocation,
queue_item=queue_item,
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)
class DefaultSessionProcessor(SessionProcessorBase):
def start(self, invoker: Invoker, thread_limit: int = 1, polling_interval: int = 1) -> None:
def __init__(
self,
session_runner: Optional[SessionRunnerBase] = None,
on_non_fatal_processor_error_callbacks: Optional[list[OnNonFatalProcessorError]] = None,
thread_limit: int = 1,
polling_interval: int = 1,
) -> None:
super().__init__()
self.session_runner = session_runner if session_runner else DefaultSessionRunner()
self._on_non_fatal_processor_error_callbacks = on_non_fatal_processor_error_callbacks or []
self._thread_limit = thread_limit
self._polling_interval = polling_interval
def start(self, invoker: Invoker) -> None:
self._invoker: Invoker = invoker
self._queue_item: Optional[SessionQueueItem] = None
self._invocation: Optional[BaseInvocation] = None
@@ -31,11 +332,11 @@ class DefaultSessionProcessor(SessionProcessorBase):
self._poll_now_event = ThreadEvent()
self._cancel_event = ThreadEvent()
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event)
register_events(QueueClearedEvent, self._on_queue_cleared)
register_events(BatchEnqueuedEvent, self._on_batch_enqueued)
register_events(QueueItemStatusChangedEvent, self._on_queue_item_status_changed)
self._thread_limit = thread_limit
self._thread_semaphore = BoundedSemaphore(thread_limit)
self._polling_interval = polling_interval
self._thread_semaphore = BoundedSemaphore(self._thread_limit)
# If profiling is enabled, create a profiler. The same profiler will be used for all sessions. Internally,
# the profiler will create a new profile for each session.
@@ -49,6 +350,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
else None
)
self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event, profiler=self._profiler)
self._thread = Thread(
name="session_processor",
target=self._process,
@@ -67,30 +369,25 @@ class DefaultSessionProcessor(SessionProcessorBase):
def _poll_now(self) -> None:
self._poll_now_event.set()
async def _on_queue_event(self, event: FastAPIEvent) -> None:
event_name = event[1]["event"]
async def _on_queue_cleared(self, event: FastAPIEvent[QueueClearedEvent]) -> None:
if self._queue_item and self._queue_item.queue_id == event[1].queue_id:
self._cancel_event.set()
self._poll_now()
if (
event_name == "session_canceled"
and self._queue_item
and self._queue_item.item_id == event[1]["data"]["queue_item_id"]
):
self._cancel_event.set()
self._poll_now()
elif (
event_name == "queue_cleared"
and self._queue_item
and self._queue_item.queue_id == event[1]["data"]["queue_id"]
):
self._cancel_event.set()
self._poll_now()
elif event_name == "batch_enqueued":
self._poll_now()
elif event_name == "queue_item_status_changed" and event[1]["data"]["queue_item"]["status"] in [
"completed",
"failed",
"canceled",
]:
async def _on_batch_enqueued(self, event: FastAPIEvent[BatchEnqueuedEvent]) -> None:
self._poll_now()
async def _on_queue_item_status_changed(self, event: FastAPIEvent[QueueItemStatusChangedEvent]) -> None:
if self._queue_item and event[1].status in ["completed", "failed", "canceled"]:
# When the queue item is canceled via HTTP, the queue item status is set to `"canceled"` and this event is
# emitted. We need to respond to this event and stop graph execution. This is done by setting the cancel
# event, which the session runner checks between invocations. If set, the session runner loop is broken.
#
# Long-running nodes that cannot be interrupted easily present a challenge. `denoise_latents` is one such
# node, but it gets a step callback, called on each step of denoising. This callback checks if the queue item
# is canceled, and if it is, raises a `CanceledException` to stop execution immediately.
if event[1].status == "canceled":
self._cancel_event.set()
self._poll_now()
def resume(self) -> SessionProcessorStatus:
@@ -116,8 +413,8 @@ class DefaultSessionProcessor(SessionProcessorBase):
resume_event: ThreadEvent,
cancel_event: ThreadEvent,
):
# Outermost processor try block; any unhandled exception is a fatal processor error
try:
# Any unhandled exception in this block is a fatal processor error and will stop the processor.
self._thread_semaphore.acquire()
stop_event.clear()
resume_event.set()
@@ -125,8 +422,8 @@ class DefaultSessionProcessor(SessionProcessorBase):
while not stop_event.is_set():
poll_now_event.clear()
# Middle processor try block; any unhandled exception is a non-fatal processor error
try:
# Any unhandled exception in this block is a nonfatal processor error and will be handled.
# If we are paused, wait for resume event
resume_event.wait()
@@ -142,159 +439,69 @@ class DefaultSessionProcessor(SessionProcessorBase):
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
cancel_event.clear()
# If profiling is enabled, start the profiler
if self._profiler is not None:
self._profiler.start(profile_id=self._queue_item.session_id)
# Run the graph
self.session_runner.run(queue_item=self._queue_item)
# Prepare invocations and take the first
self._invocation = self._queue_item.session.next()
# Loop over invocations until the session is complete or canceled
while self._invocation is not None and not cancel_event.is_set():
# get the source node id to provide to clients (the prepared node id is not as useful)
source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id]
# Send starting event
self._invoker.services.events.emit_invocation_started(
queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session_id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
)
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
try:
with self._invoker.services.performance_statistics.collect_stats(
self._invocation, self._queue_item.session.id
):
# Build invocation context (the node-facing API)
data = InvocationContextData(
invocation=self._invocation,
source_invocation_id=source_invocation_id,
queue_item=self._queue_item,
)
context = build_invocation_context(
data=data,
services=self._invoker.services,
cancel_event=self._cancel_event,
)
# Invoke the node
outputs = self._invocation.invoke_internal(
context=context, services=self._invoker.services
)
# Save outputs and history
self._queue_item.session.complete(self._invocation.id, outputs)
# Send complete event
self._invoker.services.events.emit_invocation_complete(
queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
result=outputs.model_dump(),
)
except KeyboardInterrupt:
# TODO(MM2): Create an event for this
pass
except CanceledException:
# When the user cancels the graph, we first set the cancel event. The event is checked
# between invocations, in this loop. Some invocations are long-running, and we need to
# be able to cancel them mid-execution.
#
# For example, denoising is a long-running invocation with many steps. A step callback
# is executed after each step. This step callback checks if the canceled event is set,
# then raises a CanceledException to stop execution immediately.
#
# When we get a CanceledException, we don't need to do anything - just pass and let the
# loop go to its next iteration, and the cancel event will be handled correctly.
pass
except Exception as e:
error = traceback.format_exc()
# Save error
self._queue_item.session.set_node_error(self._invocation.id, error)
self._invoker.services.logger.error(
f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}"
)
self._invoker.services.logger.error(error)
# Send error event
self._invoker.services.events.emit_invocation_error(
queue_batch_id=self._queue_item.session_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
error_type=e.__class__.__name__,
error=error,
user_id=None,
project_id=None,
)
pass
# The session is complete if the all invocations are complete or there was an error
if self._queue_item.session.is_complete() or cancel_event.is_set():
# Send complete event
self._invoker.services.events.emit_graph_execution_complete(
queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
)
# If we are profiling, stop the profiler and dump the profile & stats
if self._profiler:
profile_path = self._profiler.stop()
stats_path = profile_path.with_suffix(".json")
self._invoker.services.performance_statistics.dump_stats(
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
)
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
# we don't care about that - suppress the error.
with suppress(GESStatsNotFoundError):
self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id)
self._invoker.services.performance_statistics.reset_stats()
# Set the invocation to None to prepare for the next session
self._invocation = None
else:
# Prepare the next invocation
self._invocation = self._queue_item.session.next()
else:
# The queue was empty, wait for next polling interval or event to try again
self._invoker.services.logger.debug("Waiting for next polling interval or event")
poll_now_event.wait(self._polling_interval)
continue
except Exception:
# Non-fatal error in processor
self._invoker.services.logger.error(
f"Non-fatal error in session processor:\n{traceback.format_exc()}"
except Exception as e:
error_type = e.__class__.__name__
error_message = str(e)
error_traceback = traceback.format_exc()
self._on_non_fatal_processor_error(
queue_item=self._queue_item,
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)
# Cancel the queue item
if self._queue_item is not None:
self._invoker.services.session_queue.cancel_queue_item(
self._queue_item.item_id, error=traceback.format_exc()
)
# Reset the invocation to None to prepare for the next session
self._invocation = None
# Immediately poll for next queue item
# Wait for next polling interval or event to try again
poll_now_event.wait(self._polling_interval)
continue
except Exception:
except Exception as e:
# Fatal error in processor, log and pass - we're done here
self._invoker.services.logger.error(f"Fatal Error in session processor:\n{traceback.format_exc()}")
error_type = e.__class__.__name__
error_message = str(e)
error_traceback = traceback.format_exc()
self._invoker.services.logger.error(f"Fatal Error in session processor {error_type}: {error_message}")
self._invoker.services.logger.error(error_traceback)
pass
finally:
stop_event.clear()
poll_now_event.clear()
self._queue_item = None
self._thread_semaphore.release()
def _on_non_fatal_processor_error(
self,
queue_item: Optional[SessionQueueItem],
error_type: str,
error_message: str,
error_traceback: str,
) -> None:
"""Called when a non-fatal error occurs in the processor.
- Log the error.
- If a queue item is provided, update the queue item with the completed session & fail it.
- Run any callbacks registered for this event.
"""
self._invoker.services.logger.error(f"Non-fatal error in session processor {error_type}: {error_message}")
self._invoker.services.logger.error(error_traceback)
if queue_item is not None:
# Update the queue item with the completed session & fail it
queue_item = self._invoker.services.session_queue.set_queue_item_session(
queue_item.item_id, queue_item.session
)
queue_item = self._invoker.services.session_queue.fail_queue_item(
item_id=queue_item.item_id,
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)
for callback in self._on_non_fatal_processor_error_callbacks:
callback(
queue_item=queue_item,
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)

View File

@@ -16,6 +16,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
SessionQueueItemDTO,
SessionQueueStatus,
)
from invokeai.app.services.shared.graph import GraphExecutionState
from invokeai.app.services.shared.pagination import CursorPaginatedResults
@@ -73,10 +74,22 @@ class SessionQueueBase(ABC):
pass
@abstractmethod
def cancel_queue_item(self, item_id: int, error: Optional[str] = None) -> SessionQueueItem:
def complete_queue_item(self, item_id: int) -> SessionQueueItem:
"""Completes a session queue item"""
pass
@abstractmethod
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
"""Cancels a session queue item"""
pass
@abstractmethod
def fail_queue_item(
self, item_id: int, error_type: str, error_message: str, error_traceback: str
) -> SessionQueueItem:
"""Fails a session queue item"""
pass
@abstractmethod
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
"""Cancels all queue items with matching batch IDs"""
@@ -103,3 +116,8 @@ class SessionQueueBase(ABC):
def get_queue_item(self, item_id: int) -> SessionQueueItem:
"""Gets a session queue item by ID"""
pass
@abstractmethod
def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem:
"""Sets the session for a session queue item. Use this to update the session state."""
pass

View File

@@ -3,7 +3,16 @@ import json
from itertools import chain, product
from typing import Generator, Iterable, Literal, NamedTuple, Optional, TypeAlias, Union, cast
from pydantic import BaseModel, ConfigDict, Field, StrictStr, TypeAdapter, field_validator, model_validator
from pydantic import (
AliasChoices,
BaseModel,
ConfigDict,
Field,
StrictStr,
TypeAdapter,
field_validator,
model_validator,
)
from pydantic_core import to_jsonable_python
from invokeai.app.invocations.baseinvocation import BaseInvocation
@@ -189,7 +198,13 @@ class SessionQueueItemWithoutGraph(BaseModel):
session_id: str = Field(
description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed."
)
error: Optional[str] = Field(default=None, description="The error message if this queue item errored")
error_type: Optional[str] = Field(default=None, description="The error type if this queue item errored")
error_message: Optional[str] = Field(default=None, description="The error message if this queue item errored")
error_traceback: Optional[str] = Field(
default=None,
description="The error traceback if this queue item errored",
validation_alias=AliasChoices("error_traceback", "error"),
)
created_at: Union[datetime.datetime, str] = Field(description="When this queue item was created")
updated_at: Union[datetime.datetime, str] = Field(description="When this queue item was updated")
started_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was started")

View File

@@ -2,10 +2,6 @@ import sqlite3
import threading
from typing import Optional, Union, cast
from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event as FastAPIEvent
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
from invokeai.app.services.session_queue.session_queue_common import (
@@ -27,6 +23,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
calc_session_count,
prepare_values_to_insert,
)
from invokeai.app.services.shared.graph import GraphExecutionState
from invokeai.app.services.shared.pagination import CursorPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
@@ -41,7 +38,7 @@ class SqliteSessionQueue(SessionQueueBase):
self.__invoker = invoker
self._set_in_progress_to_canceled()
prune_result = self.prune(DEFAULT_QUEUE_ID)
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_session_event)
if prune_result.deleted > 0:
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
@@ -51,52 +48,6 @@ class SqliteSessionQueue(SessionQueueBase):
self.__conn = db.conn
self.__cursor = self.__conn.cursor()
def _match_event_name(self, event: FastAPIEvent, match_in: list[str]) -> bool:
return event[1]["event"] in match_in
async def _on_session_event(self, event: FastAPIEvent) -> FastAPIEvent:
event_name = event[1]["event"]
# This was a match statement, but match is not supported on python 3.9
if event_name == "graph_execution_state_complete":
await self._handle_complete_event(event)
elif event_name == "invocation_error":
await self._handle_error_event(event)
elif event_name == "session_canceled":
await self._handle_cancel_event(event)
return event
async def _handle_complete_event(self, event: FastAPIEvent) -> None:
try:
item_id = event[1]["data"]["queue_item_id"]
# When a queue item has an error, we get an error event, then a completed event.
# Mark the queue item completed only if it isn't already marked completed, e.g.
# by a previously-handled error event.
queue_item = self.get_queue_item(item_id)
if queue_item.status not in ["completed", "failed", "canceled"]:
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="completed")
except SessionQueueItemNotFoundError:
return
async def _handle_error_event(self, event: FastAPIEvent) -> None:
try:
item_id = event[1]["data"]["queue_item_id"]
error = event[1]["data"]["error"]
queue_item = self.get_queue_item(item_id)
# always set to failed if have an error, even if previously the item was marked completed or canceled
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="failed", error=error)
except SessionQueueItemNotFoundError:
return
async def _handle_cancel_event(self, event: FastAPIEvent) -> None:
try:
item_id = event[1]["data"]["queue_item_id"]
queue_item = self.get_queue_item(item_id)
if queue_item.status not in ["completed", "failed", "canceled"]:
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="canceled")
except SessionQueueItemNotFoundError:
return
def _set_in_progress_to_canceled(self) -> None:
"""
Sets all in_progress queue items to canceled. Run on app startup, not associated with any queue.
@@ -271,17 +222,22 @@ class SqliteSessionQueue(SessionQueueBase):
return SessionQueueItem.queue_item_from_dict(dict(result))
def _set_queue_item_status(
self, item_id: int, status: QUEUE_ITEM_STATUS, error: Optional[str] = None
self,
item_id: int,
status: QUEUE_ITEM_STATUS,
error_type: Optional[str] = None,
error_message: Optional[str] = None,
error_traceback: Optional[str] = None,
) -> SessionQueueItem:
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
UPDATE session_queue
SET status = ?, error = ?
SET status = ?, error_type = ?, error_message = ?, error_traceback = ?
WHERE item_id = ?
""",
(status, error, item_id),
(status, error_type, error_message, error_traceback, item_id),
)
self.__conn.commit()
except Exception:
@@ -292,11 +248,7 @@ class SqliteSessionQueue(SessionQueueBase):
queue_item = self.get_queue_item(item_id)
batch_status = self.get_batch_status(queue_id=queue_item.queue_id, batch_id=queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_item.queue_id)
self.__invoker.services.events.emit_queue_item_status_changed(
session_queue_item=queue_item,
batch_status=batch_status,
queue_status=queue_status,
)
self.__invoker.services.events.emit_queue_item_status_changed(queue_item, batch_status, queue_status)
return queue_item
def is_empty(self, queue_id: str) -> IsEmptyResult:
@@ -338,26 +290,6 @@ class SqliteSessionQueue(SessionQueueBase):
self.__lock.release()
return IsFullResult(is_full=is_full)
def delete_queue_item(self, item_id: int) -> SessionQueueItem:
queue_item = self.get_queue_item(item_id=item_id)
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
DELETE FROM session_queue
WHERE
item_id = ?
""",
(item_id,),
)
self.__conn.commit()
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
return queue_item
def clear(self, queue_id: str) -> ClearResult:
try:
self.__lock.acquire()
@@ -424,17 +356,28 @@ class SqliteSessionQueue(SessionQueueBase):
self.__lock.release()
return PruneResult(deleted=count)
def cancel_queue_item(self, item_id: int, error: Optional[str] = None) -> SessionQueueItem:
queue_item = self.get_queue_item(item_id)
if queue_item.status not in ["canceled", "failed", "completed"]:
status = "failed" if error is not None else "canceled"
queue_item = self._set_queue_item_status(item_id=item_id, status=status, error=error) # type: ignore [arg-type] # mypy seems to not narrow the Literals here
self.__invoker.services.events.emit_session_canceled(
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
queue_batch_id=queue_item.batch_id,
graph_execution_state_id=queue_item.session_id,
)
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
queue_item = self._set_queue_item_status(item_id=item_id, status="canceled")
return queue_item
def complete_queue_item(self, item_id: int) -> SessionQueueItem:
queue_item = self._set_queue_item_status(item_id=item_id, status="completed")
return queue_item
def fail_queue_item(
self,
item_id: int,
error_type: str,
error_message: str,
error_traceback: str,
) -> SessionQueueItem:
queue_item = self._set_queue_item_status(
item_id=item_id,
status="failed",
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)
return queue_item
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
@@ -470,18 +413,10 @@ class SqliteSessionQueue(SessionQueueBase):
)
self.__conn.commit()
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
self.__invoker.services.events.emit_session_canceled(
queue_item_id=current_queue_item.item_id,
queue_id=current_queue_item.queue_id,
queue_batch_id=current_queue_item.batch_id,
graph_execution_state_id=current_queue_item.session_id,
)
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_id)
self.__invoker.services.events.emit_queue_item_status_changed(
session_queue_item=current_queue_item,
batch_status=batch_status,
queue_status=queue_status,
current_queue_item, batch_status, queue_status
)
except Exception:
self.__conn.rollback()
@@ -521,18 +456,10 @@ class SqliteSessionQueue(SessionQueueBase):
)
self.__conn.commit()
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
self.__invoker.services.events.emit_session_canceled(
queue_item_id=current_queue_item.item_id,
queue_id=current_queue_item.queue_id,
queue_batch_id=current_queue_item.batch_id,
graph_execution_state_id=current_queue_item.session_id,
)
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_id)
self.__invoker.services.events.emit_queue_item_status_changed(
session_queue_item=current_queue_item,
batch_status=batch_status,
queue_status=queue_status,
current_queue_item, batch_status, queue_status
)
except Exception:
self.__conn.rollback()
@@ -562,6 +489,29 @@ class SqliteSessionQueue(SessionQueueBase):
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
return SessionQueueItem.queue_item_from_dict(dict(result))
def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem:
try:
# Use exclude_none so we don't end up with a bunch of nulls in the graph - this can cause validation errors
# when the graph is loaded. Graph execution occurs purely in memory - the session saved here is not referenced
# during execution.
session_json = session.model_dump_json(warnings=False, exclude_none=True)
self.__lock.acquire()
self.__cursor.execute(
"""--sql
UPDATE session_queue
SET session = ?
WHERE item_id = ?
""",
(session_json, item_id),
)
self.__conn.commit()
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
return self.get_queue_item(item_id)
def list_queue_items(
self,
queue_id: str,
@@ -578,7 +528,9 @@ class SqliteSessionQueue(SessionQueueBase):
status,
priority,
field_values,
error,
error_type,
error_message,
error_traceback,
created_at,
updated_at,
completed_at,

View File

@@ -2,17 +2,19 @@
import copy
import itertools
from typing import Annotated, Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
from typing import Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
import networkx as nx
from pydantic import (
BaseModel,
GetCoreSchemaHandler,
GetJsonSchemaHandler,
ValidationError,
field_validator,
)
from pydantic.fields import Field
from pydantic.json_schema import JsonSchemaValue
from pydantic_core import CoreSchema
from pydantic_core import core_schema
# Importing * is bad karma but needed here for node detection
from invokeai.app.invocations import * # noqa: F401 F403
@@ -190,6 +192,39 @@ class UnknownGraphValidationError(ValueError):
pass
class NodeInputError(ValueError):
"""Raised when a node fails preparation. This occurs when a node's inputs are being set from its incomers, but an
input fails validation.
Attributes:
node: The node that failed preparation. Note: only successfully set fields will be accurate. Review the error to
determine which field caused the failure.
"""
def __init__(self, node: BaseInvocation, e: ValidationError):
self.original_error = e
self.node = node
# When preparing a node, we set each input one-at-a-time. We may thus safely assume that the first error
# represents the first input that failed.
self.failed_input = loc_to_dot_sep(e.errors()[0]["loc"])
super().__init__(f"Node {node.id} has invalid incoming input for {self.failed_input}")
def loc_to_dot_sep(loc: tuple[Union[str, int], ...]) -> str:
"""Helper to pretty-print pydantic error locations as dot-separated strings.
Taken from https://docs.pydantic.dev/latest/errors/errors/#customize-error-messages
"""
path = ""
for i, x in enumerate(loc):
if isinstance(x, str):
if i > 0:
path += "."
path += x
else:
path += f"[{x}]"
return path
@invocation_output("iterate_output")
class IterateInvocationOutput(BaseInvocationOutput):
"""Used to connect iteration outputs. Will be expanded to a specific output."""
@@ -243,73 +278,58 @@ class CollectInvocation(BaseInvocation):
return CollectInvocationOutput(collection=copy.copy(self.collection))
class AnyInvocation(BaseInvocation):
@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
def validate_invocation(v: Any) -> "AnyInvocation":
return BaseInvocation.get_typeadapter().validate_python(v)
return core_schema.no_info_plain_validator_function(validate_invocation)
@classmethod
def __get_pydantic_json_schema__(
cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
# Nodes are too powerful, we have to make our own OpenAPI schema manually
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
oneOf: list[dict[str, str]] = []
names = [i.__name__ for i in BaseInvocation.get_invocations()]
for name in sorted(names):
oneOf.append({"$ref": f"#/components/schemas/{name}"})
return {"oneOf": oneOf}
class AnyInvocationOutput(BaseInvocationOutput):
@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler):
def validate_invocation_output(v: Any) -> "AnyInvocationOutput":
return BaseInvocationOutput.get_typeadapter().validate_python(v)
return core_schema.no_info_plain_validator_function(validate_invocation_output)
@classmethod
def __get_pydantic_json_schema__(
cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
# Nodes are too powerful, we have to make our own OpenAPI schema manually
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
oneOf: list[dict[str, str]] = []
names = [i.__name__ for i in BaseInvocationOutput.get_outputs()]
for name in sorted(names):
oneOf.append({"$ref": f"#/components/schemas/{name}"})
return {"oneOf": oneOf}
class Graph(BaseModel):
id: str = Field(description="The id of this graph", default_factory=uuid_string)
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
nodes: dict[str, BaseInvocation] = Field(description="The nodes in this graph", default_factory=dict)
nodes: dict[str, AnyInvocation] = Field(description="The nodes in this graph", default_factory=dict)
edges: list[Edge] = Field(
description="The connections between nodes and their fields in this graph",
default_factory=list,
)
@field_validator("nodes", mode="plain")
@classmethod
def validate_nodes(cls, v: dict[str, Any]):
"""Validates the nodes in the graph by retrieving a union of all node types and validating each node."""
# Invocations register themselves as their python modules are executed. The union of all invocations is
# constructed at runtime. We use pydantic to validate `Graph.nodes` using that union.
#
# It's possible that when `graph.py` is executed, not all invocation-containing modules will have executed. If
# we construct the invocation union as `graph.py` is executed, we may miss some invocations. Those missing
# invocations will cause a graph to fail if they are used.
#
# We can get around this by validating the nodes in the graph using a "plain" validator, which overrides the
# pydantic validation entirely. This allows us to validate the nodes using the union of invocations at runtime.
#
# This same pattern is used in `GraphExecutionState`.
nodes: dict[str, BaseInvocation] = {}
typeadapter = BaseInvocation.get_typeadapter()
for node_id, node in v.items():
nodes[node_id] = typeadapter.validate_python(node)
return nodes
@classmethod
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
# We use a "plain" validator to validate the nodes in the graph. Pydantic is unable to create a JSON Schema for
# fields that use "plain" validators, so we have to hack around this. Also, we need to add all invocations to
# the generated schema as options for the `nodes` field.
#
# The workaround is to create a new BaseModel that has the same fields as `Graph` but without the validator and
# with the invocation union as the type for the `nodes` field. Pydantic then generates the JSON Schema as
# expected.
#
# You might be tempted to do something like this:
#
# ```py
# cloned_model = create_model(cls.__name__, __base__=cls, nodes=...)
# delattr(cloned_model, "validate_nodes")
# cloned_model.model_rebuild(force=True)
# json_schema = handler(cloned_model.__pydantic_core_schema__)
# ```
#
# Unfortunately, this does not work. Calling `handler` here results in infinite recursion as pydantic attempts
# to build the JSON Schema for the cloned model. Instead, we have to manually clone the model.
#
# This same pattern is used in `GraphExecutionState`.
class Graph(BaseModel):
id: Optional[str] = Field(default=None, description="The id of this graph")
nodes: dict[
str, Annotated[Union[tuple(BaseInvocation._invocation_classes)], Field(discriminator="type")]
] = Field(description="The nodes in this graph")
edges: list[Edge] = Field(description="The connections between nodes and their fields in this graph")
json_schema = handler(Graph.__pydantic_core_schema__)
json_schema = handler.resolve_ref_schema(json_schema)
return json_schema
def add_node(self, node: BaseInvocation) -> None:
"""Adds a node to a graph
@@ -740,7 +760,7 @@ class GraphExecutionState(BaseModel):
)
# The results of executed nodes
results: dict[str, BaseInvocationOutput] = Field(description="The results of node executions", default_factory=dict)
results: dict[str, AnyInvocationOutput] = Field(description="The results of node executions", default_factory=dict)
# Errors raised when executing nodes
errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict)
@@ -757,52 +777,12 @@ class GraphExecutionState(BaseModel):
default_factory=dict,
)
@field_validator("results", mode="plain")
@classmethod
def validate_results(cls, v: dict[str, BaseInvocationOutput]):
"""Validates the results in the GES by retrieving a union of all output types and validating each result."""
# See the comment in `Graph.validate_nodes` for an explanation of this logic.
results: dict[str, BaseInvocationOutput] = {}
typeadapter = BaseInvocationOutput.get_typeadapter()
for result_id, result in v.items():
results[result_id] = typeadapter.validate_python(result)
return results
@field_validator("graph")
def graph_is_valid(cls, v: Graph):
"""Validates that the graph is valid"""
v.validate_self()
return v
@classmethod
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
# See the comment in `Graph.__get_pydantic_json_schema__` for an explanation of this logic.
class GraphExecutionState(BaseModel):
"""Tracks the state of a graph execution"""
id: str = Field(description="The id of the execution state")
graph: Graph = Field(description="The graph being executed")
execution_graph: Graph = Field(description="The expanded graph of activated and executed nodes")
executed: set[str] = Field(description="The set of node ids that have been executed")
executed_history: list[str] = Field(
description="The list of node ids that have been executed, in order of execution"
)
results: dict[
str, Annotated[Union[tuple(BaseInvocationOutput._output_classes)], Field(discriminator="type")]
] = Field(description="The results of node executions")
errors: dict[str, str] = Field(description="Errors raised when executing nodes")
prepared_source_mapping: dict[str, str] = Field(
description="The map of prepared nodes to original graph nodes"
)
source_prepared_mapping: dict[str, set[str]] = Field(
description="The map of original graph nodes to prepared nodes"
)
json_schema = handler(GraphExecutionState.__pydantic_core_schema__)
json_schema = handler.resolve_ref_schema(json_schema)
return json_schema
def next(self) -> Optional[BaseInvocation]:
"""Gets the next node ready to execute."""
@@ -821,7 +801,10 @@ class GraphExecutionState(BaseModel):
# Get values from edges
if next_node is not None:
self._prepare_inputs(next_node)
try:
self._prepare_inputs(next_node)
except ValidationError as e:
raise NodeInputError(next_node, e)
# If next is still none, there's no next node, return None
return next_node

View File

@@ -1,9 +1,9 @@
import threading
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, Callable, Optional, Union
from PIL.Image import Image
from pydantic.networks import AnyHttpUrl
from torch import Tensor
from invokeai.app.invocations.constants import IMAGE_MODES
@@ -15,8 +15,15 @@ from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.model_records.model_records_base import UnknownModelException
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.model_manager.config import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
@@ -321,8 +328,10 @@ class ConditioningInterface(InvocationContextInterface):
class ModelsInterface(InvocationContextInterface):
"""Common API for loading, downloading and managing models."""
def exists(self, identifier: Union[str, "ModelIdentifierField"]) -> bool:
"""Checks if a model exists.
"""Check if a model exists.
Args:
identifier: The key or ModelField representing the model.
@@ -332,13 +341,13 @@ class ModelsInterface(InvocationContextInterface):
"""
if isinstance(identifier, str):
return self._services.model_manager.store.exists(identifier)
return self._services.model_manager.store.exists(identifier.key)
else:
return self._services.model_manager.store.exists(identifier.key)
def load(
self, identifier: Union[str, "ModelIdentifierField"], submodel_type: Optional[SubModelType] = None
) -> LoadedModel:
"""Loads a model.
"""Load a model.
Args:
identifier: The key or ModelField representing the model.
@@ -353,16 +362,16 @@ class ModelsInterface(InvocationContextInterface):
if isinstance(identifier, str):
model = self._services.model_manager.store.get_model(identifier)
return self._services.model_manager.load.load_model(model, submodel_type, self._data)
return self._services.model_manager.load.load_model(model, submodel_type)
else:
_submodel_type = submodel_type or identifier.submodel_type
model = self._services.model_manager.store.get_model(identifier.key)
return self._services.model_manager.load.load_model(model, _submodel_type, self._data)
return self._services.model_manager.load.load_model(model, _submodel_type)
def load_by_attrs(
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
) -> LoadedModel:
"""Loads a model by its attributes.
"""Load a model by its attributes.
Args:
name: Name of the model.
@@ -382,10 +391,10 @@ class ModelsInterface(InvocationContextInterface):
if len(configs) > 1:
raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}")
return self._services.model_manager.load.load_model(configs[0], submodel_type, self._data)
return self._services.model_manager.load.load_model(configs[0], submodel_type)
def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig:
"""Gets a model's config.
"""Get a model's config.
Args:
identifier: The key or ModelField representing the model.
@@ -395,11 +404,11 @@ class ModelsInterface(InvocationContextInterface):
"""
if isinstance(identifier, str):
return self._services.model_manager.store.get_model(identifier)
return self._services.model_manager.store.get_model(identifier.key)
else:
return self._services.model_manager.store.get_model(identifier.key)
def search_by_path(self, path: Path) -> list[AnyModelConfig]:
"""Searches for models by path.
"""Search for models by path.
Args:
path: The path to search for.
@@ -416,7 +425,7 @@ class ModelsInterface(InvocationContextInterface):
type: Optional[ModelType] = None,
format: Optional[ModelFormat] = None,
) -> list[AnyModelConfig]:
"""Searches for models by attributes.
"""Search for models by attributes.
Args:
name: The name to search for (exact match).
@@ -435,6 +444,72 @@ class ModelsInterface(InvocationContextInterface):
model_format=format,
)
def download_and_cache_model(
self,
source: str | AnyHttpUrl,
) -> Path:
"""
Download the model file located at source to the models cache and return its Path.
This can be used to single-file install models and other resources of arbitrary types
which should not get registered with the database. If the model is already
installed, the cached path will be returned. Otherwise it will be downloaded.
Args:
source: A URL that points to the model, or a huggingface repo_id.
Returns:
Path to the downloaded model
"""
return self._services.model_manager.install.download_and_cache_model(source=source)
def load_local_model(
self,
model_path: Path,
loader: Optional[Callable[[Path], AnyModel]] = None,
) -> LoadedModelWithoutConfig:
"""
Load the model file located at the indicated path
If a loader callable is provided, it will be invoked to load the model. Otherwise,
`safetensors.torch.load_file()` or `torch.load()` will be called to load the model.
Be aware that the LoadedModelWithoutConfig object has no `config` attribute
Args:
path: A model Path
loader: A Callable that expects a Path and returns a dict[str|int, Any]
Returns:
A LoadedModelWithoutConfig object.
"""
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
def load_remote_model(
self,
source: str | AnyHttpUrl,
loader: Optional[Callable[[Path], AnyModel]] = None,
) -> LoadedModelWithoutConfig:
"""
Download, cache, and load the model file located at the indicated URL or repo_id.
If the model is already downloaded, it will be loaded from the cache.
If the a loader callable is provided, it will be invoked to load the model. Otherwise,
`safetensors.torch.load_file()` or `torch.load()` will be called to load the model.
Be aware that the LoadedModelWithoutConfig object has no `config` attribute
Args:
source: A URL or huggingface repoid.
loader: A Callable that expects a Path and returns a dict[str|int, Any]
Returns:
A LoadedModelWithoutConfig object.
"""
model_path = self._services.model_manager.install.download_and_cache_model(source=str(source))
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
class ConfigInterface(InvocationContextInterface):
def get(self) -> InvokeAIAppConfig:
@@ -449,10 +524,10 @@ class ConfigInterface(InvocationContextInterface):
class UtilInterface(InvocationContextInterface):
def __init__(
self, services: InvocationServices, data: InvocationContextData, cancel_event: threading.Event
self, services: InvocationServices, data: InvocationContextData, is_canceled: Callable[[], bool]
) -> None:
super().__init__(services, data)
self._cancel_event = cancel_event
self._is_canceled = is_canceled
def is_canceled(self) -> bool:
"""Checks if the current session has been canceled.
@@ -460,7 +535,7 @@ class UtilInterface(InvocationContextInterface):
Returns:
True if the current session has been canceled, False if not.
"""
return self._cancel_event.is_set()
return self._is_canceled()
def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_model: BaseModelType) -> None:
"""
@@ -535,7 +610,7 @@ class InvocationContext:
def build_invocation_context(
services: InvocationServices,
data: InvocationContextData,
cancel_event: threading.Event,
is_canceled: Callable[[], bool],
) -> InvocationContext:
"""Builds the invocation context for a specific invocation execution.
@@ -552,7 +627,7 @@ def build_invocation_context(
tensors = TensorsInterface(services=services, data=data)
models = ModelsInterface(services=services, data=data)
config = ConfigInterface(services=services, data=data)
util = UtilInterface(services=services, data=data, cancel_event=cancel_event)
util = UtilInterface(services=services, data=data, is_canceled=is_canceled)
conditioning = ConditioningInterface(services=services, data=data)
boards = BoardsInterface(services=services, data=data)

View File

@@ -12,6 +12,8 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_6 import
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_7 import build_migration_7
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_8 import build_migration_8
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_9 import build_migration_9
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_10 import build_migration_10
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_11 import build_migration_11
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@@ -41,6 +43,8 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_7())
migrator.register_migration(build_migration_8(app_config=config))
migrator.register_migration(build_migration_9())
migrator.register_migration(build_migration_10())
migrator.register_migration(build_migration_11(app_config=config, logger=logger))
migrator.run_migrations()
return db

View File

@@ -0,0 +1,35 @@
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration10Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._update_error_cols(cursor)
def _update_error_cols(self, cursor: sqlite3.Cursor) -> None:
"""
- Adds `error_type` and `error_message` columns to the session queue table.
- Renames the `error` column to `error_traceback`.
"""
cursor.execute("ALTER TABLE session_queue ADD COLUMN error_type TEXT;")
cursor.execute("ALTER TABLE session_queue ADD COLUMN error_message TEXT;")
cursor.execute("ALTER TABLE session_queue RENAME COLUMN error TO error_traceback;")
def build_migration_10() -> Migration:
"""
Build the migration from database version 9 to 10.
This migration does the following:
- Adds `error_type` and `error_message` columns to the session queue table.
- Renames the `error` column to `error_traceback`.
"""
migration_10 = Migration(
from_version=9,
to_version=10,
callback=Migration10Callback(),
)
return migration_10

View File

@@ -0,0 +1,75 @@
import shutil
import sqlite3
from logging import Logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
LEGACY_CORE_MODELS = [
# OpenPose
"any/annotators/dwpose/yolox_l.onnx",
"any/annotators/dwpose/dw-ll_ucoco_384.onnx",
# DepthAnything
"any/annotators/depth_anything/depth_anything_vitl14.pth",
"any/annotators/depth_anything/depth_anything_vitb14.pth",
"any/annotators/depth_anything/depth_anything_vits14.pth",
# Lama inpaint
"core/misc/lama/lama.pt",
# RealESRGAN upscale
"core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
"core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
"core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
"core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
]
class Migration11Callback:
def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None:
self._app_config = app_config
self._logger = logger
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._remove_convert_cache()
self._remove_downloaded_models()
self._remove_unused_core_models()
def _remove_convert_cache(self) -> None:
"""Rename models/.cache to models/.convert_cache."""
self._logger.info("Removing .cache directory. Converted models will now be cached in .convert_cache.")
legacy_convert_path = self._app_config.root_path / "models" / ".cache"
shutil.rmtree(legacy_convert_path, ignore_errors=True)
def _remove_downloaded_models(self) -> None:
"""Remove models from their old locations; they will re-download when needed."""
self._logger.info(
"Removing legacy just-in-time models. Downloaded models will now be cached in .download_cache."
)
for model_path in LEGACY_CORE_MODELS:
legacy_dest_path = self._app_config.models_path / model_path
legacy_dest_path.unlink(missing_ok=True)
def _remove_unused_core_models(self) -> None:
"""Remove unused core models and their directories."""
self._logger.info("Removing defunct core models.")
for dir in ["face_restoration", "misc", "upscaling"]:
path_to_remove = self._app_config.models_path / "core" / dir
shutil.rmtree(path_to_remove, ignore_errors=True)
shutil.rmtree(self._app_config.models_path / "any" / "annotators", ignore_errors=True)
def build_migration_11(app_config: InvokeAIAppConfig, logger: Logger) -> Migration:
"""
Build the migration from database version 10 to 11.
This migration does the following:
- Moves "core" models previously downloaded with download_with_progress_bar() into new
"models/.download_cache" directory.
- Renames "models/.cache" to "models/.convert_cache".
"""
migration_11 = Migration(
from_version=10,
to_version=11,
callback=Migration11Callback(app_config=app_config, logger=logger),
)
return migration_11

View File

@@ -289,7 +289,7 @@ def prepare_control_image(
width: int,
height: int,
num_channels: int = 3,
device: str = "cuda",
device: str | torch.device = "cuda",
dtype: torch.dtype = torch.float16,
control_mode: CONTROLNET_MODE_VALUES = "balanced",
resize_mode: CONTROLNET_RESIZE_VALUES = "just_resize_simple",
@@ -304,7 +304,7 @@ def prepare_control_image(
num_channels (int, optional): The target number of image channels. This is achieved by converting the input
image to RGB, then naively taking the first `num_channels` channels. The primary use case is converting a
RGB image to a single-channel grayscale image. Raises if `num_channels` cannot be achieved. Defaults to 3.
device (str, optional): The target device for the output image. Defaults to "cuda".
device (str | torch.Device, optional): The target device for the output image. Defaults to "cuda".
dtype (_type_, optional): The dtype for the output image. Defaults to torch.float16.
do_classifier_free_guidance (bool, optional): If True, repeat the output image along the batch dimension.
Defaults to True.

View File

@@ -0,0 +1,116 @@
from typing import Any, Callable, Optional
from fastapi import FastAPI
from fastapi.openapi.utils import get_openapi
from pydantic.json_schema import models_json_schema
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, UIConfigBase
from invokeai.app.invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.events.events_common import EventBase
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
def move_defs_to_top_level(openapi_schema: dict[str, Any], component_schema: dict[str, Any]) -> None:
"""Moves a component schema's $defs to the top level of the openapi schema. Useful when generating a schema
for a single model that needs to be added back to the top level of the schema. Mutates openapi_schema and
component_schema."""
defs = component_schema.pop("$defs", {})
for schema_key, json_schema in defs.items():
if schema_key in openapi_schema["components"]["schemas"]:
continue
openapi_schema["components"]["schemas"][schema_key] = json_schema
def get_openapi_func(
app: FastAPI, post_transform: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None
) -> Callable[[], dict[str, Any]]:
"""Gets the OpenAPI schema generator function.
Args:
app (FastAPI): The FastAPI app to generate the schema for.
post_transform (Optional[Callable[[dict[str, Any]], dict[str, Any]]], optional): A function to apply to the
generated schema before returning it. Defaults to None.
Returns:
Callable[[], dict[str, Any]]: The OpenAPI schema generator function. When first called, the generated schema is
cached in `app.openapi_schema`. On subsequent calls, the cached schema is returned. This caching behaviour
matches FastAPI's default schema generation caching.
"""
def openapi() -> dict[str, Any]:
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title=app.title,
description="An API for invoking AI image operations",
version="1.0.0",
routes=app.routes,
separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/
)
# We'll create a map of invocation type to output schema to make some types simpler on the client.
invocation_output_map_properties: dict[str, Any] = {}
invocation_output_map_required: list[str] = []
# We need to manually add all outputs to the schema - pydantic doesn't add them because they aren't used directly.
for output in BaseInvocationOutput.get_outputs():
json_schema = output.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}")
move_defs_to_top_level(openapi_schema, json_schema)
openapi_schema["components"]["schemas"][output.__name__] = json_schema
# Technically, invocations are added to the schema by pydantic, but we still need to manually set their output
# property, so we'll just do it all manually.
for invocation in BaseInvocation.get_invocations():
json_schema = invocation.model_json_schema(
mode="serialization", ref_template="#/components/schemas/{model}"
)
move_defs_to_top_level(openapi_schema, json_schema)
output_title = invocation.get_output_annotation().__name__
outputs_ref = {"$ref": f"#/components/schemas/{output_title}"}
json_schema["output"] = outputs_ref
openapi_schema["components"]["schemas"][invocation.__name__] = json_schema
# Add this invocation and its output to the output map
invocation_type = invocation.get_type()
invocation_output_map_properties[invocation_type] = json_schema["output"]
invocation_output_map_required.append(invocation_type)
# Add the output map to the schema
openapi_schema["components"]["schemas"]["InvocationOutputMap"] = {
"type": "object",
"properties": invocation_output_map_properties,
"required": invocation_output_map_required,
}
# Some models don't end up in the schemas as standalone definitions because they aren't used directly in the API.
# We need to add them manually here. WARNING: Pydantic can choke if you call `model.model_json_schema()` to get
# a schema. This has something to do with schema refs - not totally clear. For whatever reason, using
# `models_json_schema` seems to work fine.
additional_models = [
*EventBase.get_events(),
UIConfigBase,
InputFieldJSONSchemaExtra,
OutputFieldJSONSchemaExtra,
ModelIdentifierField,
ProgressImage,
]
additional_schemas = models_json_schema(
[(m, "serialization") for m in additional_models],
ref_template="#/components/schemas/{model}",
)
# additional_schemas[1] is a dict of $defs that we need to add to the top level of the schema
move_defs_to_top_level(openapi_schema, additional_schemas[1])
if post_transform is not None:
openapi_schema = post_transform(openapi_schema)
openapi_schema["components"]["schemas"] = dict(sorted(openapi_schema["components"]["schemas"].items()))
app.openapi_schema = openapi_schema
return app.openapi_schema
return openapi

View File

@@ -1,51 +0,0 @@
from pathlib import Path
from urllib import request
from tqdm import tqdm
from invokeai.backend.util.logging import InvokeAILogger
class ProgressBar:
"""Simple progress bar for urllib.request.urlretrieve using tqdm."""
def __init__(self, model_name: str = "file"):
self.pbar = None
self.name = model_name
def __call__(self, block_num: int, block_size: int, total_size: int):
if not self.pbar:
self.pbar = tqdm(
desc=self.name,
initial=0,
unit="iB",
unit_scale=True,
unit_divisor=1000,
total=total_size,
)
self.pbar.update(block_size)
def download_with_progress_bar(name: str, url: str, dest_path: Path) -> bool:
"""Download a file from a URL to a destination path, with a progress bar.
If the file already exists, it will not be downloaded again.
Exceptions are not caught.
Args:
name (str): Name of the file being downloaded.
url (str): URL to download the file from.
dest_path (Path): Destination path to save the file to.
Returns:
bool: True if the file was downloaded, False if it already existed.
"""
if dest_path.exists():
return False # already downloaded
InvokeAILogger.get_logger().info(f"Downloading {name}...")
dest_path.parent.mkdir(parents=True, exist_ok=True)
request.urlretrieve(url, dest_path, ProgressBar(name))
return True

View File

@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING, Callable, Optional
import torch
from PIL import Image
@@ -13,8 +13,36 @@ if TYPE_CHECKING:
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.shared.invocation_context import InvocationContextData
# fast latents preview matrix for sdxl
# generated by @StAlKeR7779
SDXL_LATENT_RGB_FACTORS = [
# R G B
[0.3816, 0.4930, 0.5320],
[-0.3753, 0.1631, 0.1739],
[0.1770, 0.3588, -0.2048],
[-0.4350, -0.2644, -0.4289],
]
SDXL_SMOOTH_MATRIX = [
[0.0358, 0.0964, 0.0358],
[0.0964, 0.4711, 0.0964],
[0.0358, 0.0964, 0.0358],
]
def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=None):
# origingally adapted from code by @erucipe and @keturn here:
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
# these updated numbers for v1.5 are from @torridgristle
SD1_5_LATENT_RGB_FACTORS = [
# R G B
[0.3444, 0.1385, 0.0670], # L1
[0.1247, 0.4027, 0.1494], # L2
[-0.3192, 0.2513, 0.2103], # L3
[-0.1307, -0.1874, -0.7445], # L4
]
def sample_to_lowres_estimated_image(
samples: torch.Tensor, latent_rgb_factors: torch.Tensor, smooth_matrix: Optional[torch.Tensor] = None
):
latent_image = samples[0].permute(1, 2, 0) @ latent_rgb_factors
if smooth_matrix is not None:
@@ -47,64 +75,12 @@ def stable_diffusion_step_callback(
else:
sample = intermediate_state.latents
# TODO: This does not seem to be needed any more?
# # txt2img provides a Tensor in the step_callback
# # img2img provides a PipelineIntermediateState
# if isinstance(sample, PipelineIntermediateState):
# # this was an img2img
# print('img2img')
# latents = sample.latents
# step = sample.step
# else:
# print('txt2img')
# latents = sample
# step = intermediate_state.step
# TODO: only output a preview image when requested
if base_model in [BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner]:
# fast latents preview matrix for sdxl
# generated by @StAlKeR7779
sdxl_latent_rgb_factors = torch.tensor(
[
# R G B
[0.3816, 0.4930, 0.5320],
[-0.3753, 0.1631, 0.1739],
[0.1770, 0.3588, -0.2048],
[-0.4350, -0.2644, -0.4289],
],
dtype=sample.dtype,
device=sample.device,
)
sdxl_smooth_matrix = torch.tensor(
[
[0.0358, 0.0964, 0.0358],
[0.0964, 0.4711, 0.0964],
[0.0358, 0.0964, 0.0358],
],
dtype=sample.dtype,
device=sample.device,
)
sdxl_latent_rgb_factors = torch.tensor(SDXL_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
sdxl_smooth_matrix = torch.tensor(SDXL_SMOOTH_MATRIX, dtype=sample.dtype, device=sample.device)
image = sample_to_lowres_estimated_image(sample, sdxl_latent_rgb_factors, sdxl_smooth_matrix)
else:
# origingally adapted from code by @erucipe and @keturn here:
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
# these updated numbers for v1.5 are from @torridgristle
v1_5_latent_rgb_factors = torch.tensor(
[
# R G B
[0.3444, 0.1385, 0.0670], # L1
[0.1247, 0.4027, 0.1494], # L2
[-0.3192, 0.2513, 0.2103], # L3
[-0.1307, -0.1874, -0.7445], # L4
],
dtype=sample.dtype,
device=sample.device,
)
v1_5_latent_rgb_factors = torch.tensor(SD1_5_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
image = sample_to_lowres_estimated_image(sample, v1_5_latent_rgb_factors)
(width, height) = image.size
@@ -113,15 +89,9 @@ def stable_diffusion_step_callback(
dataURL = image_to_dataURL(image, image_format="JPEG")
events.emit_generator_progress(
queue_id=context_data.queue_item.queue_id,
queue_item_id=context_data.queue_item.item_id,
queue_batch_id=context_data.queue_item.batch_id,
graph_execution_state_id=context_data.queue_item.session_id,
node_id=context_data.invocation.id,
source_node_id=context_data.source_invocation_id,
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
step=intermediate_state.step,
order=intermediate_state.order,
total_steps=intermediate_state.total_steps,
events.emit_invocation_denoise_progress(
context_data.queue_item,
context_data.invocation,
intermediate_state,
ProgressImage(dataURL=dataURL, width=width, height=height),
)

View File

@@ -1,5 +1,5 @@
import pathlib
from typing import Literal, Union
from pathlib import Path
from typing import Literal
import cv2
import numpy as np
@@ -10,28 +10,17 @@ from PIL import Image
from torchvision.transforms import Compose
from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
config = get_config()
logger = InvokeAILogger.get_logger(config=config)
DEPTH_ANYTHING_MODELS = {
"large": {
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true",
"local": "any/annotators/depth_anything/depth_anything_vitl14.pth",
},
"base": {
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true",
"local": "any/annotators/depth_anything/depth_anything_vitb14.pth",
},
"small": {
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true",
"local": "any/annotators/depth_anything/depth_anything_vits14.pth",
},
"large": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true",
"base": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true",
"small": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true",
}
@@ -53,36 +42,27 @@ transform = Compose(
class DepthAnythingDetector:
def __init__(self) -> None:
self.model = None
self.model_size: Union[Literal["large", "base", "small"], None] = None
self.device = TorchDevice.choose_torch_device()
def __init__(self, model: DPT_DINOv2, device: torch.device) -> None:
self.model = model
self.device = device
def load_model(self, model_size: Literal["large", "base", "small"] = "small"):
DEPTH_ANYTHING_MODEL_PATH = config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"]
download_with_progress_bar(
pathlib.Path(DEPTH_ANYTHING_MODELS[model_size]["url"]).name,
DEPTH_ANYTHING_MODELS[model_size]["url"],
DEPTH_ANYTHING_MODEL_PATH,
)
@staticmethod
def load_model(
model_path: Path, device: torch.device, model_size: Literal["large", "base", "small"] = "small"
) -> DPT_DINOv2:
match model_size:
case "small":
model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
case "base":
model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
case "large":
model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
if not self.model or model_size != self.model_size:
del self.model
self.model_size = model_size
model.load_state_dict(torch.load(model_path.as_posix(), map_location="cpu"))
model.eval()
match self.model_size:
case "small":
self.model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
case "base":
self.model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
case "large":
self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu"))
self.model.eval()
self.model.to(self.device)
return self.model
model.to(device)
return model
def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image:
if not self.model:

View File

@@ -1,30 +1,53 @@
from pathlib import Path
from typing import Dict
import numpy as np
import torch
from controlnet_aux.util import resize_image
from PIL import Image
from invokeai.backend.image_util.dw_openpose.utils import draw_bodypose, draw_facepose, draw_handpose
from invokeai.backend.image_util.dw_openpose.utils import NDArrayInt, draw_bodypose, draw_facepose, draw_handpose
from invokeai.backend.image_util.dw_openpose.wholebody import Wholebody
DWPOSE_MODELS = {
"yolox_l.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true",
"dw-ll_ucoco_384.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true",
}
def draw_pose(pose, H, W, draw_face=True, draw_body=True, draw_hands=True, resolution=512):
def draw_pose(
pose: Dict[str, NDArrayInt | Dict[str, NDArrayInt]],
H: int,
W: int,
draw_face: bool = True,
draw_body: bool = True,
draw_hands: bool = True,
resolution: int = 512,
) -> Image.Image:
bodies = pose["bodies"]
faces = pose["faces"]
hands = pose["hands"]
assert isinstance(bodies, dict)
candidate = bodies["candidate"]
assert isinstance(bodies, dict)
subset = bodies["subset"]
canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
if draw_body:
canvas = draw_bodypose(canvas, candidate, subset)
if draw_hands:
assert isinstance(hands, np.ndarray)
canvas = draw_handpose(canvas, hands)
if draw_face:
canvas = draw_facepose(canvas, faces)
assert isinstance(hands, np.ndarray)
canvas = draw_facepose(canvas, faces) # type: ignore
dwpose_image = resize_image(
dwpose_image: Image.Image = resize_image(
canvas,
resolution,
)
@@ -39,11 +62,16 @@ class DWOpenposeDetector:
Credits: https://github.com/IDEA-Research/DWPose
"""
def __init__(self) -> None:
self.pose_estimation = Wholebody()
def __init__(self, onnx_det: Path, onnx_pose: Path) -> None:
self.pose_estimation = Wholebody(onnx_det=onnx_det, onnx_pose=onnx_pose)
def __call__(
self, image: Image.Image, draw_face=False, draw_body=True, draw_hands=False, resolution=512
self,
image: Image.Image,
draw_face: bool = False,
draw_body: bool = True,
draw_hands: bool = False,
resolution: int = 512,
) -> Image.Image:
np_image = np.array(image)
H, W, C = np_image.shape
@@ -79,3 +107,6 @@ class DWOpenposeDetector:
return draw_pose(
pose, H, W, draw_face=draw_face, draw_hands=draw_hands, draw_body=draw_body, resolution=resolution
)
__all__ = ["DWPOSE_MODELS", "DWOpenposeDetector"]

View File

@@ -5,11 +5,13 @@ import math
import cv2
import matplotlib
import numpy as np
import numpy.typing as npt
eps = 0.01
NDArrayInt = npt.NDArray[np.uint8]
def draw_bodypose(canvas, candidate, subset):
def draw_bodypose(canvas: NDArrayInt, candidate: NDArrayInt, subset: NDArrayInt) -> NDArrayInt:
H, W, C = canvas.shape
candidate = np.array(candidate)
subset = np.array(subset)
@@ -88,7 +90,7 @@ def draw_bodypose(canvas, candidate, subset):
return canvas
def draw_handpose(canvas, all_hand_peaks):
def draw_handpose(canvas: NDArrayInt, all_hand_peaks: NDArrayInt) -> NDArrayInt:
H, W, C = canvas.shape
edges = [
@@ -142,7 +144,7 @@ def draw_handpose(canvas, all_hand_peaks):
return canvas
def draw_facepose(canvas, all_lmks):
def draw_facepose(canvas: NDArrayInt, all_lmks: NDArrayInt) -> NDArrayInt:
H, W, C = canvas.shape
for lmks in all_lmks:
lmks = np.array(lmks)

View File

@@ -2,47 +2,26 @@
# Modified pathing to suit Invoke
from pathlib import Path
import numpy as np
import onnxruntime as ort
from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.util.devices import TorchDevice
from .onnxdet import inference_detector
from .onnxpose import inference_pose
DWPOSE_MODELS = {
"yolox_l.onnx": {
"local": "any/annotators/dwpose/yolox_l.onnx",
"url": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true",
},
"dw-ll_ucoco_384.onnx": {
"local": "any/annotators/dwpose/dw-ll_ucoco_384.onnx",
"url": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true",
},
}
config = get_config()
class Wholebody:
def __init__(self):
def __init__(self, onnx_det: Path, onnx_pose: Path):
device = TorchDevice.choose_torch_device()
providers = ["CUDAExecutionProvider"] if device.type == "cuda" else ["CPUExecutionProvider"]
DET_MODEL_PATH = config.models_path / DWPOSE_MODELS["yolox_l.onnx"]["local"]
download_with_progress_bar("yolox_l.onnx", DWPOSE_MODELS["yolox_l.onnx"]["url"], DET_MODEL_PATH)
POSE_MODEL_PATH = config.models_path / DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]["local"]
download_with_progress_bar(
"dw-ll_ucoco_384.onnx", DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]["url"], POSE_MODEL_PATH
)
onnx_det = DET_MODEL_PATH
onnx_pose = POSE_MODEL_PATH
self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers)
self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers)

View File

@@ -1,4 +1,4 @@
import gc
from pathlib import Path
from typing import Any
import numpy as np
@@ -6,9 +6,7 @@ import torch
from PIL import Image
import invokeai.backend.util.logging as logger
from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.model_manager.config import AnyModel
def norm_img(np_img):
@@ -19,28 +17,11 @@ def norm_img(np_img):
return np_img
def load_jit_model(url_or_path, device):
model_path = url_or_path
logger.info(f"Loading model from: {model_path}")
model = torch.jit.load(model_path, map_location="cpu").to(device)
model.eval()
return model
class LaMA:
def __init__(self, model: AnyModel):
self._model = model
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
device = TorchDevice.choose_torch_device()
model_location = get_config().models_path / "core/misc/lama/lama.pt"
if not model_location.exists():
download_with_progress_bar(
name="LaMa Inpainting Model",
url="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
dest_path=model_location,
)
model = load_jit_model(model_location, device)
image = np.asarray(input_image.convert("RGB"))
image = norm_img(image)
@@ -48,20 +29,25 @@ class LaMA:
mask = np.asarray(mask)
mask = np.invert(mask)
mask = norm_img(mask)
mask = (mask > 0) * 1
device = next(self._model.buffers()).device
image = torch.from_numpy(image).unsqueeze(0).to(device)
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
with torch.inference_mode():
infilled_image = model(image, mask)
infilled_image = self._model(image, mask)
infilled_image = infilled_image[0].permute(1, 2, 0).detach().cpu().numpy()
infilled_image = np.clip(infilled_image * 255, 0, 255).astype("uint8")
infilled_image = Image.fromarray(infilled_image)
del model
gc.collect()
torch.cuda.empty_cache()
return infilled_image
@staticmethod
def load_jit_model(url_or_path: str | Path, device: torch.device | str = "cpu") -> torch.nn.Module:
model_path = url_or_path
logger.info(f"Loading model from: {model_path}")
model: torch.nn.Module = torch.jit.load(model_path, map_location="cpu").to(device) # type: ignore
model.eval()
return model

View File

@@ -1,6 +1,5 @@
import math
from enum import Enum
from pathlib import Path
from typing import Any, Optional
import cv2
@@ -11,6 +10,7 @@ from cv2.typing import MatLike
from tqdm import tqdm
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
from invokeai.backend.model_manager.config import AnyModel
from invokeai.backend.util.devices import TorchDevice
"""
@@ -52,7 +52,7 @@ class RealESRGAN:
def __init__(
self,
scale: int,
model_path: Path,
loadnet: AnyModel,
model: RRDBNet,
tile: int = 0,
tile_pad: int = 10,
@@ -67,8 +67,6 @@ class RealESRGAN:
self.half = half
self.device = TorchDevice.choose_torch_device()
loadnet = torch.load(model_path, map_location=torch.device("cpu"))
# prefer to use params_ema
if "params_ema" in loadnet:
keyname = "params_ema"

View File

@@ -125,13 +125,16 @@ class IPAdapter(RawModel):
self.device, dtype=self.dtype
)
def to(self, device: torch.device, dtype: Optional[torch.dtype] = None):
self.device = device
def to(
self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, non_blocking: bool = False
):
if device is not None:
self.device = device
if dtype is not None:
self.dtype = dtype
self._image_proj_model.to(device=self.device, dtype=self.dtype)
self.attn_weights.to(device=self.device, dtype=self.dtype)
self._image_proj_model.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
self.attn_weights.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
def calc_size(self):
# workaround for circular import

View File

@@ -61,9 +61,10 @@ class LoRALayerBase:
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
if self.bias is not None:
self.bias = self.bias.to(device=device, dtype=dtype)
self.bias = self.bias.to(device=device, dtype=dtype, non_blocking=non_blocking)
# TODO: find and debug lora/locon with bias
@@ -109,14 +110,15 @@ class LoRALayer(LoRALayerBase):
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
super().to(device=device, dtype=dtype)
super().to(device=device, dtype=dtype, non_blocking=non_blocking)
self.up = self.up.to(device=device, dtype=dtype)
self.down = self.down.to(device=device, dtype=dtype)
self.up = self.up.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.down = self.down.to(device=device, dtype=dtype, non_blocking=non_blocking)
if self.mid is not None:
self.mid = self.mid.to(device=device, dtype=dtype)
self.mid = self.mid.to(device=device, dtype=dtype, non_blocking=non_blocking)
class LoHALayer(LoRALayerBase):
@@ -169,18 +171,19 @@ class LoHALayer(LoRALayerBase):
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
super().to(device=device, dtype=dtype)
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
self.w1_a = self.w1_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.w1_b = self.w1_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
if self.t1 is not None:
self.t1 = self.t1.to(device=device, dtype=dtype)
self.t1 = self.t1.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
self.w2_a = self.w2_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.w2_b = self.w2_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)
self.t2 = self.t2.to(device=device, dtype=dtype, non_blocking=non_blocking)
class LoKRLayer(LoRALayerBase):
@@ -265,6 +268,7 @@ class LoKRLayer(LoRALayerBase):
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
super().to(device=device, dtype=dtype)
@@ -273,19 +277,19 @@ class LoKRLayer(LoRALayerBase):
else:
assert self.w1_a is not None
assert self.w1_b is not None
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
self.w1_a = self.w1_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.w1_b = self.w1_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
if self.w2 is not None:
self.w2 = self.w2.to(device=device, dtype=dtype)
self.w2 = self.w2.to(device=device, dtype=dtype, non_blocking=non_blocking)
else:
assert self.w2_a is not None
assert self.w2_b is not None
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
self.w2_a = self.w2_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.w2_b = self.w2_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)
self.t2 = self.t2.to(device=device, dtype=dtype, non_blocking=non_blocking)
class FullLayer(LoRALayerBase):
@@ -319,10 +323,11 @@ class FullLayer(LoRALayerBase):
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
class IA3Layer(LoRALayerBase):
@@ -358,11 +363,12 @@ class IA3Layer(LoRALayerBase):
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
):
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
self.on_input = self.on_input.to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.on_input = self.on_input.to(device=device, dtype=dtype, non_blocking=non_blocking)
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
@@ -388,10 +394,11 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
# TODO: try revert if exception?
for _key, layer in self.layers.items():
layer.to(device=device, dtype=dtype)
layer.to(device=device, dtype=dtype, non_blocking=non_blocking)
def calc_size(self) -> int:
model_size = 0
@@ -514,7 +521,7 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
# lower memory consumption by removing already parsed layer values
state_dict[layer_key].clear()
layer.to(device=device, dtype=dtype)
layer.to(device=device, dtype=dtype, non_blocking=True)
model.layers[layer_key] = layer
return model

View File

@@ -0,0 +1,24 @@
import json
from base64 import b64decode
def validate_hash(hash: str):
if ":" not in hash:
return
for enc_hash in hashes:
alg, hash_ = hash.split(":")
if alg == "blake3":
alg = "blake3_single"
map = json.loads(b64decode(enc_hash))
if alg in map:
if hash_ == map[alg]:
raise Exception("Unrecoverable Model Error")
hashes: list[str] = [
"eyJibGFrZTNfbXVsdGkiOiI3Yjc5ODZmM2QyNTk3MDZiMjVhZDRhM2NmNGM2MTcyNGNhZmQ0Yjc4NjI4MjIwNjMyZGU4NjVlM2UxNDEyMTVlIiwiYmxha2UzX3NpbmdsZSI6IjdiNzk4NmYzZDI1OTcwNmIyNWFkNGEzY2Y0YzYxNzI0Y2FmZDRiNzg2MjgyMjA2MzJkZTg2NWUzZTE0MTIxNWUiLCJyYW5kb20iOiJhNDQxYjE1ZmU5YTNjZjU2NjYxMTkwYTBiOTNiOWRlYzdkMDQxMjcyODhjYzg3MjUwOTY3Y2YzYjUyODk0ZDExIiwibWQ1IjoiNzdlZmU5MzRhZGQ3YmU5Njc3NmJkODM3NWJhZDQxN2QiLCJzaGExIjoiYmM2YzYxYzgwNDgyMTE2ZTY2ZGQyNTYwNjRkYTgxYjFlY2U4NzMzOCIsInNoYTIyNCI6IjgzNzNlZGM4ZTg4Y2UxMTljODdlOTM2OTY4ZWViMWNmMzdjZGY4NTBmZjhjOTZkYjNmMDc4YmE0Iiwic2hhMjU2IjoiNzNjYWMxZWRlZmUyZjdlODFkNjRiMTI2YjIxMmY2Yzk2ZTAwNjgyNGJjZmJkZDI3Y2E5NmUyNTk5ZTQwNzUwZiIsInNoYTM4NCI6IjlmNmUwNzlmOTNiNDlkMTg1YzEyNzY0OGQwNzE3YTA0N2E3MzYyNDI4YzY4MzBhNDViNzExODAwZDE4NjIwZDZjMjcwZGE3ZmY0Y2FjOTRmNGVmZDdiZWQ5OTlkOWU0ZCIsInNoYTUxMiI6IjAwNzE5MGUyYjk5ZjVlN2Q1OGZiYWI2YTk1YmY0NjJiODhkOTg1N2NlNjY4MTMyMGJmM2M0Y2ZiZmY0MjkxZmEzNTMyMTk3YzdkODc2YWQ3NjZhOTQyOTQ2Zjc1OWY2YTViNDBlM2I2MzM3YzIwNWI0M2JkOWMyN2JiMTljNzk0IiwiYmxha2UyYiI6IjlhN2VhNTQzY2ZhMmMzMWYyZDIyNjg2MjUwNzUyNDE0Mjc1OWJiZTA0MWZlMWJkMzQzNDM1MWQwNWZlYjI2OGY2MjU0OTFlMzlmMzdkYWQ4MGM2Y2UzYTE4ZjAxNGEzZjJiMmQ2OGU2OTc0MjRmNTU2M2Y5ZjlhYzc1MzJiMjEwIiwiYmxha2UycyI6ImYxZmMwMjA0YjdjNzIwNGJlNWI1YzY3NDEyYjQ2MjY5NWE3YjFlYWQ2M2E5ZGVkMjEzYjZmYTU0NGZjNjJlYzUiLCJzaGEzXzIyNCI6IjljZDQ3YTBhMzA3NmNmYzI0NjJhNTAzMjVmMjg4ZjFiYzJjMmY2NmU2ODIxODc5NjJhNzU0NjFmIiwic2hhM18yNTYiOiI4NTFlNGI1ZDI1MWZlZTFiYzk0ODU1OWNjMDNiNjhlNTllYWU5YWI1ZTUyYjA0OTgxYTRhOTU4YWQyMDdkYjYwIiwic2hhM18zODQiOiJiZDA2ZTRhZGFlMWQ0MTJmZjFjOTcxMDJkZDFlN2JmY2UzMDViYTgxMTgyNzM3NWY5NTI4OWJkOGIyYTUxNjdiMmUyNzZjODNjNTU3ODFhMTEyMDRhNzc5MTUwMzM5ZTEiLCJzaGEzXzUxMiI6ImQ1ZGQ2OGZmZmY5NGRhZjJhMDkzZTliNmM1MTBlZmZkNThmZTA0ODMyZGQzMzEyOTZmN2NkZmYzNmRhZmQ3NGMxY2VmNjUxNTBkZjk5OGM1ODgyY2MzMzk2MTk1ZTViYjc5OTY1OGFkMTQ3MzFiMjJmZWZiMWQzNmY2MWJjYzJjIiwic2hha2VfMTI4IjoiOWJlNTgwNWMwNjg1MmZmNDUzNGQ4ZDZmODYyMmFkOTJkMGUwMWE2Y2JmYjIwN2QxOTRmM2JkYThiOGNmNWU4ZiIsInNoYWtlXzI1NiI6IjRhYjgwYjY2MzcxYzdhNjBhYWM4NDVkMTZlNWMzZDNhMmM4M2FjM2FjZDNiNTBiNzdjYWYyYTNmMWMyY2ZjZjc5OGNjYjkxN2FjZjQzNzBmZDdjN2ZmODQ5M2Q3NGY1MWM4NGU3M2ViZGQ4MTRmM2MwMzk3YzI4ODlmNTI0Mzg3In0K",
"eyJibGFrZTNfbXVsdGkiOiI4ODlmYzIwMDA4NWY1NWY4YTA4MjhiODg3MDM0OTRhMGFmNWZkZGI5N2E2YmYwMDRjM2VkYTdiYzBkNDU0MjQzIiwiYmxha2UzX3NpbmdsZSI6Ijg4OWZjMjAwMDg1ZjU1ZjhhMDgyOGI4ODcwMzQ5NGEwYWY1ZmRkYjk3YTZiZjAwNGMzZWRhN2JjMGQ0NTQyNDMiLCJyYW5kb20iOiJhNDQxYjE1ZmU5YTNjZjU2NjYxMTkwYTBiOTNiOWRlYzdkMDQxMjcyODhjYzg3MjUwOTY3Y2YzYjUyODk0ZDExIiwibWQ1IjoiNTIzNTRhMzkzYTVmOGNjNmMyMzQ0OThiYjcxMDljYzEiLCJzaGExIjoiMTJmYmRhOGE3ZGUwOGMwNDc2NTA5OWY2NGNmMGIzYjcxMjc1MGM1NyIsInNoYTIyNCI6IjEyZWU3N2U0Y2NhODViMDk4YjdjNWJlMWFjNGMwNzljNGM3MmJmODA2YjdlZjU1NGI0NzgxZDkxIiwic2hhMjU2IjoiMjU1NTMwZDAyYTY4MjY4OWE5ZTZjMjRhOWZhMDM2OGNhODMxZTI1OTAyYjM2NzQyNzkwZTk3NzU1ZjEzMmNmNSIsInNoYTM4NCI6IjhkMGEyMTRlNDk0NGE2NGY3ZmZjNTg3MGY0ZWUyZTA0OGIzYjRjMmQ0MGRmMWFmYTVlOGE1ZWNkN2IwOTY3M2ZjNWI5YzM5Yzg4Yjc2YmIwY2I4ZjQ1ZjAxY2MwNjZkNCIsInNoYTUxMiI6Ijg3NTM3OWNiYzdlOGYyNzU4YjVjMDY5ZTU2ZWRjODY1ODE4MGFkNDEzNGMwMzY1NzM4ZjM1YjQwYzI2M2JkMTMwMzcwZTE0MzZkNDNmOGFhMTgyMTg5MzgzMTg1ODNhOWJhYTUyYTBjMTk1Mjg5OTQzYzZiYTY2NTg1Yjg5M2ZiIiwiYmxha2UyYiI6IjBhY2MwNWEwOGE5YjhhODNmZTVjYTk4ZmExMTg3NTYwNjk0MjY0YWUxNTI4NDliYzFkNzQzNTYzMzMyMTlhYTg3N2ZiNjc4MmRjZDZiOGIyYjM1MTkyNDQzNDE2ODJiMTQ3YmY2YTY3MDU2ZWIwOTQ4MzE1M2E4Y2ZiNTNmMTI0IiwiYmxha2UycyI6ImY5ZTRhZGRlNGEzZDRhOTZhOWUyNjVjMGVmMjdmZDNiNjA0NzI1NDllMTEyMWQzOGQwMTkxNTY5ZDY5YzdhYzAiLCJzaGEzXzIyNCI6ImM0NjQ3MGRjMjkyNGI0YjZkMTA2NDY5MDRiNWM2OGVjNTU2YmQ4MTA5NmVkMTA4YjZiMzQyZmU1Iiwic2hhM18yNTYiOiIwMDBlMThiZTI1MzYxYTk0NGExZTIwNjQ5ZmY0ZGM2OGRiZTk0OGNkNTYwY2I5MTFhODU1OTE3ODdkNWQ5YWYwIiwic2hhM18zODQiOiIzNDljZmVhMGUxZGE0NWZlMmYzNjJhMWFjZjI1ZTczOWNiNGQ0NDdiM2NiODUzZDVkYWNjMzU5ZmRhMWE1M2FhYWU5OTM2ZmFhZWM1NmFhZDkwMThhYjgxMTI4ZjI3N2YiLCJzaGEzXzUxMiI6ImMxNDgwNGY1YTNjNWE4ZGEyMTAyODk1YTFjZGU4MmIwNGYwZmY4OTczMTc0MmY2NDQyY2NmNzQ1OTQzYWQ5NGViOWZmMTNhZDg3YjRmODkxN2M5NmY5ZjMwZjkwYTFhYTI4OTI3OTkwMjg0ZDJhMzcyMjA0NjE4MTNiNDI0MzEyIiwic2hha2VfMTI4IjoiN2IxY2RkMWUyMzUzMzk0OTg5M2UyMmZkMTAwZmU0YjJhMTU1MDJmMTNjMTI0YzhiZDgxY2QwZDdlOWEzMGNmOCIsInNoYWtlXzI1NiI6ImI0NjMzZThhMjNkZDM0ODk0ZTIyNzc0ODYyNTE1MzVjYWFlNjkyMTdmOTQ0NTc3MzE1NTljODBjNWQ3M2ZkOTMxZTFjMDJlZDI0Yjc3MzE3OTJjMjVlNTZhYjg3NjI4YmJiMDgxNTU0MjU2MWY5ZGI2NWE0NDk4NDFmNGQzYTU4In0K",
"eyJibGFrZTNfbXVsdGkiOiI2Y2M0MmU4NGRiOGQyZTliYjA4YjUxNWUwYzlmYzg2NTViNDUwNGRlZDM1MzBlZjFjNTFjZWEwOWUxYThiNGYxIiwiYmxha2UzX3NpbmdsZSI6IjZjYzQyZTg0ZGI4ZDJlOWJiMDhiNTE1ZTBjOWZjODY1NWI0NTA0ZGVkMzUzMGVmMWM1MWNlYTA5ZTFhOGI0ZjEiLCJyYW5kb20iOiJhNDQxYjE1ZmU5YTNjZjU2NjYxMTkwYTBiOTNiOWRlYzdkMDQxMjcyODhjYzg3MjUwOTY3Y2YzYjUyODk0ZDExIiwibWQ1IjoiZDQwNjk3NTJhYjQ0NzFhZDliMDY3YmUxMmRjNTM2ZjYiLCJzaGExIjoiOGRjZmVlMjZjZjUyOTllMDBjN2QwZjJiZTc0NmVmMTlkZjliZGExNCIsInNoYTIyNCI6IjhjMzAzOTU3ZjI3NDNiMjUwNmQyYzIzY2VmNmU4MTQ5MTllZmE2MWM0MTFiMDk5ZmMzODc2MmRjIiwic2hhMjU2IjoiZDk3ZjQ2OWJjMWZkMjhjMjZkMjJhN2Y3ODczNzlhZmM4NjY3ZmZmM2FhYTQ5NTE4NmQyZTM4OTU2MTBjZDJmMyIsInNoYTM4NCI6IjY0NmY0YWM0ZDA2YWJkZmE2MDAwN2VjZWNiOWNjOTk4ZmJkOTBiYzYwMmY3NTk2M2RhZDUzMGMzNGE5ZGE1YzY4NjhlMGIwMDJkZDNlMTM4ZjhmMjA2ODcyNzFkMDVjMSIsInNoYTUxMiI6ImYzZTU4NTA0YzYyOGUwYjViNzBhOTYxYThmODA1MDA1NjQ1M2E5NDlmNTgzNDhiYTNhZTVlMjdkNDRhNGJkMjc5ZjA3MmU1OGQ5YjEyOGE1NDc1MTU2ZmM3YzcxMGJkYjI3OWQ5OGFmN2EwYTI4Y2Y1ZDY2MmQxODY4Zjg3ZjI3IiwiYmxha2UyYiI6ImFhNjgyYmJjM2U1ZGRjNDZkNWUxN2VjMzRlNmEzZGY5ZjhiNWQyNzk0YTZkNmY0M2VjODMxZjhjOTU2OGYyY2RiOGE4YjAyNTE4MDA4YmY0Y2FhYTlhY2FhYjNkNzRmZmRiNGZlNDgwOTcwODU3OGJiZjNlNzJjYTc5ZDQwYzZmIiwiYmxha2UycyI6ImQ0ZGJlZTJkMmZlNDMwOGViYTkwMTY1MDdmMzI1ZmJiODZlMWQzNDQ0MjgzNzRlMjAwNjNiNWQ1MzkzZTExNjMiLCJzaGEzXzIyNCI6ImE1ZTM5NWZlNGRlYjIyY2JhNjgwMWFiZTliZjljMjM2YmMzYjkwZDdiN2ZjMTRhZDhjZjQ0NzBlIiwic2hhM18yNTYiOiIwOWYwZGVjODk0OWEzYmQzYzU3N2RjYzUyMTMwMGRiY2UwMjVjM2VjOTJkNzQ0MDJkNTE1ZDA4NTQwODg2NGY1Iiwic2hhM18zODQiOiJmMjEyNmM5NTcxODQ3NDZmNjYyMjE4MTRkMDZkZWQ3NDBhYWU3MDA4MTc0YjI0OTEzY2YwOTQzY2IwMTA5Y2QxNWI4YmMwOGY1YjUwMWYwYzhhOTY4MzUwYzgzY2I1ZWUiLCJzaGEzXzUxMiI6ImU1ZmEwMzIwMzk2YTJjMThjN2UxZjVlZmJiODYwYTU1M2NlMTlkMDQ0MWMxNWEwZTI1M2RiNjJkM2JmNjg0ZDI1OWIxYmQ4OTJkYTcyMDVjYTYyODQ2YzU0YWI1ODYxOTBmNDUxZDlmZmNkNDA5YmU5MzlhNWM1YWIyZDdkM2ZkIiwic2hha2VfMTI4IjoiNGI2MTllM2I4N2U1YTY4OTgxMjk0YzgzMmU0NzljZGI4MWFmODdlZTE4YzM1Zjc5ZjExODY5ZWEzNWUxN2I3MiIsInNoYWtlXzI1NiI6ImYzOWVkNmMxZmQ2NzVmMDg3ODAyYTc4ZTUwYWFkN2ZiYTZiM2QxNzhlZWYzMjRkMTI3ZTZjYmEwMGRjNzkwNTkxNjQ1Y2U1Y2NmMjhjYzVkNWRkODU1OWIzMDMxYTM3ZjE5NjhmYmFhNDQzMmI2ZWU0Yzg3ZWE2YTdkMmE2NWM2In0K",
"eyJibGFrZTNfbXVsdGkiOiJhNDRiZjJkMzVkZDI3OTZlZTI1NmY0MzVkODFhNTdhOGM0MjZhMzM5ZDc3NTVkMmNiMjdmMzU4ZjM0NTM4OWM2IiwiYmxha2UzX3NpbmdsZSI6ImE0NGJmMmQzNWRkMjc5NmVlMjU2ZjQzNWQ4MWE1N2E4YzQyNmEzMzlkNzc1NWQyY2IyN2YzNThmMzQ1Mzg5YzYiLCJyYW5kb20iOiJhNDQxYjE1ZmU5YTNjZjU2NjYxMTkwYTBiOTNiOWRlYzdkMDQxMjcyODhjYzg3MjUwOTY3Y2YzYjUyODk0ZDExIiwibWQ1IjoiOGU5OTMzMzEyZjg4NDY4MDg0ZmRiZWNjNDYyMTMxZTgiLCJzaGExIjoiNmI0MmZjZDFmMmQyNzUwYWNkY2JkMTUzMmQ4NjQ5YTM1YWI2NDYzNCIsInNoYTIyNCI6ImQ2Y2E2OTUxNzIzZjdjZjg0NzBjZWRjMmVhNjA2ODNmMWU4NDMzM2Q2NDM2MGIzOWIyMjZlZmQzIiwic2hhMjU2IjoiMDAxNGY5Yzg0YjcwMTFhMGJkNzliNzU0NGVjNzg4NDQzNWQ4ZGY0NmRjMDBiNDk0ZmFkYzA4NWQzNDM1NjI4MyIsInNoYTM4NCI6IjMxODg2OTYxODc4NWY3MWJlM2RlZjkyZDgyNzY2NjBhZGE0MGViYTdkMDk1M2Y0YTc5ODdlMThhNzFlNjBlY2EwY2YyM2YwMjVhMmQ4ZjUyMmNkZGY3MTcxODFhMTQxNSIsInNoYTUxMiI6IjdmZGQxN2NmOWU3ZTBhZDcwMzJjMDg1MTkyYWMxZmQ0ZmFhZjZkNWNlYzAzOTE5ZDk0MmZiZTIyNWNhNmIwZTg0NmQ4ZGI0ZjllYTQ5MjJlMTdhNTg4MTY4YzExMTM1NWZiZDQ1NTlmMmU5NDcwNjAwZWE1MzBhMDdiMzY0YWQwIiwiYmxha2UyYiI6IjI0ZjExZWI5M2VlN2YxOTI5NWZiZGU5MTczMmE0NGJkZGYxOWE1ZTQ4MWNmOWFhMjQ2M2UzNDllYjg0Mzc4ZDBkODFjNzY0YWQ1NTk1YjkxZjQzYzgxODcxNTRlYWU5NTZkY2ZjZTlkMWU2MTZjNTFkZThhZDZjZTBhODcyY2Q0IiwiYmxha2UycyI6IjVkZTUwZDUwMGYwYTBmOGRlMTEwOGE2ZmFkZGM4ODNlMTA3NmQ3MThiNmQxN2E4ZDVkMjgzZDdiNGYzZDU2OGEiLCJzaGEzXzIyNCI6IjFhNTA0OGNlYWZiYjg2ZDc4ZmNiNTI0ZTViYTc4NWQ2ZmY5NzY1ZTNlMzdhZWRjZmYxZGVjNGJhIiwic2hhM18yNTYiOiI0YjA0YjE1NTRmMzRkYTlmMjBmZDczM2IzNDg4NjE0ZWNhM2IwOWU1OTJjOGJlMmM0NjA1NjYyMWU0MjJmZDllIiwic2hhM18zODQiOiI1NjMwYjM2OGQ4MGM1YmM5MTgzM2VmNWM2YWUzOTJhNDE4NTNjYmM2MWJiNTI4ZDE4YWM1OWFjZGZiZWU1YThkMWMyZDE4MTM1ZGI2ZWQ2OTJlODFkZThmYTM3MzkxN2MiLCJzaGEzXzUxMiI6IjA2ODg4MGE1MmNiNDkzODYwZDhjOTVhOTFhZGFmZTYwZGYxODc2ZDhjYjFhNmI3NTU2ZjJjM2Y1NjFmMGYwZjMyZjZhYTA1YmVmN2FhYjQ5OWEwNTM0Zjk0Njc4MDEzODlmNDc0ODFiNzcxMjdjMDFiOGFhOTY4NGJhZGUzYmY2Iiwic2hha2VfMTI4IjoiODlmYTdjNDcwNGI4NGZkMWQ1M2E0MTBlN2ZjMzU3NWRhNmUxMGU1YzkzMjM1NWYyZWEyMWM4NDVhZDBlM2UxOCIsInNoYWtlXzI1NiI6IjE4NGNlMWY2NjdmYmIyODA5NWJhZmVkZTQzNTUzZjhkYzBhNGY1MDQwYWJlMjcxMzkzMzcwNDEyZWFiZTg0ZGJhNjI0Y2ZiZWE4YzUxZDU2YzkwMTM2Mjg2ODgyZmQ0Y2E3MzA3NzZjNWUzODFlYzI5MWYxYTczOTE1MDkyMTFmIn0K",
"eyJibGFrZTNfbXVsdGkiOiJhYjA2YjNmMDliNTExOTAzMTMzMzY5NDE2MTc4ZDk2ZjlkYTc3ZGEwOTgyNDJmN2VlMTVjNTNhNTRkMDZhNWVmIiwiYmxha2UzX3NpbmdsZSI6ImFiMDZiM2YwOWI1MTE5MDMxMzMzNjk0MTYxNzhkOTZmOWRhNzdkYTA5ODI0MmY3ZWUxNWM1M2E1NGQwNmE1ZWYiLCJyYW5kb20iOiJhNDQxYjE1ZmU5YTNjZjU2NjYxMTkwYTBiOTNiOWRlYzdkMDQxMjcyODhjYzg3MjUwOTY3Y2YzYjUyODk0ZDExIiwibWQ1IjoiZWY0MjcxYjU3NTQwMjU4NGQ2OTI5ZWJkMGI3Nzk5NzYiLCJzaGExIjoiMzgzNzliYWQzZjZiZjc4MmM4OTgzOGY3YWVkMzRkNDNkMzNlYWM2MSIsInNoYTIyNCI6ImQ5ZDNiMjJkYmZlY2M1NTdlODAzNjg5M2M3ZWE0N2I0NTQzYzM2NzZhMDk4NzMxMzRhNjQ0OWEwIiwic2hhMjU2IjoiMjYxZGI3NmJlMGYxMzdlZWJkYmI5OGRlYWM0ZjcyMDdiOGUxMjdiY2MyZmMwODI5OGVjZDczYjQ3MjYxNjQ1NiIsInNoYTM4NCI6IjMzMjkwYWQxYjlhMmRkYmU0ODY3MWZiMTIxNDdiZWJhNjI4MjA1MDcwY2VkNjNiZTFmNGU5YWRhMjgwYWU2ZjZjNDkzYTY2MDllMGQ2YTIzMWU2ODU5ZmIyNGZhM2FjMCIsInNoYTUxMiI6IjAzMDZhMWI1NmNiYTdjNjJiNTNmNTk4MTAwMTQ3MDQ5ODBhNGRmZTdjZjQ5NTU4ZmMyMmQxZDczZDc5NzJmZTllODk2ZWRjMmEyYTQxYWVjNjRjZjkwZGUwYjI1NGM0MDBlZTU1YzcwZjk3OGVlMzk5NmM2YzhkNTBjYTI4YTdiIiwiYmxha2UyYiI6IjY1MDZhMDg1YWQ5MGZkZjk2NGJmMGE5NTFkZmVkMTllZTc0NGVjY2EyODQzZjQzYTI5NmFjZDM0M2RiODhhMDNlNTlkNmFmMGM1YWJkNTEzMzc4MTQ5Yjg3OTExMTVmODRmMDIyZWM1M2JmNGFjNDZhZDczNWIwMmJlYTM0MDk5IiwiYmxha2UycyI6IjdlZDQ3ZWQxOTg3MTk0YWFmNGIwMjQ3MWFkNTMyMmY3NTE3ZjI0OTcwMDc2Y2NmNDkzMWI0MzYxMDU1NzBlNDAiLCJzaGEzXzIyNCI6Ijk2MGM4MDExOTlhMGUzYWExNjdiNmU2MWVkMzE2ZDUzMDM2Yjk4M2UyOThkNWI5MjZmMDc3NDlhIiwic2hhM18yNTYiOiIzYzdmYWE1ZDE3Zjk2MGYxOTI2ZjNlNGIyZjc1ZjdiOWIyZDQ4NGFhNmEwM2ViOWNlMTI4NmM2OTE2YWEyM2RlIiwic2hhM18zODQiOiI5Y2Y0NDA1NWFjYzFlYjZmMDY1YjRjODcxYTYzNTM1MGE1ZjY0ODQwM2YwYTU0MWEzYzZhNjI3N2ViZjZmYTNjYmM1YmJiNjQwMDE4OGFlMWIxMTI2OGZmMDJiMzYzZDUiLCJzaGEzXzUxMiI6ImEyZDk3ZDRlYjYxM2UwZDViYTc2OTk2MzE2MzcxOGEwNDIxZDkxNTNiNjllYjM5MDRmZjI4ODRhZDdjNGJiYmIwNGY2Nzc1OTA1YmQxNGI2NTJmZTQ1Njg0YmI5MTQ3ZjBkYWViZjAxZjIzY2MzZDhkMjIzMTE0MGUzNjI4NTE5Iiwic2hha2VfMTI4IjoiNjkwMWMwYjg1MTg5ZTkyNTJiODI3MTc5NjE2MjRlMTM0MDQ1ZjlkMmI5MzM0MzVkM2Y0OThiZWIyN2Q3N2JiNSIsInNoYWtlXzI1NiI6ImIwMjA4ZTFkNDVjZWI0ODdiZDUwNzk3MWJiNWI3MjdjN2UyYmE3ZDliNWM2ZTEyYWE5YTNhOTY5YzcyNDRjODIwZDcyNDY1ODhlZWU3Yjk4ZWM1NzhjZWIxNjc3OTkxODljMWRkMmZkMmZmYWM4MWExZDAzZDFiNjMxOGRkMjBiIn0K",
]

View File

@@ -31,12 +31,13 @@ from typing_extensions import Annotated, Any, Dict
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_hash.hash_validator import validate_hash
from ..raw_model import RawModel
# ModelMixin is the base class for all diffusers and transformers models
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module]
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor]]
class InvalidModelConfigException(Exception):
@@ -115,7 +116,7 @@ class SchedulerPredictionType(str, Enum):
class ModelRepoVariant(str, Enum):
"""Various hugging face variants on the diffusers format."""
Default = "" # model files without "fp16" or other qualifier - empty str
Default = "" # model files without "fp16" or other qualifier
FP16 = "fp16"
FP32 = "fp32"
ONNX = "onnx"
@@ -448,4 +449,6 @@ class ModelConfigFactory(object):
model.key = key
if isinstance(model, CheckpointConfigBase) and timestamp is not None:
model.converted_at = timestamp
if model:
validate_hash(model.hash)
return model # type: ignore

View File

@@ -7,7 +7,7 @@ from importlib import import_module
from pathlib import Path
from .convert_cache.convert_cache_default import ModelConvertCache
from .load_base import LoadedModel, ModelLoaderBase
from .load_base import LoadedModel, LoadedModelWithoutConfig, ModelLoaderBase
from .load_default import ModelLoader
from .model_cache.model_cache_default import ModelCache
from .model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase
@@ -19,6 +19,7 @@ for module in loaders:
__all__ = [
"LoadedModel",
"LoadedModelWithoutConfig",
"ModelCache",
"ModelConvertCache",
"ModelLoaderBase",

View File

@@ -7,6 +7,7 @@ from pathlib import Path
from invokeai.backend.util import GIG, directory_size
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.backend.util.util import safe_filename
from .convert_cache_base import ModelConvertCacheBase
@@ -35,6 +36,7 @@ class ModelConvertCache(ModelConvertCacheBase):
def cache_path(self, key: str) -> Path:
"""Return the path for a model with the indicated key."""
key = safe_filename(self._cache_path, key)
return self._cache_path / key
def make_room(self, size: float) -> None:

View File

@@ -4,10 +4,13 @@ Base class for model loading in InvokeAI.
"""
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass
from logging import Logger
from pathlib import Path
from typing import Any, Optional
from typing import Any, Dict, Generator, Optional, Tuple
import torch
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager.config import (
@@ -20,10 +23,44 @@ from invokeai.backend.model_manager.load.model_cache.model_cache_base import Mod
@dataclass
class LoadedModel:
"""Context manager object that mediates transfer from RAM<->VRAM."""
class LoadedModelWithoutConfig:
"""
Context manager object that mediates transfer from RAM<->VRAM.
This is a context manager object that has two distinct APIs:
1. Older API (deprecated):
Use the LoadedModel object directly as a context manager.
It will move the model into VRAM (on CUDA devices), and
return the model in a form suitable for passing to torch.
Example:
```
loaded_model_= loader.get_model_by_key('f13dd932', SubModelType('vae'))
with loaded_model as vae:
image = vae.decode(latents)[0]
```
2. Newer API (recommended):
Call the LoadedModel's `model_on_device()` method in a
context. It returns a tuple consisting of a copy of
the model's state dict in CPU RAM followed by a copy
of the model in VRAM. The state dict is provided to allow
LoRAs and other model patchers to return the model to
its unpatched state without expensive copy and restore
operations.
Example:
```
loaded_model_= loader.get_model_by_key('f13dd932', SubModelType('vae'))
with loaded_model.model_on_device() as (state_dict, vae):
image = vae.decode(latents)[0]
```
The state_dict should be treated as a read-only object and
never modified. Also be aware that some loadable models do
not have a state_dict, in which case this value will be None.
"""
config: AnyModelConfig
_locker: ModelLockerBase
def __enter__(self) -> AnyModel:
@@ -35,12 +72,29 @@ class LoadedModel:
"""Context exit."""
self._locker.unlock()
@contextmanager
def model_on_device(self) -> Generator[Tuple[Optional[Dict[str, torch.Tensor]], AnyModel], None, None]:
"""Return a tuple consisting of the model's state dict (if it exists) and the locked model on execution device."""
locked_model = self._locker.lock()
try:
state_dict = self._locker.get_state_dict()
yield (state_dict, locked_model)
finally:
self._locker.unlock()
@property
def model(self) -> AnyModel:
"""Return the model without locking it."""
return self._locker.model
@dataclass
class LoadedModel(LoadedModelWithoutConfig):
"""Context manager object that mediates transfer from RAM<->VRAM."""
config: Optional[AnyModelConfig] = None
# TODO(MM2):
# Some "intermediary" subclasses in the ModelLoaderBase class hierarchy define methods that their subclasses don't
# know about. I think the problem may be related to this class being an ABC.

View File

@@ -16,7 +16,7 @@ from invokeai.backend.model_manager.config import DiffusersConfigBase, ModelType
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data, calc_model_size_by_fs
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.util.devices import TorchDevice
@@ -84,7 +84,7 @@ class ModelLoader(ModelLoaderBase):
except IndexError:
pass
cache_path: Path = self._convert_cache.cache_path(config.key)
cache_path: Path = self._convert_cache.cache_path(str(model_path))
if self._needs_conversion(config, model_path, cache_path):
loaded_model = self._do_convert(config, model_path, cache_path, submodel_type)
else:
@@ -95,7 +95,6 @@ class ModelLoader(ModelLoaderBase):
config.key,
submodel_type=submodel_type,
model=loaded_model,
size=calc_model_size_by_data(loaded_model),
)
return self._ram_cache.get(
@@ -126,9 +125,7 @@ class ModelLoader(ModelLoaderBase):
if subtype == submodel_type:
continue
if submodel := getattr(pipeline, subtype.value, None):
self._ram_cache.put(
config.key, submodel_type=subtype, model=submodel, size=calc_model_size_by_data(submodel)
)
self._ram_cache.put(config.key, submodel_type=subtype, model=submodel)
return getattr(pipeline, submodel_type.value) if submodel_type else pipeline
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:

View File

@@ -30,6 +30,11 @@ class ModelLockerBase(ABC):
"""Unlock the contained model, and remove it from VRAM."""
pass
@abstractmethod
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
"""Return the state dict (if any) for the cached model."""
pass
@property
@abstractmethod
def model(self) -> AnyModel:
@@ -42,10 +47,31 @@ T = TypeVar("T")
@dataclass
class CacheRecord(Generic[T]):
"""Elements of the cache."""
"""
Elements of the cache:
key: Unique key for each model, same as used in the models database.
model: Model in memory.
state_dict: A read-only copy of the model's state dict in RAM. It will be
used as a template for creating a copy in the VRAM.
size: Size of the model
loaded: True if the model's state dict is currently in VRAM
Before a model is executed, the state_dict template is copied into VRAM,
and then injected into the model. When the model is finished, the VRAM
copy of the state dict is deleted, and the RAM version is reinjected
into the model.
The state_dict should be treated as a read-only attribute. Do not attempt
to patch or otherwise modify it. Instead, patch the copy of the state_dict
after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel`
context manager call `model_on_device()`.
"""
key: str
model: T
device: torch.device
state_dict: Optional[Dict[str, torch.Tensor]]
size: int
loaded: bool = False
_locks: int = 0
@@ -143,7 +169,6 @@ class ModelCacheBase(ABC, Generic[T]):
self,
key: str,
model: T,
size: int,
submodel_type: Optional[SubModelType] = None,
) -> None:
"""Store model under key and optional submodel_type."""

View File

@@ -20,7 +20,6 @@ context. Use like this:
import gc
import math
import sys
import time
from contextlib import suppress
from logging import Logger
@@ -30,6 +29,7 @@ import torch
from invokeai.backend.model_manager import AnyModel, SubModelType
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
@@ -154,15 +154,17 @@ class ModelCache(ModelCacheBase[AnyModel]):
self,
key: str,
model: AnyModel,
size: int,
submodel_type: Optional[SubModelType] = None,
) -> None:
"""Store model under key and optional submodel_type."""
key = self._make_cache_key(key, submodel_type)
if key in self._cached_models:
return
size = calc_model_size_by_data(model)
self.make_room(size)
cache_record = CacheRecord(key, model, size)
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None
cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size)
self._cached_models[key] = cache_record
self._cache_stack.append(key)
@@ -251,23 +253,42 @@ class ModelCache(ModelCacheBase[AnyModel]):
May raise a torch.cuda.OutOfMemoryError
"""
# These attributes are not in the base ModelMixin class but in various derived classes.
# Some models don't have these attributes, in which case they run in RAM/CPU.
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
if not (hasattr(cache_entry.model, "device") and hasattr(cache_entry.model, "to")):
return
source_device = cache_entry.model.device
source_device = cache_entry.device
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
# This would need to be revised to support multi-GPU.
if torch.device(source_device).type == torch.device(target_device).type:
return
# Some models don't have a `to` method, in which case they run in RAM/CPU.
if not hasattr(cache_entry.model, "to"):
return
# This roundabout method for moving the model around is done to avoid
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
# When moving to VRAM, we copy (not move) each element of the state dict from
# RAM to a new state dict in VRAM, and then inject it into the model.
# This operation is slightly faster than running `to()` on the whole model.
#
# When the model needs to be removed from VRAM we simply delete the copy
# of the state dict in VRAM, and reinject the state dict that is cached
# in RAM into the model. So this operation is very fast.
start_model_to_time = time.time()
snapshot_before = self._capture_memory_snapshot()
try:
cache_entry.model.to(target_device)
if cache_entry.state_dict is not None:
assert hasattr(cache_entry.model, "load_state_dict")
if target_device == self.storage_device:
cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True)
else:
new_dict: Dict[str, torch.Tensor] = {}
for k, v in cache_entry.state_dict.items():
new_dict[k] = v.to(torch.device(target_device), copy=True, non_blocking=True)
cache_entry.model.load_state_dict(new_dict, assign=True)
cache_entry.model.to(target_device, non_blocking=True)
cache_entry.device = target_device
except Exception as e: # blow away cache entry
self._delete_cache_entry(cache_entry)
raise e
@@ -347,43 +368,12 @@ class ModelCache(ModelCacheBase[AnyModel]):
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
model_key = self._cache_stack[pos]
cache_entry = self._cached_models[model_key]
refs = sys.getrefcount(cache_entry.model)
# HACK: This is a workaround for a memory-management issue that we haven't tracked down yet. We are directly
# going against the advice in the Python docs by using `gc.get_referrers(...)` in this way:
# https://docs.python.org/3/library/gc.html#gc.get_referrers
# manualy clear local variable references of just finished function calls
# for some reason python don't want to collect it even by gc.collect() immidiately
if refs > 2:
while True:
cleared = False
for referrer in gc.get_referrers(cache_entry.model):
if type(referrer).__name__ == "frame":
# RuntimeError: cannot clear an executing frame
with suppress(RuntimeError):
referrer.clear()
cleared = True
# break
# repeat if referrers changes(due to frame clear), else exit loop
if cleared:
gc.collect()
else:
break
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
self.logger.debug(
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded},"
f" refs: {refs}"
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}"
)
# Expected refs:
# 1 from cache_entry
# 1 from getrefcount function
# 1 from onnx runtime object
if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2):
if not cache_entry.locked:
self.logger.debug(
f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
)

View File

@@ -2,6 +2,8 @@
Base class and implementation of a class that moves models in and out of VRAM.
"""
from typing import Dict, Optional
import torch
from invokeai.backend.model_manager import AnyModel
@@ -27,20 +29,18 @@ class ModelLocker(ModelLockerBase):
"""Return the model without moving it around."""
return self._cache_entry.model
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
"""Return the state dict (if any) for the cached model."""
return self._cache_entry.state_dict
def lock(self) -> AnyModel:
"""Move the model into the execution device (GPU) and lock it."""
if not hasattr(self.model, "to"):
return self.model
# NOTE that the model has to have the to() method in order for this code to move it into GPU!
self._cache_entry.lock()
try:
if self._cache.lazy_offloading:
self._cache.offload_unlocked_models(self._cache_entry.size)
self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device)
self._cache_entry.loaded = True
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}")
self._cache.print_cuda_stats()
except torch.cuda.OutOfMemoryError:
@@ -55,10 +55,7 @@ class ModelLocker(ModelLockerBase):
def unlock(self) -> None:
"""Call upon exit from context."""
if not hasattr(self.model, "to"):
return
self._cache_entry.unlock()
if not self._cache.lazy_offloading:
self._cache.offload_unlocked_models(self._cache_entry.size)
self._cache.offload_unlocked_models(0)
self._cache.print_cuda_stats()

View File

@@ -65,14 +65,11 @@ class GenericDiffusersLoader(ModelLoader):
else:
try:
config = self._load_diffusers_config(model_path, config_name="config.json")
class_name = config.get("_class_name", None)
if class_name:
if class_name := config.get("_class_name"):
result = self._hf_definition_to_type(module="diffusers", class_name=class_name)
if config.get("model_type", None) == "clip_vision_model":
class_name = config.get("architectures")
assert class_name is not None
elif class_name := config.get("architectures"):
result = self._hf_definition_to_type(module="transformers", class_name=class_name[0])
if not class_name:
else:
raise InvalidModelConfigException("Unable to decipher Load Class based on given config.json")
except KeyError as e:
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e

View File

@@ -83,7 +83,7 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
assert s.size is not None
files.append(
RemoteModelFile(
url=hf_hub_url(id, s.rfilename, revision=variant),
url=hf_hub_url(id, s.rfilename, revision=variant or "main"),
path=Path(name, s.rfilename),
size=s.size,
sha256=s.lfs.get("sha256") if s.lfs else None,

View File

@@ -37,9 +37,12 @@ class RemoteModelFile(BaseModel):
url: AnyHttpUrl = Field(description="The url to download this model file")
path: Path = Field(description="The path to the file, relative to the model root")
size: int = Field(description="The size of this file, in bytes")
size: Optional[int] = Field(description="The size of this file, in bytes", default=0)
sha256: Optional[str] = Field(description="SHA256 hash of this model (not always available)", default=None)
def __hash__(self) -> int:
return hash(str(self))
class ModelMetadataBase(BaseModel):
"""Base class for model metadata information."""

View File

@@ -10,7 +10,7 @@ from picklescan.scanner import scan_file_path
import invokeai.backend.util.logging as logger
from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
from invokeai.backend.util.util import SilenceWarnings
from invokeai.backend.util.silence_warnings import SilenceWarnings
from .config import (
AnyModelConfig,

View File

@@ -5,7 +5,7 @@ from __future__ import annotations
import pickle
from contextlib import contextmanager
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union
import numpy as np
import torch
@@ -66,8 +66,14 @@ class ModelPatcher:
cls,
unet: UNet2DConditionModel,
loras: Iterator[Tuple[LoRAModelRaw, float]],
) -> None:
with cls.apply_lora(unet, loras, "lora_unet_"):
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
) -> Generator[None, None, None]:
with cls.apply_lora(
unet,
loras=loras,
prefix="lora_unet_",
model_state_dict=model_state_dict,
):
yield
@classmethod
@@ -76,28 +82,9 @@ class ModelPatcher:
cls,
text_encoder: CLIPTextModel,
loras: Iterator[Tuple[LoRAModelRaw, float]],
) -> None:
with cls.apply_lora(text_encoder, loras, "lora_te_"):
yield
@classmethod
@contextmanager
def apply_sdxl_lora_text_encoder(
cls,
text_encoder: CLIPTextModel,
loras: List[Tuple[LoRAModelRaw, float]],
) -> None:
with cls.apply_lora(text_encoder, loras, "lora_te1_"):
yield
@classmethod
@contextmanager
def apply_sdxl_lora_text_encoder2(
cls,
text_encoder: CLIPTextModel,
loras: List[Tuple[LoRAModelRaw, float]],
) -> None:
with cls.apply_lora(text_encoder, loras, "lora_te2_"):
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
) -> Generator[None, None, None]:
with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", model_state_dict=model_state_dict):
yield
@classmethod
@@ -107,7 +94,16 @@ class ModelPatcher:
model: AnyModel,
loras: Iterator[Tuple[LoRAModelRaw, float]],
prefix: str,
) -> None:
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
) -> Generator[None, None, None]:
"""
Apply one or more LoRAs to a model.
:param model: The model to patch.
:param loras: An iterator that returns the LoRA to patch in and its patch weight.
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
:model_state_dict: Read-only copy of the model's state dict in CPU, for unpatching purposes.
"""
original_weights = {}
try:
with torch.no_grad():
@@ -133,19 +129,22 @@ class ModelPatcher:
dtype = module.weight.dtype
if module_key not in original_weights:
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
if model_state_dict is not None: # we were provided with the CPU copy of the state dict
original_weights[module_key] = model_state_dict[module_key + ".weight"]
else:
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
# We intentionally move to the target device first, then cast. Experimentally, this was found to
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
# same thing in a single call to '.to(...)'.
layer.to(device=device)
layer.to(dtype=torch.float32)
layer.to(device=device, non_blocking=True)
layer.to(dtype=torch.float32, non_blocking=True)
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
layer.to(device=torch.device("cpu"))
layer.to(device=torch.device("cpu"), non_blocking=True)
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
if module.weight.shape != layer_weight.shape:
@@ -154,7 +153,7 @@ class ModelPatcher:
layer_weight = layer_weight.reshape(module.weight.shape)
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
module.weight += layer_weight.to(dtype=dtype)
module.weight += layer_weight.to(dtype=dtype, non_blocking=True)
yield # wait for context manager exit
@@ -162,7 +161,7 @@ class ModelPatcher:
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
with torch.no_grad():
for module_key, weight in original_weights.items():
model.get_submodule(module_key).weight.copy_(weight)
model.get_submodule(module_key).weight.copy_(weight, non_blocking=True)
@classmethod
@contextmanager

View File

@@ -6,6 +6,7 @@ from typing import Any, List, Optional, Tuple, Union
import numpy as np
import onnx
import torch
from onnx import numpy_helper
from onnxruntime import InferenceSession, SessionOptions, get_available_providers
@@ -188,6 +189,15 @@ class IAIOnnxRuntimeModel(RawModel):
# return self.io_binding.copy_outputs_to_cpu()
return self.session.run(None, inputs)
# compatability with RawModel ABC
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
pass
# compatability with diffusers load code
@classmethod
def from_pretrained(

View File

@@ -10,6 +10,20 @@ The term 'raw' was introduced to describe a wrapper around a torch.nn.Module
that adds additional methods and attributes.
"""
from abc import ABC, abstractmethod
from typing import Optional
class RawModel:
"""Base class for 'Raw' model wrappers."""
import torch
class RawModel(ABC):
"""Abstract base class for 'Raw' model wrappers."""
@abstractmethod
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
pass

View File

@@ -10,12 +10,11 @@ import PIL.Image
import psutil
import torch
import torchvision.transforms as T
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.controlnet import ControlNetModel
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
from diffusers.utils.import_utils import is_xformers_available
from pydantic import Field
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
@@ -26,6 +25,7 @@ from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion impor
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
from invokeai.backend.util.attention import auto_detect_slice_size
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.hotfixes import ControlNetModel
@dataclass
@@ -38,56 +38,18 @@ class PipelineIntermediateState:
predicted_original: Optional[torch.Tensor] = None
@dataclass
class AddsMaskLatents:
"""Add the channels required for inpainting model input.
The inpainting model takes the normal latent channels as input, _plus_ a one-channel mask
and the latent encoding of the base image.
This class assumes the same mask and base image should apply to all items in the batch.
"""
forward: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
mask: torch.Tensor
initial_image_latents: torch.Tensor
def __call__(
self,
latents: torch.Tensor,
t: torch.Tensor,
text_embeddings: torch.Tensor,
**kwargs,
) -> torch.Tensor:
model_input = self.add_mask_channels(latents)
return self.forward(model_input, t, text_embeddings, **kwargs)
def add_mask_channels(self, latents):
batch_size = latents.size(0)
# duplicate mask and latents for each batch
mask = einops.repeat(self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size)
image_latents = einops.repeat(self.initial_image_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
# add mask and image as additional channels
model_input, _ = einops.pack([latents, mask, image_latents], "b * h w")
return model_input
def are_like_tensors(a: torch.Tensor, b: object) -> bool:
return isinstance(b, torch.Tensor) and (a.size() == b.size())
@dataclass
class AddsMaskGuidance:
mask: torch.FloatTensor
mask_latents: torch.FloatTensor
mask: torch.Tensor
mask_latents: torch.Tensor
scheduler: SchedulerMixin
noise: torch.Tensor
gradient_mask: bool
is_gradient_mask: bool
def __call__(self, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
return self.apply_mask(latents, t)
def apply_mask(self, latents: torch.Tensor, t) -> torch.Tensor:
def apply_mask(self, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
batch_size = latents.size(0)
mask = einops.repeat(self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size)
if t.dim() == 0:
@@ -100,7 +62,7 @@ class AddsMaskGuidance:
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
if self.gradient_mask:
if self.is_gradient_mask:
threshhold = (t.item()) / self.scheduler.config.num_train_timesteps
mask_bool = mask > threshhold # I don't know when mask got inverted, but it did
masked_input = torch.where(mask_bool, latents, mask_latents)
@@ -200,7 +162,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
safety_checker: Optional[StableDiffusionSafetyChecker],
feature_extractor: Optional[CLIPFeatureExtractor],
requires_safety_checker: bool = False,
control_model: ControlNetModel = None,
):
super().__init__(
vae=vae,
@@ -214,8 +175,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
)
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
self.control_model = control_model
self.use_ip_adapter = False
def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
"""
@@ -280,116 +239,131 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False):
raise Exception("Should not be called")
def add_inpainting_channels_to_latents(
self, latents: torch.Tensor, masked_ref_image_latents: torch.Tensor, inpainting_mask: torch.Tensor
):
"""Given a `latents` tensor, adds the mask and image latents channels required for inpainting.
Standard (non-inpainting) SD UNet models expect an input with shape (N, 4, H, W). Inpainting models expect an
input of shape (N, 9, H, W). The 9 channels are defined as follows:
- Channel 0-3: The latents being denoised.
- Channel 4: The mask indicating which parts of the image are being inpainted.
- Channel 5-8: The latent representation of the masked reference image being inpainted.
This function assumes that the same mask and base image should apply to all items in the batch.
"""
# Validate assumptions about input tensor shapes.
batch_size, latent_channels, latent_height, latent_width = latents.shape
assert latent_channels == 4
assert masked_ref_image_latents.shape == [1, 4, latent_height, latent_width]
assert inpainting_mask == [1, 1, latent_height, latent_width]
# Repeat original_image_latents and inpainting_mask to match the latents batch size.
original_image_latents = masked_ref_image_latents.expand(batch_size, -1, -1, -1)
inpainting_mask = inpainting_mask.expand(batch_size, -1, -1, -1)
# Concatenate along the channel dimension.
return torch.cat([latents, inpainting_mask, original_image_latents], dim=1)
def latents_from_embeddings(
self,
latents: torch.Tensor,
num_inference_steps: int,
scheduler_step_kwargs: dict[str, Any],
conditioning_data: TextConditioningData,
*,
noise: Optional[torch.Tensor],
seed: int,
timesteps: torch.Tensor,
init_timestep: torch.Tensor,
additional_guidance: List[Callable] = None,
callback: Callable[[PipelineIntermediateState], None] = None,
control_data: List[ControlNetData] = None,
callback: Callable[[PipelineIntermediateState], None],
control_data: list[ControlNetData] | None = None,
ip_adapter_data: Optional[list[IPAdapterData]] = None,
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
mask: Optional[torch.Tensor] = None,
masked_latents: Optional[torch.Tensor] = None,
gradient_mask: Optional[bool] = False,
seed: int,
is_gradient_mask: bool = False,
) -> torch.Tensor:
if init_timestep.shape[0] == 0:
return latents
"""Denoise the latents.
if additional_guidance is None:
additional_guidance = []
Args:
latents: The latent-space image to denoise.
- If we are inpainting, this is the initial latent image before noise has been added.
- If we are generating a new image, this should be initialized to zeros.
- In some cases, this may be a partially-noised latent image (e.g. when running the SDXL refiner).
scheduler_step_kwargs: kwargs forwarded to the scheduler.step() method.
conditioning_data: Text conditionging data.
noise: Noise used for two purposes:
1. Used by the scheduler to noise the initial `latents` before denoising.
2. Used to noise the `masked_latents` when inpainting.
`noise` should be None if the `latents` tensor has already been noised.
seed: The seed used to generate the noise for the denoising process.
HACK(ryand): seed is only used in a particular case when `noise` is None, but we need to re-generate the
same noise used earlier in the pipeline. This should really be handled in a clearer way.
timesteps: The timestep schedule for the denoising process.
init_timestep: The first timestep in the schedule.
TODO(ryand): I'm pretty sure this should always be the same as timesteps[0:1]. Confirm that that is the
case, and remove this duplicate param.
callback: A callback function that is called to report progress during the denoising process.
control_data: ControlNet data.
ip_adapter_data: IP-Adapter data.
t2i_adapter_data: T2I-Adapter data.
mask: A mask indicating which parts of the image are being inpainted. The presence of mask is used to
determine whether we are inpainting or not. `mask` should have the same spatial dimensions as the
`latents` tensor.
TODO(ryand): Check and document the expected dtype, range, and values used to represent
foreground/background.
masked_latents: A latent-space representation of a masked inpainting reference image. This tensor is only
used if an *inpainting* model is being used i.e. this tensor is not used when inpainting with a standard
SD UNet model.
is_gradient_mask: A flag indicating whether `mask` is a gradient mask or not.
"""
# TODO(ryand): Figure out why this condition is necessary, and document it. My guess is that it's to handle
# cases where densoisings_start and denoising_end are set such that there are no timesteps.
if init_timestep.shape[0] == 0 or timesteps.shape[0] == 0:
return latents
orig_latents = latents.clone()
batch_size = latents.shape[0]
batched_t = init_timestep.expand(batch_size)
batched_init_timestep = init_timestep.expand(batch_size)
# noise can be None if the latents have already been noised (e.g. when running the SDXL refiner).
if noise is not None:
# TODO(ryand): I'm pretty sure we should be applying init_noise_sigma in cases where we are starting with
# full noise. Investigate the history of why this got commented out.
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
latents = self.scheduler.add_noise(latents, noise, batched_t)
latents = self.scheduler.add_noise(latents, noise, batched_init_timestep)
if mask is not None:
if is_inpainting_model(self.unet):
if masked_latents is None:
raise Exception("Source image required for inpaint mask when inpaint model used!")
self.invokeai_diffuser.model_forward_callback = AddsMaskLatents(
self._unet_forward, mask, masked_latents
)
else:
# if no noise provided, noisify unmasked area based on seed
if noise is None:
noise = torch.randn(
orig_latents.shape,
dtype=torch.float32,
device="cpu",
generator=torch.Generator(device="cpu").manual_seed(seed),
).to(device=orig_latents.device, dtype=orig_latents.dtype)
additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, gradient_mask))
try:
latents = self.generate_latents_from_embeddings(
latents,
timesteps,
conditioning_data,
scheduler_step_kwargs=scheduler_step_kwargs,
additional_guidance=additional_guidance,
control_data=control_data,
ip_adapter_data=ip_adapter_data,
t2i_adapter_data=t2i_adapter_data,
callback=callback,
)
finally:
self.invokeai_diffuser.model_forward_callback = self._unet_forward
# restore unmasked part after the last step is completed
# in-process masking happens before each step
if mask is not None:
if gradient_mask:
latents = torch.where(mask > 0, latents, orig_latents)
else:
latents = torch.lerp(
orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype)
)
return latents
def generate_latents_from_embeddings(
self,
latents: torch.Tensor,
timesteps,
conditioning_data: TextConditioningData,
scheduler_step_kwargs: dict[str, Any],
*,
additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None,
ip_adapter_data: Optional[list[IPAdapterData]] = None,
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
callback: Callable[[PipelineIntermediateState], None] = None,
) -> torch.Tensor:
self._adjust_memory_efficient_attention(latents)
if additional_guidance is None:
additional_guidance = []
batch_size = latents.shape[0]
# Handle mask guidance (a.k.a. inpainting).
mask_guidance: AddsMaskGuidance | None = None
if mask is not None and not is_inpainting_model(self.unet):
# We are doing inpainting, since a mask is provided, but we are not using an inpainting model, so we will
# apply mask guidance to the latents.
if timesteps.shape[0] == 0:
return latents
# 'noise' might be None if the latents have already been noised (e.g. when running the SDXL refiner).
# We still need noise for inpainting, so we generate it from the seed here.
if noise is None:
noise = torch.randn(
orig_latents.shape,
dtype=torch.float32,
device="cpu",
generator=torch.Generator(device="cpu").manual_seed(seed),
).to(device=orig_latents.device, dtype=orig_latents.dtype)
mask_guidance = AddsMaskGuidance(
mask=mask,
mask_latents=orig_latents,
scheduler=self.scheduler,
noise=noise,
is_gradient_mask=is_gradient_mask,
)
use_ip_adapter = ip_adapter_data is not None
use_regional_prompting = (
conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None
)
unet_attention_patcher = None
self.use_ip_adapter = use_ip_adapter
attn_ctx = nullcontext()
if use_ip_adapter or use_regional_prompting:
@@ -402,28 +376,28 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
with attn_ctx:
if callback is not None:
callback(
PipelineIntermediateState(
step=-1,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=self.scheduler.config.num_train_timesteps,
latents=latents,
)
callback(
PipelineIntermediateState(
step=-1,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=self.scheduler.config.num_train_timesteps,
latents=latents,
)
)
# print("timesteps:", timesteps)
for i, t in enumerate(self.progress_bar(timesteps)):
batched_t = t.expand(batch_size)
step_output = self.step(
batched_t,
latents,
conditioning_data,
t=batched_t,
latents=latents,
conditioning_data=conditioning_data,
step_index=i,
total_step_count=len(timesteps),
scheduler_step_kwargs=scheduler_step_kwargs,
additional_guidance=additional_guidance,
mask_guidance=mask_guidance,
mask=mask,
masked_latents=masked_latents,
control_data=control_data,
ip_adapter_data=ip_adapter_data,
t2i_adapter_data=t2i_adapter_data,
@@ -431,19 +405,28 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
latents = step_output.prev_sample
predicted_original = getattr(step_output, "pred_original_sample", None)
if callback is not None:
callback(
PipelineIntermediateState(
step=i,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=int(t),
latents=latents,
predicted_original=predicted_original,
)
callback(
PipelineIntermediateState(
step=i,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=int(t),
latents=latents,
predicted_original=predicted_original,
)
)
return latents
# restore unmasked part after the last step is completed
# in-process masking happens before each step
if mask is not None:
if is_gradient_mask:
latents = torch.where(mask > 0, latents, orig_latents)
else:
latents = torch.lerp(
orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype)
)
return latents
@torch.inference_mode()
def step(
@@ -454,19 +437,20 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
step_index: int,
total_step_count: int,
scheduler_step_kwargs: dict[str, Any],
additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None,
mask_guidance: AddsMaskGuidance | None,
mask: torch.Tensor | None,
masked_latents: torch.Tensor | None,
control_data: list[ControlNetData] | None = None,
ip_adapter_data: Optional[list[IPAdapterData]] = None,
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
):
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
timestep = t[0]
if additional_guidance is None:
additional_guidance = []
# one day we will expand this extension point, but for now it just does denoise masking
for guidance in additional_guidance:
latents = guidance(latents, timestep)
# Handle masked image-to-image (a.k.a inpainting).
if mask_guidance is not None:
# NOTE: This is intentionally done *before* self.scheduler.scale_model_input(...).
latents = mask_guidance(latents, timestep)
# TODO: should this scaling happen here or inside self._unet_forward?
# i.e. before or after passing it to InvokeAIDiffuserComponent
@@ -514,6 +498,31 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
down_intrablock_additional_residuals = accum_adapter_state
# Handle inpainting models.
if is_inpainting_model(self.unet):
# NOTE: These calls to add_inpainting_channels_to_latents(...) are intentionally done *after*
# self.scheduler.scale_model_input(...) so that the scaling is not applied to the mask or reference image
# latents.
if mask is not None:
if masked_latents is None:
raise ValueError("Source image required for inpaint mask when inpaint model used!")
latent_model_input = self.add_inpainting_channels_to_latents(
latents=latent_model_input, masked_ref_image_latents=masked_latents, inpainting_mask=mask
)
else:
# We are using an inpainting model, but no mask was provided, so we are not really "inpainting".
# We generate a global mask and empty original image so that we can still generate in this
# configuration.
# TODO(ryand): Should we just raise an exception here instead? I can't think of a use case for wanting
# to do this.
# TODO(ryand): If we decide that there is a good reason to keep this, then we should generate the 'fake'
# mask and original image once rather than on every denoising step.
latent_model_input = self.add_inpainting_channels_to_latents(
latents=latent_model_input,
masked_ref_image_latents=torch.zeros_like(latent_model_input[:1]),
inpainting_mask=torch.ones_like(latent_model_input[:1, :1]),
)
uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step(
sample=latent_model_input,
timestep=t, # TODO: debug how handled batched and non batched timesteps
@@ -542,17 +551,18 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# compute the previous noisy sample x_t -> x_t-1
step_output = self.scheduler.step(noise_pred, timestep, latents, **scheduler_step_kwargs)
# TODO: discuss injection point options. For now this is a patch to get progress images working with inpainting again.
for guidance in additional_guidance:
# apply the mask to any "denoised" or "pred_original_sample" fields
# TODO: discuss injection point options. For now this is a patch to get progress images working with inpainting
# again.
if mask_guidance is not None:
# Apply the mask to any "denoised" or "pred_original_sample" fields.
if hasattr(step_output, "denoised"):
step_output.pred_original_sample = guidance(step_output.denoised, self.scheduler.timesteps[-1])
step_output.pred_original_sample = mask_guidance(step_output.denoised, self.scheduler.timesteps[-1])
elif hasattr(step_output, "pred_original_sample"):
step_output.pred_original_sample = guidance(
step_output.pred_original_sample = mask_guidance(
step_output.pred_original_sample, self.scheduler.timesteps[-1]
)
else:
step_output.pred_original_sample = guidance(latents, self.scheduler.timesteps[-1])
step_output.pred_original_sample = mask_guidance(latents, self.scheduler.timesteps[-1])
return step_output
@@ -575,17 +585,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
**kwargs,
):
"""predict the noise residual"""
if is_inpainting_model(self.unet) and latents.size(1) == 4:
# Pad out normal non-inpainting inputs for an inpainting model.
# FIXME: There are too many layers of functions and we have too many different ways of
# overriding things! This should get handled in a way more consistent with the other
# use of AddsMaskLatents.
latents = AddsMaskLatents(
self._unet_forward,
mask=torch.ones_like(latents[:1, :1], device=latents.device, dtype=latents.dtype),
initial_image_latents=torch.zeros_like(latents[:1], device=latents.device, dtype=latents.dtype),
).add_mask_channels(latents)
# First three args should be positional, not keywords, so torch hooks can see them.
return self.unet(
latents,

View File

@@ -0,0 +1,242 @@
from __future__ import annotations
import copy
from dataclasses import dataclass
from typing import Any, Callable, Optional
import torch
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
ControlNetData,
PipelineIntermediateState,
StableDiffusionGeneratorPipeline,
)
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData
from invokeai.backend.tiles.utils import TBLR
# The maximum number of regions with compatible sizes that will be batched together.
# Larger batch sizes improve speed, but require more device memory.
MAX_REGION_BATCH_SIZE = 4
@dataclass
class MultiDiffusionRegionConditioning:
# Region coords in latent space.
region: TBLR
text_conditioning_data: TextConditioningData
control_data: list[ControlNetData]
class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
"""A Stable Diffusion pipeline that uses Multi-Diffusion (https://arxiv.org/pdf/2302.08113) for denoising."""
def _split_into_region_batches(
self, multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning]
) -> list[list[MultiDiffusionRegionConditioning]]:
# Group the regions by shape. Only regions with the same shape can be batched together.
conditioning_by_shape: dict[tuple[int, int], list[MultiDiffusionRegionConditioning]] = {}
for region_conditioning in multi_diffusion_conditioning:
shape_hw = (
region_conditioning.region.bottom - region_conditioning.region.top,
region_conditioning.region.right - region_conditioning.region.left,
)
# In python, a tuple of hashable objects is hashable, so can be used as a key in a dict.
if shape_hw not in conditioning_by_shape:
conditioning_by_shape[shape_hw] = []
conditioning_by_shape[shape_hw].append(region_conditioning)
# Split the regions into batches, respecting the MAX_REGION_BATCH_SIZE constraint.
region_conditioning_batches = []
for region_conditioning_batch in conditioning_by_shape.values():
for i in range(0, len(region_conditioning_batch), MAX_REGION_BATCH_SIZE):
region_conditioning_batches.append(region_conditioning_batch[i : i + MAX_REGION_BATCH_SIZE])
return region_conditioning_batches
def _check_regional_prompting(self, multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning]):
"""Check the input conditioning and confirm that regional prompting is not used."""
for region_conditioning in multi_diffusion_conditioning:
if (
region_conditioning.text_conditioning_data.cond_regions is not None
or region_conditioning.text_conditioning_data.uncond_regions is not None
):
raise NotImplementedError("Regional prompting is not yet supported in Multi-Diffusion.")
def multi_diffusion_denoise(
self,
multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning],
latents: torch.Tensor,
scheduler_step_kwargs: dict[str, Any],
noise: Optional[torch.Tensor],
timesteps: torch.Tensor,
init_timestep: torch.Tensor,
callback: Callable[[PipelineIntermediateState], None],
) -> torch.Tensor:
self._check_regional_prompting(multi_diffusion_conditioning)
# TODO(ryand): Figure out why this condition is necessary, and document it. My guess is that it's to handle
# cases where densoisings_start and denoising_end are set such that there are no timesteps.
if init_timestep.shape[0] == 0 or timesteps.shape[0] == 0:
return latents
batch_size, _, latent_height, latent_width = latents.shape
batched_init_timestep = init_timestep.expand(batch_size)
# noise can be None if the latents have already been noised (e.g. when running the SDXL refiner).
if noise is not None:
# TODO(ryand): I'm pretty sure we should be applying init_noise_sigma in cases where we are starting with
# full noise. Investigate the history of why this got commented out.
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
latents = self.scheduler.add_noise(latents, noise, batched_init_timestep)
# TODO(ryand): Look into the implications of passing in latents here that are larger than they will be after
# cropping into regions.
self._adjust_memory_efficient_attention(latents)
# Populate a weighted mask that will be used to combine the results from each region after every step.
# For now, we assume that each region has the same weight (1.0).
region_weight_mask = torch.zeros(
(1, 1, latent_height, latent_width), device=latents.device, dtype=latents.dtype
)
for region_conditioning in multi_diffusion_conditioning:
region = region_conditioning.region
region_weight_mask[:, :, region.top : region.bottom, region.left : region.right] += 1.0
# Group the region conditioning into batches for faster processing.
# region_conditioning_batches[b][r] is the r'th region in the b'th batch.
region_conditioning_batches = self._split_into_region_batches(multi_diffusion_conditioning)
# Many of the diffusers schedulers are stateful (i.e. they update internal state in each call to step()). Since
# we are calling step() multiple times at the same timestep (once for each region batch), we must maintain a
# separate scheduler state for each region batch.
region_batch_schedulers: list[SchedulerMixin] = [
copy.deepcopy(self.scheduler) for _ in region_conditioning_batches
]
callback(
PipelineIntermediateState(
step=-1,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=self.scheduler.config.num_train_timesteps,
latents=latents,
)
)
for i, t in enumerate(self.progress_bar(timesteps)):
batched_t = t.expand(batch_size)
merged_latents = torch.zeros_like(latents)
merged_pred_original: torch.Tensor | None = None
for region_batch_idx, region_conditioning_batch in enumerate(region_conditioning_batches):
# Switch to the scheduler for the region batch.
self.scheduler = region_batch_schedulers[region_batch_idx]
# TODO(ryand): This logic has not yet been tested with input latents with a batch_size > 1.
# Prepare the latents for the region batch.
batch_latents = torch.cat(
[
latents[
:,
:,
region_conditioning.region.top : region_conditioning.region.bottom,
region_conditioning.region.left : region_conditioning.region.right,
]
for region_conditioning in region_conditioning_batch
],
)
# TODO(ryand): Do we have to repeat the text_conditioning_data to match the batch size? Or does step()
# handle broadcasting properly?
# TODO(ryand): Resume here!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
# Run the denoising step on the region.
step_output = self.step(
t=batched_t,
latents=batch_latents,
conditioning_data=region_conditioning.text_conditioning_data,
step_index=i,
total_step_count=total_step_count,
scheduler_step_kwargs=scheduler_step_kwargs,
mask_guidance=None,
mask=None,
masked_latents=None,
control_data=region_conditioning.control_data,
)
# Run a denoising step on the region.
# step_output = self._region_step(
# region_conditioning=region_conditioning,
# t=batched_t,
# latents=latents,
# step_index=i,
# total_step_count=len(timesteps),
# scheduler_step_kwargs=scheduler_step_kwargs,
# )
# Store the results from the region.
region = region_conditioning.region
merged_latents[:, :, region.top : region.bottom, region.left : region.right] += step_output.prev_sample
pred_orig_sample = getattr(step_output, "pred_original_sample", None)
if pred_orig_sample is not None:
# If one region has pred_original_sample, then we can assume that all regions will have it, because
# they all use the same scheduler.
if merged_pred_original is None:
merged_pred_original = torch.zeros_like(latents)
merged_pred_original[:, :, region.top : region.bottom, region.left : region.right] += (
pred_orig_sample
)
# Normalize the merged results.
latents = torch.where(region_weight_mask > 0, merged_latents / region_weight_mask, merged_latents)
predicted_original = None
if merged_pred_original is not None:
predicted_original = torch.where(
region_weight_mask > 0, merged_pred_original / region_weight_mask, merged_pred_original
)
callback(
PipelineIntermediateState(
step=i,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=int(t),
latents=latents,
predicted_original=predicted_original,
)
)
return latents
@torch.inference_mode()
def _region_batch_step(
self,
region_conditioning: MultiDiffusionRegionConditioning,
t: torch.Tensor,
latents: torch.Tensor,
step_index: int,
total_step_count: int,
scheduler_step_kwargs: dict[str, Any],
):
# Crop the inputs to the region.
region_latents = latents[
:,
:,
region_conditioning.region.top : region_conditioning.region.bottom,
region_conditioning.region.left : region_conditioning.region.right,
]
# Run the denoising step on the region.
return self.step(
t=t,
latents=region_latents,
conditioning_data=region_conditioning.text_conditioning_data,
step_index=step_index,
total_step_count=total_step_count,
scheduler_step_kwargs=scheduler_step_kwargs,
mask_guidance=None,
mask=None,
masked_latents=None,
control_data=region_conditioning.control_data,
)

View File

@@ -1,7 +1,7 @@
"""Textual Inversion wrapper class."""
from pathlib import Path
from typing import Dict, List, Optional, Union
from typing import Optional, Union
import torch
from compel.embeddings_provider import BaseTextualInversionManager
@@ -65,36 +65,65 @@ class TextualInversionModelRaw(RawModel):
return result
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
if not torch.cuda.is_available():
return
for emb in [self.embedding, self.embedding_2]:
if emb is not None:
emb.to(device=device, dtype=dtype, non_blocking=non_blocking)
# no type hints for BaseTextualInversionManager?
class TextualInversionManager(BaseTextualInversionManager): # type: ignore
pad_tokens: Dict[int, List[int]]
tokenizer: CLIPTokenizer
class TextualInversionManager(BaseTextualInversionManager):
"""TextualInversionManager implements the BaseTextualInversionManager ABC from the compel library."""
def __init__(self, tokenizer: CLIPTokenizer):
self.pad_tokens = {}
self.pad_tokens: dict[int, list[int]] = {}
self.tokenizer = tokenizer
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
"""Given a list of tokens ids, expand any TI tokens to their corresponding pad tokens.
For example, suppose we have a `<ti_dog>` TI with 4 vectors that was added to the tokenizer with the following
mapping of tokens to token_ids:
```
<ti_dog>: 49408
<ti_dog-!pad-1>: 49409
<ti_dog-!pad-2>: 49410
<ti_dog-!pad-3>: 49411
```
`self.pad_tokens` would be set to `{49408: [49408, 49409, 49410, 49411]}`.
This function is responsible for expanding `49408` in the token_ids list to `[49408, 49409, 49410, 49411]`.
"""
# Short circuit if there are no pad tokens to save a little time.
if len(self.pad_tokens) == 0:
return token_ids
# This function assumes that compel has not included the BOS and EOS tokens in the token_ids list. We verify
# this assumption here.
if token_ids[0] == self.tokenizer.bos_token_id:
raise ValueError("token_ids must not start with bos_token_id")
if token_ids[-1] == self.tokenizer.eos_token_id:
raise ValueError("token_ids must not end with eos_token_id")
new_token_ids = []
# Expand any TI tokens to their corresponding pad tokens.
new_token_ids: list[int] = []
for token_id in token_ids:
new_token_ids.append(token_id)
if token_id in self.pad_tokens:
new_token_ids.extend(self.pad_tokens[token_id])
# Do not exceed the max model input size
# The -2 here is compensating for compensate compel.embeddings_provider.get_token_ids(),
# which first removes and then adds back the start and end tokens.
max_length = list(self.tokenizer.max_model_input_sizes.values())[0] - 2
# Do not exceed the max model input size. The -2 here is compensating for
# compel.embeddings_provider.get_token_ids(), which first removes and then adds back the start and end tokens.
max_length = self.tokenizer.model_max_length - 2
if len(new_token_ids) > max_length:
# HACK: If TI token expansion causes us to exceed the max text encoder input length, we silently discard
# tokens. Token expansion should happen in a way that is compatible with compel's default handling of long
# prompts.
new_token_ids = new_token_ids[0:max_length]
return new_token_ids

View File

@@ -1,29 +1,36 @@
"""Context class to silence transformers and diffusers warnings."""
import warnings
from typing import Any
from contextlib import ContextDecorator
from diffusers import logging as diffusers_logging
from diffusers.utils import logging as diffusers_logging
from transformers import logging as transformers_logging
class SilenceWarnings(object):
"""Use in context to temporarily turn off warnings from transformers & diffusers modules.
# Inherit from ContextDecorator to allow using SilenceWarnings as both a context manager and a decorator.
class SilenceWarnings(ContextDecorator):
"""A context manager that disables warnings from transformers & diffusers modules while active.
As context manager:
```
with SilenceWarnings():
# do something
```
As decorator:
```
@SilenceWarnings()
def some_function():
# do something
```
"""
def __init__(self) -> None:
self.transformers_verbosity = transformers_logging.get_verbosity()
self.diffusers_verbosity = diffusers_logging.get_verbosity()
def __enter__(self) -> None:
self._transformers_verbosity = transformers_logging.get_verbosity()
self._diffusers_verbosity = diffusers_logging.get_verbosity()
transformers_logging.set_verbosity_error()
diffusers_logging.set_verbosity_error()
warnings.simplefilter("ignore")
def __exit__(self, *args: Any) -> None:
transformers_logging.set_verbosity(self.transformers_verbosity)
diffusers_logging.set_verbosity(self.diffusers_verbosity)
def __exit__(self, *args) -> None:
transformers_logging.set_verbosity(self._transformers_verbosity)
diffusers_logging.set_verbosity(self._diffusers_verbosity)
warnings.simplefilter("default")

View File

@@ -1,17 +1,43 @@
import base64
import io
import os
import warnings
import re
import unicodedata
from pathlib import Path
from diffusers import logging as diffusers_logging
from PIL import Image
from transformers import logging as transformers_logging
# actual size of a gig
GIG = 1073741824
def slugify(value: str, allow_unicode: bool = False) -> str:
"""
Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
dashes to single dashes. Remove characters that aren't alphanumerics,
underscores, or hyphens. Replace slashes with underscores.
Convert to lowercase. Also strip leading and
trailing whitespace, dashes, and underscores.
Adapted from Django: https://github.com/django/django/blob/main/django/utils/text.py
"""
value = str(value)
if allow_unicode:
value = unicodedata.normalize("NFKC", value)
else:
value = unicodedata.normalize("NFKD", value).encode("ascii", "ignore").decode("ascii")
value = re.sub(r"[/]", "_", value.lower())
value = re.sub(r"[^.\w\s-]", "", value.lower())
return re.sub(r"[-\s]+", "-", value).strip("-_")
def safe_filename(directory: Path, value: str) -> str:
"""Make a string safe to use as a filename."""
escaped_string = slugify(value)
max_name_length = os.pathconf(directory, "PC_NAME_MAX") if hasattr(os, "pathconf") else 256
return escaped_string[len(escaped_string) - max_name_length :]
def directory_size(directory: Path) -> int:
"""
Return the aggregate size of all files in a directory (bytes).
@@ -51,21 +77,3 @@ class Chdir(object):
def __exit__(self, *args):
os.chdir(self.original)
class SilenceWarnings(object):
"""Context manager to temporarily lower verbosity of diffusers & transformers warning messages."""
def __enter__(self):
"""Set verbosity to error."""
self.transformers_verbosity = transformers_logging.get_verbosity()
self.diffusers_verbosity = diffusers_logging.get_verbosity()
transformers_logging.set_verbosity_error()
diffusers_logging.set_verbosity_error()
warnings.simplefilter("ignore")
def __exit__(self, type, value, traceback):
"""Restore logger verbosity to state before context was entered."""
transformers_logging.set_verbosity(self.transformers_verbosity)
diffusers_logging.set_verbosity(self.diffusers_verbosity)
warnings.simplefilter("default")

View File

@@ -1021,7 +1021,8 @@
"float": "Kommazahlen",
"enum": "Aufzählung",
"fullyContainNodes": "Vollständig ausgewählte Nodes auswählen",
"editMode": "Im Workflow-Editor bearbeiten"
"editMode": "Im Workflow-Editor bearbeiten",
"resetToDefaultValue": "Auf Standardwert zurücksetzen"
},
"hrf": {
"enableHrf": "Korrektur für hohe Auflösungen",

View File

@@ -2,6 +2,7 @@
"accessibility": {
"about": "About",
"createIssue": "Create Issue",
"submitSupportTicket": "Submit Support Ticket",
"invokeProgressBar": "Invoke progress bar",
"menu": "Menu",
"mode": "Mode",
@@ -146,7 +147,11 @@
"viewing": "Viewing",
"viewingDesc": "Review images in a large gallery view",
"editing": "Editing",
"editingDesc": "Edit on the Control Layers canvas"
"editingDesc": "Edit on the Control Layers canvas",
"comparing": "Comparing",
"comparingDesc": "Comparing two images",
"enabled": "Enabled",
"disabled": "Disabled"
},
"controlnet": {
"controlAdapter_one": "Control Adapter",
@@ -372,7 +377,23 @@
"bulkDownloadRequestFailed": "Problem Preparing Download",
"bulkDownloadFailed": "Download Failed",
"problemDeletingImages": "Problem Deleting Images",
"problemDeletingImagesDesc": "One or more images could not be deleted"
"problemDeletingImagesDesc": "One or more images could not be deleted",
"viewerImage": "Viewer Image",
"compareImage": "Compare Image",
"openInViewer": "Open in Viewer",
"selectForCompare": "Select for Compare",
"selectAnImageToCompare": "Select an Image to Compare",
"slider": "Slider",
"sideBySide": "Side-by-Side",
"hover": "Hover",
"swapImages": "Swap Images",
"compareOptions": "Comparison Options",
"stretchToFit": "Stretch to Fit",
"exitCompare": "Exit Compare",
"compareHelp1": "Hold <Kbd>Alt</Kbd> while clicking a gallery image or using the arrow keys to change the compare image.",
"compareHelp2": "Press <Kbd>M</Kbd> to cycle through comparison modes.",
"compareHelp3": "Press <Kbd>C</Kbd> to swap the compared images.",
"compareHelp4": "Press <Kbd>Z</Kbd> or <Kbd>Esc</Kbd> to exit."
},
"hotkeys": {
"searchHotkeys": "Search Hotkeys",
@@ -897,7 +918,10 @@
"zoomInNodes": "Zoom In",
"zoomOutNodes": "Zoom Out",
"betaDesc": "This invocation is in beta. Until it is stable, it may have breaking changes during app updates. We plan to support this invocation long-term.",
"prototypeDesc": "This invocation is a prototype. It may have breaking changes during app updates and may be removed at any time."
"prototypeDesc": "This invocation is a prototype. It may have breaking changes during app updates and may be removed at any time.",
"imageAccessError": "Unable to find image {{image_name}}, resetting to default",
"boardAccessError": "Unable to find board {{board_id}}, resetting to default",
"modelAccessError": "Unable to find model {{key}}, resetting to default"
},
"parameters": {
"aspect": "Aspect",
@@ -1070,8 +1094,9 @@
},
"toast": {
"addedToBoard": "Added to board",
"baseModelChangedCleared_one": "Base model changed, cleared or disabled {{count}} incompatible submodel",
"baseModelChangedCleared_other": "Base model changed, cleared or disabled {{count}} incompatible submodels",
"baseModelChanged": "Base Model Changed",
"baseModelChangedCleared_one": "Cleared or disabled {{count}} incompatible submodel",
"baseModelChangedCleared_other": "Cleared or disabled {{count}} incompatible submodels",
"canceled": "Processing Canceled",
"canvasCopiedClipboard": "Canvas Copied to Clipboard",
"canvasDownloaded": "Canvas Downloaded",
@@ -1092,10 +1117,17 @@
"metadataLoadFailed": "Failed to load metadata",
"modelAddedSimple": "Model Added to Queue",
"modelImportCanceled": "Model Import Canceled",
"outOfMemoryError": "Out of Memory Error",
"outOfMemoryErrorDesc": "Your current generation settings exceed system capacity. Please adjust your settings and try again.",
"parameters": "Parameters",
"parameterNotSet": "{{parameter}} not set",
"parameterSet": "{{parameter}} set",
"parametersNotSet": "Parameters Not Set",
"parameterSet": "Parameter Recalled",
"parameterSetDesc": "Recalled {{parameter}}",
"parameterNotSet": "Parameter Not Recalled",
"parameterNotSetDesc": "Unable to recall {{parameter}}",
"parameterNotSetDescWithMessage": "Unable to recall {{parameter}}: {{message}}",
"parametersSet": "Parameters Recalled",
"parametersNotSet": "Parameters Not Recalled",
"errorCopied": "Error Copied",
"problemCopyingCanvas": "Problem Copying Canvas",
"problemCopyingCanvasDesc": "Unable to export base layer",
"problemCopyingImage": "Unable to Copy Image",
@@ -1115,11 +1147,13 @@
"sentToImageToImage": "Sent To Image To Image",
"sentToUnifiedCanvas": "Sent to Unified Canvas",
"serverError": "Server Error",
"sessionRef": "Session: {{sessionId}}",
"setAsCanvasInitialImage": "Set as canvas initial image",
"setCanvasInitialImage": "Set canvas initial image",
"setControlImage": "Set as control image",
"setInitialImage": "Set as initial image",
"setNodeField": "Set as node field",
"somethingWentWrong": "Something Went Wrong",
"uploadFailed": "Upload failed",
"uploadFailedInvalidUploadDesc": "Must be single PNG or JPEG image",
"uploadInitialImage": "Upload Initial Image",
@@ -1559,7 +1593,6 @@
"controlLayers": "Control Layers",
"globalMaskOpacity": "Global Mask Opacity",
"autoNegative": "Auto Negative",
"toggleVisibility": "Toggle Layer Visibility",
"deletePrompt": "Delete Prompt",
"resetRegion": "Reset Region",
"debugLayers": "Debug Layers",

View File

@@ -6,7 +6,7 @@
"settingsLabel": "Ajustes",
"img2img": "Imagen a Imagen",
"unifiedCanvas": "Lienzo Unificado",
"nodes": "Editor del flujo de trabajo",
"nodes": "Flujos de trabajo",
"upload": "Subir imagen",
"load": "Cargar",
"statusDisconnected": "Desconectado",
@@ -14,7 +14,7 @@
"discordLabel": "Discord",
"back": "Atrás",
"loading": "Cargando",
"postprocessing": "Tratamiento posterior",
"postprocessing": "Postprocesado",
"txt2img": "De texto a imagen",
"accept": "Aceptar",
"cancel": "Cancelar",
@@ -42,7 +42,42 @@
"copy": "Copiar",
"beta": "Beta",
"on": "En",
"aboutDesc": "¿Utilizas Invoke para trabajar? Mira aquí:"
"aboutDesc": "¿Utilizas Invoke para trabajar? Mira aquí:",
"installed": "Instalado",
"green": "Verde",
"editor": "Editor",
"orderBy": "Ordenar por",
"file": "Archivo",
"goTo": "Ir a",
"imageFailedToLoad": "No se puede cargar la imagen",
"saveAs": "Guardar Como",
"somethingWentWrong": "Algo salió mal",
"nextPage": "Página Siguiente",
"selected": "Seleccionado",
"tab": "Tabulador",
"positivePrompt": "Prompt Positivo",
"negativePrompt": "Prompt Negativo",
"error": "Error",
"format": "formato",
"unknown": "Desconocido",
"input": "Entrada",
"nodeEditor": "Editor de nodos",
"template": "Plantilla",
"prevPage": "Página Anterior",
"red": "Rojo",
"alpha": "Transparencia",
"outputs": "Salidas",
"editing": "Editando",
"learnMore": "Aprende más",
"enabled": "Activado",
"disabled": "Desactivado",
"folder": "Carpeta",
"updated": "Actualizado",
"created": "Creado",
"save": "Guardar",
"unknownError": "Error Desconocido",
"blue": "Azul",
"viewingDesc": "Revisar imágenes en una vista de galería grande"
},
"gallery": {
"galleryImageSize": "Tamaño de la imagen",
@@ -382,7 +417,7 @@
"canvasMerged": "Lienzo consolidado",
"sentToImageToImage": "Enviar hacia Imagen a Imagen",
"sentToUnifiedCanvas": "Enviar hacia Lienzo Consolidado",
"parametersNotSet": "Parámetros no establecidos",
"parametersNotSet": "Parámetros no recuperados",
"metadataLoadFailed": "Error al cargar metadatos",
"serverError": "Error en el servidor",
"canceled": "Procesando la cancelación",
@@ -390,7 +425,8 @@
"uploadFailedInvalidUploadDesc": "Debe ser una sola imagen PNG o JPEG",
"parameterSet": "Conjunto de parámetros",
"parameterNotSet": "Parámetro no configurado",
"problemCopyingImage": "No se puede copiar la imagen"
"problemCopyingImage": "No se puede copiar la imagen",
"errorCopied": "Error al copiar"
},
"tooltip": {
"feature": {
@@ -466,7 +502,8 @@
"about": "Acerca de",
"createIssue": "Crear un problema",
"resetUI": "Interfaz de usuario $t(accessibility.reset)",
"mode": "Modo"
"mode": "Modo",
"submitSupportTicket": "Enviar Ticket de Soporte"
},
"nodes": {
"zoomInNodes": "Acercar",
@@ -542,5 +579,17 @@
"layers_one": "Capa",
"layers_many": "Capas",
"layers_other": "Capas"
},
"controlnet": {
"crop": "Cortar",
"delete": "Eliminar",
"depthAnythingDescription": "Generación de mapa de profundidad usando la técnica de Depth Anything",
"duplicate": "Duplicar",
"colorMapDescription": "Genera un mapa de color desde la imagen",
"depthMidasDescription": "Crea un mapa de profundidad con Midas",
"balanced": "Equilibrado",
"beginEndStepPercent": "Inicio / Final Porcentaje de pasos",
"detectResolution": "Detectar resolución",
"beginEndStepPercentShort": "Inicio / Final %"
}
}

View File

@@ -45,7 +45,7 @@
"outputs": "Risultati",
"data": "Dati",
"somethingWentWrong": "Qualcosa è andato storto",
"copyError": "$t(gallery.copy) Errore",
"copyError": "Errore $t(gallery.copy)",
"input": "Ingresso",
"notInstalled": "Non $t(common.installed)",
"unknownError": "Errore sconosciuto",
@@ -85,7 +85,11 @@
"viewing": "Visualizza",
"viewingDesc": "Rivedi le immagini in un'ampia vista della galleria",
"editing": "Modifica",
"editingDesc": "Modifica nell'area Livelli di controllo"
"editingDesc": "Modifica nell'area Livelli di controllo",
"enabled": "Abilitato",
"disabled": "Disabilitato",
"comparingDesc": "Confronta due immagini",
"comparing": "Confronta"
},
"gallery": {
"galleryImageSize": "Dimensione dell'immagine",
@@ -122,14 +126,30 @@
"bulkDownloadRequestedDesc": "La tua richiesta di download è in preparazione. L'operazione potrebbe richiedere alcuni istanti.",
"bulkDownloadRequestFailed": "Problema durante la preparazione del download",
"bulkDownloadFailed": "Scaricamento fallito",
"alwaysShowImageSizeBadge": "Mostra sempre le dimensioni dell'immagine"
"alwaysShowImageSizeBadge": "Mostra sempre le dimensioni dell'immagine",
"openInViewer": "Apri nel visualizzatore",
"selectForCompare": "Seleziona per il confronto",
"selectAnImageToCompare": "Seleziona un'immagine da confrontare",
"slider": "Cursore",
"sideBySide": "Fianco a Fianco",
"compareImage": "Immagine di confronto",
"viewerImage": "Immagine visualizzata",
"hover": "Al passaggio del mouse",
"swapImages": "Scambia le immagini",
"compareOptions": "Opzioni di confronto",
"stretchToFit": "Scala per adattare",
"exitCompare": "Esci dal confronto",
"compareHelp1": "Tieni premuto <Kbd>Alt</Kbd> mentre fai clic su un'immagine della galleria o usi i tasti freccia per cambiare l'immagine di confronto.",
"compareHelp2": "Premi <Kbd>M</Kbd> per scorrere le modalità di confronto.",
"compareHelp3": "Premi <Kbd>C</Kbd> per scambiare le immagini confrontate.",
"compareHelp4": "Premi <Kbd>Z</Kbd> o <Kbd>Esc</Kbd> per uscire."
},
"hotkeys": {
"keyboardShortcuts": "Tasti di scelta rapida",
"appHotkeys": "Applicazione",
"generalHotkeys": "Generale",
"galleryHotkeys": "Galleria",
"unifiedCanvasHotkeys": "Tela Unificata",
"unifiedCanvasHotkeys": "Tela",
"invoke": {
"title": "Invoke",
"desc": "Genera un'immagine"
@@ -147,8 +167,8 @@
"desc": "Apre e chiude il pannello delle opzioni"
},
"pinOptions": {
"title": "Appunta le opzioni",
"desc": "Blocca il pannello delle opzioni"
"title": "Fissa le opzioni",
"desc": "Fissa il pannello delle opzioni"
},
"toggleGallery": {
"title": "Attiva/disattiva galleria",
@@ -332,14 +352,14 @@
"title": "Annulla e cancella"
},
"resetOptionsAndGallery": {
"title": "Ripristina Opzioni e Galleria",
"desc": "Reimposta le opzioni e i pannelli della galleria"
"title": "Ripristina le opzioni e la galleria",
"desc": "Reimposta i pannelli delle opzioni e della galleria"
},
"searchHotkeys": "Cerca tasti di scelta rapida",
"noHotkeysFound": "Nessun tasto di scelta rapida trovato",
"toggleOptionsAndGallery": {
"desc": "Apre e chiude le opzioni e i pannelli della galleria",
"title": "Attiva/disattiva le Opzioni e la Galleria"
"title": "Attiva/disattiva le opzioni e la galleria"
},
"clearSearch": "Cancella ricerca",
"remixImage": {
@@ -348,7 +368,7 @@
},
"toggleViewer": {
"title": "Attiva/disattiva il visualizzatore di immagini",
"desc": "Passa dal Visualizzatore immagini all'area di lavoro per la scheda corrente."
"desc": "Passa dal visualizzatore immagini all'area di lavoro per la scheda corrente."
}
},
"modelManager": {
@@ -378,7 +398,7 @@
"convertToDiffusers": "Converti in Diffusori",
"convertToDiffusersHelpText2": "Questo processo sostituirà la voce in Gestione Modelli con la versione Diffusori dello stesso modello.",
"convertToDiffusersHelpText4": "Questo è un processo una tantum. Potrebbero essere necessari circa 30-60 secondi a seconda delle specifiche del tuo computer.",
"convertToDiffusersHelpText5": "Assicurati di avere spazio su disco sufficiente. I modelli generalmente variano tra 2 GB e 7 GB di dimensioni.",
"convertToDiffusersHelpText5": "Assicurati di avere spazio su disco sufficiente. I modelli generalmente variano tra 2 GB e 7 GB in dimensione.",
"convertToDiffusersHelpText6": "Vuoi convertire questo modello?",
"modelConverted": "Modello convertito",
"alpha": "Alpha",
@@ -524,7 +544,20 @@
"missingNodeTemplate": "Modello di nodo mancante",
"missingInputForField": "{{nodeLabel}} -> {{fieldLabel}} ingresso mancante",
"missingFieldTemplate": "Modello di campo mancante",
"imageNotProcessedForControlAdapter": "L'immagine dell'adattatore di controllo #{{number}} non è stata elaborata"
"imageNotProcessedForControlAdapter": "L'immagine dell'adattatore di controllo #{{number}} non è stata elaborata",
"layer": {
"initialImageNoImageSelected": "Nessuna immagine iniziale selezionata",
"t2iAdapterIncompatibleDimensions": "L'adattatore T2I richiede che la dimensione dell'immagine sia un multiplo di {{multiple}}",
"controlAdapterNoModelSelected": "Nessun modello di adattatore di controllo selezionato",
"controlAdapterIncompatibleBaseModel": "Il modello base dell'adattatore di controllo non è compatibile",
"controlAdapterNoImageSelected": "Nessuna immagine dell'adattatore di controllo selezionata",
"controlAdapterImageNotProcessed": "Immagine dell'adattatore di controllo non elaborata",
"ipAdapterNoModelSelected": "Nessun adattatore IP selezionato",
"ipAdapterIncompatibleBaseModel": "Il modello base dell'adattatore IP non è compatibile",
"ipAdapterNoImageSelected": "Nessuna immagine dell'adattatore IP selezionata",
"rgNoPromptsOrIPAdapters": "Nessun prompt o adattatore IP",
"rgNoRegion": "Nessuna regione selezionata"
}
},
"useCpuNoise": "Usa la CPU per generare rumore",
"iterations": "Iterazioni",
@@ -593,25 +626,25 @@
"canvasMerged": "Tela unita",
"sentToImageToImage": "Inviato a Generazione da immagine",
"sentToUnifiedCanvas": "Inviato alla Tela",
"parametersNotSet": "Parametri non impostati",
"parametersNotSet": "Parametri non richiamati",
"metadataLoadFailed": "Impossibile caricare i metadati",
"serverError": "Errore del Server",
"connected": "Connesso al Server",
"connected": "Connesso al server",
"canceled": "Elaborazione annullata",
"uploadFailedInvalidUploadDesc": "Deve essere una singola immagine PNG o JPEG",
"parameterSet": "{{parameter}} impostato",
"parameterNotSet": "{{parameter}} non impostato",
"parameterSet": "Parametro richiamato",
"parameterNotSet": "Parametro non richiamato",
"problemCopyingImage": "Impossibile copiare l'immagine",
"baseModelChangedCleared_one": "Il modello base è stato modificato, cancellato o disabilitato {{count}} sotto-modello incompatibile",
"baseModelChangedCleared_many": "Il modello base è stato modificato, cancellato o disabilitato {{count}} sotto-modelli incompatibili",
"baseModelChangedCleared_other": "Il modello base è stato modificato, cancellato o disabilitato {{count}} sotto-modelli incompatibili",
"baseModelChangedCleared_one": "Cancellato o disabilitato {{count}} sottomodello incompatibile",
"baseModelChangedCleared_many": "Cancellati o disabilitati {{count}} sottomodelli incompatibili",
"baseModelChangedCleared_other": "Cancellati o disabilitati {{count}} sottomodelli incompatibili",
"imageSavingFailed": "Salvataggio dell'immagine non riuscito",
"canvasSentControlnetAssets": "Tela inviata a ControlNet & Risorse",
"problemCopyingCanvasDesc": "Impossibile copiare la tela",
"loadedWithWarnings": "Flusso di lavoro caricato con avvisi",
"canvasCopiedClipboard": "Tela copiata negli appunti",
"maskSavedAssets": "Maschera salvata nelle risorse",
"problemDownloadingCanvas": "Problema durante il download della tela",
"problemDownloadingCanvas": "Problema durante lo scarico della tela",
"problemMergingCanvas": "Problema nell'unione delle tele",
"imageUploaded": "Immagine caricata",
"addedToBoard": "Aggiunto alla bacheca",
@@ -645,7 +678,17 @@
"problemDownloadingImage": "Impossibile scaricare l'immagine",
"prunedQueue": "Coda ripulita",
"modelImportCanceled": "Importazione del modello annullata",
"parameters": "Parametri"
"parameters": "Parametri",
"parameterSetDesc": "{{parameter}} richiamato",
"parameterNotSetDesc": "Impossibile richiamare {{parameter}}",
"parameterNotSetDescWithMessage": "Impossibile richiamare {{parameter}}: {{message}}",
"parametersSet": "Parametri richiamati",
"errorCopied": "Errore copiato",
"outOfMemoryError": "Errore di memoria esaurita",
"baseModelChanged": "Modello base modificato",
"sessionRef": "Sessione: {{sessionId}}",
"somethingWentWrong": "Qualcosa è andato storto",
"outOfMemoryErrorDesc": "Le impostazioni della generazione attuale superano la capacità del sistema. Modifica le impostazioni e riprova."
},
"tooltip": {
"feature": {
@@ -661,7 +704,7 @@
"layer": "Livello",
"base": "Base",
"mask": "Maschera",
"maskingOptions": "Opzioni di mascheramento",
"maskingOptions": "Opzioni maschera",
"enableMask": "Abilita maschera",
"preserveMaskedArea": "Mantieni area mascherata",
"clearMask": "Cancella maschera (Shift+C)",
@@ -732,7 +775,8 @@
"mode": "Modalità",
"resetUI": "$t(accessibility.reset) l'Interfaccia Utente",
"createIssue": "Segnala un problema",
"about": "Informazioni"
"about": "Informazioni",
"submitSupportTicket": "Invia ticket di supporto"
},
"nodes": {
"zoomOutNodes": "Rimpicciolire",
@@ -777,7 +821,7 @@
"workflowNotes": "Note",
"versionUnknown": " Versione sconosciuta",
"unableToValidateWorkflow": "Impossibile convalidare il flusso di lavoro",
"updateApp": "Aggiorna App",
"updateApp": "Aggiorna Applicazione",
"unableToLoadWorkflow": "Impossibile caricare il flusso di lavoro",
"updateNode": "Aggiorna nodo",
"version": "Versione",
@@ -824,8 +868,8 @@
"unableToUpdateNodes_other": "Impossibile aggiornare {{count}} nodi",
"addLinearView": "Aggiungi alla vista Lineare",
"unknownErrorValidatingWorkflow": "Errore sconosciuto durante la convalida del flusso di lavoro",
"collectionFieldType": "{{name}} Raccolta",
"collectionOrScalarFieldType": "{{name}} Raccolta|Scalare",
"collectionFieldType": "{{name}} (Raccolta)",
"collectionOrScalarFieldType": "{{name}} (Singola o Raccolta)",
"nodeVersion": "Versione Nodo",
"inputFieldTypeParseError": "Impossibile analizzare il tipo di campo di input {{node}}.{{field}} ({{message}})",
"unsupportedArrayItemType": "Tipo di elemento dell'array non supportato \"{{type}}\"",
@@ -863,11 +907,20 @@
"edit": "Modifica",
"graph": "Grafico",
"showEdgeLabelsHelp": "Mostra etichette sui collegamenti, che indicano i nodi collegati",
"showEdgeLabels": "Mostra le etichette del collegamento"
"showEdgeLabels": "Mostra le etichette del collegamento",
"cannotMixAndMatchCollectionItemTypes": "Impossibile combinare e abbinare i tipi di elementi della raccolta",
"noGraph": "Nessun grafico",
"missingNode": "Nodo di invocazione mancante",
"missingInvocationTemplate": "Modello di invocazione mancante",
"missingFieldTemplate": "Modello di campo mancante",
"singleFieldType": "{{name}} (Singola)",
"imageAccessError": "Impossibile trovare l'immagine {{image_name}}, ripristino delle impostazioni predefinite",
"boardAccessError": "Impossibile trovare la bacheca {{board_id}}, ripristino ai valori predefiniti",
"modelAccessError": "Impossibile trovare il modello {{key}}, ripristino ai valori predefiniti"
},
"boards": {
"autoAddBoard": "Aggiungi automaticamente bacheca",
"menuItemAutoAdd": "Aggiungi automaticamente a questa Bacheca",
"menuItemAutoAdd": "Aggiungi automaticamente a questa bacheca",
"cancel": "Annulla",
"addBoard": "Aggiungi Bacheca",
"bottomMessage": "L'eliminazione di questa bacheca e delle sue immagini ripristinerà tutte le funzionalità che le stanno attualmente utilizzando.",
@@ -879,7 +932,7 @@
"myBoard": "Bacheca",
"searchBoard": "Cerca bacheche ...",
"noMatching": "Nessuna bacheca corrispondente",
"selectBoard": "Seleziona una Bacheca",
"selectBoard": "Seleziona una bacheca",
"uncategorized": "Non categorizzato",
"downloadBoard": "Scarica la bacheca",
"deleteBoardOnly": "solo la Bacheca",
@@ -900,7 +953,7 @@
"control": "Controllo",
"crop": "Ritaglia",
"depthMidas": "Profondità (Midas)",
"detectResolution": "Rileva risoluzione",
"detectResolution": "Rileva la risoluzione",
"controlMode": "Modalità di controllo",
"cannyDescription": "Canny rilevamento bordi",
"depthZoe": "Profondità (Zoe)",
@@ -911,7 +964,7 @@
"showAdvanced": "Mostra opzioni Avanzate",
"bgth": "Soglia rimozione sfondo",
"importImageFromCanvas": "Importa immagine dalla Tela",
"lineartDescription": "Converte l'immagine in lineart",
"lineartDescription": "Converte l'immagine in linea",
"importMaskFromCanvas": "Importa maschera dalla Tela",
"hideAdvanced": "Nascondi opzioni avanzate",
"resetControlImage": "Reimposta immagine di controllo",
@@ -927,7 +980,7 @@
"pidiDescription": "Elaborazione immagini PIDI",
"fill": "Riempie",
"colorMapDescription": "Genera una mappa dei colori dall'immagine",
"lineartAnimeDescription": "Elaborazione lineart in stile anime",
"lineartAnimeDescription": "Elaborazione linea in stile anime",
"imageResolution": "Risoluzione dell'immagine",
"colorMap": "Colore",
"lowThreshold": "Soglia inferiore",
@@ -1034,7 +1087,16 @@
"graphFailedToQueue": "Impossibile mettere in coda il grafico",
"batchFieldValues": "Valori Campi Lotto",
"time": "Tempo",
"openQueue": "Apri coda"
"openQueue": "Apri coda",
"iterations_one": "Iterazione",
"iterations_many": "Iterazioni",
"iterations_other": "Iterazioni",
"prompts_one": "Prompt",
"prompts_many": "Prompt",
"prompts_other": "Prompt",
"generations_one": "Generazione",
"generations_many": "Generazioni",
"generations_other": "Generazioni"
},
"models": {
"noMatchingModels": "Nessun modello corrispondente",
@@ -1563,7 +1625,6 @@
"brushSize": "Dimensioni del pennello",
"globalMaskOpacity": "Opacità globale della maschera",
"autoNegative": "Auto Negativo",
"toggleVisibility": "Attiva/disattiva la visibilità dei livelli",
"deletePrompt": "Cancella il prompt",
"debugLayers": "Debug dei Livelli",
"rectangle": "Rettangolo",

View File

@@ -6,7 +6,7 @@
"settingsLabel": "Instellingen",
"img2img": "Afbeelding naar afbeelding",
"unifiedCanvas": "Centraal canvas",
"nodes": "Werkstroom-editor",
"nodes": "Werkstromen",
"upload": "Upload",
"load": "Laad",
"statusDisconnected": "Niet verbonden",
@@ -34,7 +34,60 @@
"controlNet": "ControlNet",
"imageFailedToLoad": "Kan afbeelding niet laden",
"learnMore": "Meer informatie",
"advanced": "Uitgebreid"
"advanced": "Uitgebreid",
"file": "Bestand",
"installed": "Geïnstalleerd",
"notInstalled": "Niet $t(common.installed)",
"simple": "Eenvoudig",
"somethingWentWrong": "Er ging iets mis",
"add": "Voeg toe",
"checkpoint": "Checkpoint",
"details": "Details",
"outputs": "Uitvoeren",
"save": "Bewaar",
"nextPage": "Volgende pagina",
"blue": "Blauw",
"alpha": "Alfa",
"red": "Rood",
"editor": "Editor",
"folder": "Map",
"format": "structuur",
"goTo": "Ga naar",
"template": "Sjabloon",
"input": "Invoer",
"loglevel": "Logboekniveau",
"safetensors": "Safetensors",
"saveAs": "Bewaar als",
"created": "Gemaakt",
"green": "Groen",
"tab": "Tab",
"positivePrompt": "Positieve prompt",
"negativePrompt": "Negatieve prompt",
"selected": "Geselecteerd",
"orderBy": "Sorteer op",
"prevPage": "Vorige pagina",
"beta": "Bèta",
"copyError": "$t(gallery.copy) Fout",
"toResolve": "Op te lossen",
"aboutDesc": "Gebruik je Invoke voor het werk? Kijk dan naar:",
"aboutHeading": "Creatieve macht voor jou",
"copy": "Kopieer",
"data": "Gegevens",
"or": "of",
"updated": "Bijgewerkt",
"outpaint": "outpainten",
"viewing": "Bekijken",
"viewingDesc": "Beoordeel afbeelding in een grote galerijweergave",
"editing": "Bewerken",
"editingDesc": "Bewerk op het canvas Stuurlagen",
"ai": "ai",
"inpaint": "inpainten",
"unknown": "Onbekend",
"delete": "Verwijder",
"direction": "Richting",
"error": "Fout",
"localSystem": "Lokaal systeem",
"unknownError": "Onbekende fout"
},
"gallery": {
"galleryImageSize": "Afbeeldingsgrootte",
@@ -310,10 +363,41 @@
"modelSyncFailed": "Synchronisatie modellen mislukt",
"modelDeleteFailed": "Model kon niet verwijderd worden",
"convertingModelBegin": "Model aan het converteren. Even geduld.",
"predictionType": "Soort voorspelling (voor Stable Diffusion 2.x-modellen en incidentele Stable Diffusion 1.x-modellen)",
"predictionType": "Soort voorspelling",
"advanced": "Uitgebreid",
"modelType": "Soort model",
"vaePrecision": "Nauwkeurigheid VAE"
"vaePrecision": "Nauwkeurigheid VAE",
"loraTriggerPhrases": "LoRA-triggerzinnen",
"urlOrLocalPathHelper": "URL's zouden moeten wijzen naar een los bestand. Lokale paden kunnen wijzen naar een los bestand of map voor een individueel Diffusers-model.",
"modelName": "Modelnaam",
"path": "Pad",
"triggerPhrases": "Triggerzinnen",
"typePhraseHere": "Typ zin hier in",
"useDefaultSettings": "Gebruik standaardinstellingen",
"modelImageDeleteFailed": "Fout bij verwijderen modelafbeelding",
"modelImageUpdated": "Modelafbeelding bijgewerkt",
"modelImageUpdateFailed": "Fout bij bijwerken modelafbeelding",
"noMatchingModels": "Geen overeenkomende modellen",
"scanPlaceholder": "Pad naar een lokale map",
"noModelsInstalled": "Geen modellen geïnstalleerd",
"noModelsInstalledDesc1": "Installeer modellen met de",
"noModelSelected": "Geen model geselecteerd",
"starterModels": "Beginnermodellen",
"textualInversions": "Tekstuele omkeringen",
"upcastAttention": "Upcast-aandacht",
"uploadImage": "Upload afbeelding",
"mainModelTriggerPhrases": "Triggerzinnen hoofdmodel",
"urlOrLocalPath": "URL of lokaal pad",
"scanFolderHelper": "De map zal recursief worden ingelezen voor modellen. Dit kan enige tijd in beslag nemen voor erg grote mappen.",
"simpleModelPlaceholder": "URL of pad naar een lokaal pad of Diffusers-map",
"modelSettings": "Modelinstellingen",
"pathToConfig": "Pad naar configuratie",
"prune": "Snoei",
"pruneTooltip": "Snoei voltooide importeringen uit wachtrij",
"repoVariant": "Repovariant",
"scanFolder": "Lees map in",
"scanResults": "Resultaten inlezen",
"source": "Bron"
},
"parameters": {
"images": "Afbeeldingen",
@@ -353,13 +437,13 @@
"copyImage": "Kopieer afbeelding",
"denoisingStrength": "Sterkte ontruisen",
"scheduler": "Planner",
"seamlessXAxis": "X-as",
"seamlessYAxis": "Y-as",
"seamlessXAxis": "Naadloze tegels in x-as",
"seamlessYAxis": "Naadloze tegels in y-as",
"clipSkip": "Overslaan CLIP",
"negativePromptPlaceholder": "Negatieve prompt",
"controlNetControlMode": "Aansturingsmodus",
"positivePromptPlaceholder": "Positieve prompt",
"maskBlur": "Vervaag",
"maskBlur": "Vervaging van masker",
"invoke": {
"noNodesInGraph": "Geen knooppunten in graaf",
"noModelSelected": "Geen model ingesteld",
@@ -369,11 +453,25 @@
"missingInputForField": "{{nodeLabel}} -> {{fieldLabel}} invoer ontbreekt",
"noControlImageForControlAdapter": "Controle-adapter #{{number}} heeft geen controle-afbeelding",
"noModelForControlAdapter": "Control-adapter #{{number}} heeft geen model ingesteld staan.",
"incompatibleBaseModelForControlAdapter": "Model van controle-adapter #{{number}} is ongeldig in combinatie met het hoofdmodel.",
"incompatibleBaseModelForControlAdapter": "Model van controle-adapter #{{number}} is niet compatibel met het hoofdmodel.",
"systemDisconnected": "Systeem is niet verbonden",
"missingNodeTemplate": "Knooppuntsjabloon ontbreekt",
"missingFieldTemplate": "Veldsjabloon ontbreekt",
"addingImagesTo": "Bezig met toevoegen van afbeeldingen aan"
"addingImagesTo": "Bezig met toevoegen van afbeeldingen aan",
"layer": {
"initialImageNoImageSelected": "geen initiële afbeelding geselecteerd",
"controlAdapterNoModelSelected": "geen controle-adaptermodel geselecteerd",
"controlAdapterIncompatibleBaseModel": "niet-compatibele basismodel voor controle-adapter",
"controlAdapterNoImageSelected": "geen afbeelding voor controle-adapter geselecteerd",
"controlAdapterImageNotProcessed": "Afbeelding voor controle-adapter niet verwerkt",
"ipAdapterIncompatibleBaseModel": "niet-compatibele basismodel voor IP-adapter",
"ipAdapterNoImageSelected": "geen afbeelding voor IP-adapter geselecteerd",
"rgNoRegion": "geen gebied geselecteerd",
"rgNoPromptsOrIPAdapters": "geen tekstprompts of IP-adapters",
"t2iAdapterIncompatibleDimensions": "T2I-adapter vereist een afbeelding met afmetingen met een veelvoud van 64",
"ipAdapterNoModelSelected": "geen IP-adapter geselecteerd"
},
"imageNotProcessedForControlAdapter": "De afbeelding van controle-adapter #{{number}} is niet verwerkt"
},
"isAllowedToUpscale": {
"useX2Model": "Afbeelding is te groot om te vergroten met het x4-model. Gebruik hiervoor het x2-model",
@@ -383,7 +481,26 @@
"useCpuNoise": "Gebruik CPU-ruis",
"imageActions": "Afbeeldingshandeling",
"iterations": "Iteraties",
"coherenceMode": "Modus"
"coherenceMode": "Modus",
"infillColorValue": "Vulkleur",
"remixImage": "Meng afbeelding opnieuw",
"setToOptimalSize": "Optimaliseer grootte voor het model",
"setToOptimalSizeTooSmall": "$t(parameters.setToOptimalSize) (is mogelijk te klein)",
"aspect": "Beeldverhouding",
"infillMosaicTileWidth": "Breedte tegel",
"setToOptimalSizeTooLarge": "$t(parameters.setToOptimalSize) (is mogelijk te groot)",
"lockAspectRatio": "Zet beeldverhouding vast",
"infillMosaicTileHeight": "Hoogte tegel",
"globalNegativePromptPlaceholder": "Globale negatieve prompt",
"globalPositivePromptPlaceholder": "Globale positieve prompt",
"useSize": "Gebruik grootte",
"swapDimensions": "Wissel afmetingen om",
"globalSettings": "Globale instellingen",
"coherenceEdgeSize": "Randgrootte",
"coherenceMinDenoise": "Min. ontruising",
"infillMosaicMinColor": "Min. kleur",
"infillMosaicMaxColor": "Max. kleur",
"cfgRescaleMultiplier": "Vermenigvuldiger voor CFG-herschaling"
},
"settings": {
"models": "Modellen",
@@ -410,7 +527,12 @@
"intermediatesCleared_one": "{{count}} tussentijdse afbeelding gewist",
"intermediatesCleared_other": "{{count}} tussentijdse afbeeldingen gewist",
"clearIntermediatesDesc1": "Als je tussentijdse afbeeldingen wist, dan wordt de staat hersteld van je canvas en van ControlNet.",
"intermediatesClearedFailed": "Fout bij wissen van tussentijdse afbeeldingen"
"intermediatesClearedFailed": "Fout bij wissen van tussentijdse afbeeldingen",
"clearIntermediatesDisabled": "Wachtrij moet leeg zijn om tussentijdse afbeeldingen te kunnen leegmaken",
"enableInformationalPopovers": "Schakel informatieve hulpballonnen in",
"enableInvisibleWatermark": "Schakel onzichtbaar watermerk in",
"enableNSFWChecker": "Schakel NSFW-controle in",
"reloadingIn": "Opnieuw laden na"
},
"toast": {
"uploadFailed": "Upload mislukt",
@@ -425,8 +547,8 @@
"connected": "Verbonden met server",
"canceled": "Verwerking geannuleerd",
"uploadFailedInvalidUploadDesc": "Moet een enkele PNG- of JPEG-afbeelding zijn",
"parameterNotSet": "Parameter niet ingesteld",
"parameterSet": "Instellen parameters",
"parameterNotSet": "{{parameter}} niet ingesteld",
"parameterSet": "{{parameter}} ingesteld",
"problemCopyingImage": "Kan Afbeelding Niet Kopiëren",
"baseModelChangedCleared_one": "Basismodel is gewijzigd: {{count}} niet-compatibel submodel weggehaald of uitgeschakeld",
"baseModelChangedCleared_other": "Basismodel is gewijzigd: {{count}} niet-compatibele submodellen weggehaald of uitgeschakeld",
@@ -443,11 +565,11 @@
"maskSavedAssets": "Masker bewaard in Assets",
"problemDownloadingCanvas": "Fout bij downloaden van canvas",
"problemMergingCanvas": "Fout bij samenvoegen canvas",
"setCanvasInitialImage": "Ingesteld als initiële canvasafbeelding",
"setCanvasInitialImage": "Initiële canvasafbeelding ingesteld",
"imageUploaded": "Afbeelding geüpload",
"addedToBoard": "Toegevoegd aan bord",
"workflowLoaded": "Werkstroom geladen",
"modelAddedSimple": "Model toegevoegd",
"modelAddedSimple": "Model toegevoegd aan wachtrij",
"problemImportingMaskDesc": "Kan masker niet exporteren",
"problemCopyingCanvas": "Fout bij kopiëren canvas",
"problemSavingCanvas": "Fout bij bewaren canvas",
@@ -459,7 +581,18 @@
"maskSentControlnetAssets": "Masker gestuurd naar ControlNet en Assets",
"canvasSavedGallery": "Canvas bewaard in galerij",
"imageUploadFailed": "Fout bij uploaden afbeelding",
"problemImportingMask": "Fout bij importeren masker"
"problemImportingMask": "Fout bij importeren masker",
"workflowDeleted": "Werkstroom verwijderd",
"invalidUpload": "Ongeldige upload",
"uploadInitialImage": "Initiële afbeelding uploaden",
"setAsCanvasInitialImage": "Ingesteld als initiële afbeelding voor canvas",
"problemRetrievingWorkflow": "Fout bij ophalen van werkstroom",
"parameters": "Parameters",
"modelImportCanceled": "Importeren model geannuleerd",
"problemDeletingWorkflow": "Fout bij verwijderen van werkstroom",
"prunedQueue": "Wachtrij gesnoeid",
"problemDownloadingImage": "Fout bij downloaden afbeelding",
"resetInitialImage": "Initiële afbeelding hersteld"
},
"tooltip": {
"feature": {
@@ -533,7 +666,11 @@
"showOptionsPanel": "Toon zijscherm",
"menu": "Menu",
"showGalleryPanel": "Toon deelscherm Galerij",
"loadMore": "Laad meer"
"loadMore": "Laad meer",
"about": "Over",
"mode": "Modus",
"resetUI": "$t(accessibility.reset) UI",
"createIssue": "Maak probleem aan"
},
"nodes": {
"zoomOutNodes": "Uitzoomen",
@@ -547,7 +684,7 @@
"loadWorkflow": "Laad werkstroom",
"downloadWorkflow": "Download JSON van werkstroom",
"scheduler": "Planner",
"missingTemplate": "Ontbrekende sjabloon",
"missingTemplate": "Ongeldig knooppunt: knooppunt {{node}} van het soort {{type}} heeft een ontbrekend sjabloon (niet geïnstalleerd?)",
"workflowDescription": "Korte beschrijving",
"versionUnknown": " Versie onbekend",
"noNodeSelected": "Geen knooppunt gekozen",
@@ -563,7 +700,7 @@
"integer": "Geheel getal",
"nodeTemplate": "Sjabloon knooppunt",
"nodeOpacity": "Dekking knooppunt",
"unableToLoadWorkflow": "Kan werkstroom niet valideren",
"unableToLoadWorkflow": "Fout bij laden werkstroom",
"snapToGrid": "Lijn uit op raster",
"noFieldsLinearview": "Geen velden toegevoegd aan lineaire weergave",
"nodeSearch": "Zoek naar knooppunten",
@@ -614,11 +751,56 @@
"unknownField": "Onbekend veld",
"colorCodeEdges": "Kleurgecodeerde randen",
"unknownNode": "Onbekend knooppunt",
"mismatchedVersion": "Heeft niet-overeenkomende versie",
"mismatchedVersion": "Ongeldig knooppunt: knooppunt {{node}} van het soort {{type}} heeft een niet-overeenkomende versie (probeer het bij te werken?)",
"addNodeToolTip": "Voeg knooppunt toe (Shift+A, spatie)",
"loadingNodes": "Bezig met laden van knooppunten...",
"snapToGridHelp": "Lijn knooppunten uit op raster bij verplaatsing",
"workflowSettings": "Instellingen werkstroomeditor"
"workflowSettings": "Instellingen werkstroomeditor",
"addLinearView": "Voeg toe aan lineaire weergave",
"nodePack": "Knooppuntpakket",
"unknownInput": "Onbekende invoer: {{name}}",
"sourceNodeFieldDoesNotExist": "Ongeldige rand: bron-/uitvoerveld {{node}}.{{field}} bestaat niet",
"collectionFieldType": "Verzameling {{name}}",
"deletedInvalidEdge": "Ongeldige hoek {{source}} -> {{target}} verwijderd",
"graph": "Grafiek",
"targetNodeDoesNotExist": "Ongeldige rand: doel-/invoerknooppunt {{node}} bestaat niet",
"resetToDefaultValue": "Herstel naar standaardwaarden",
"editMode": "Bewerk in Werkstroom-editor",
"showEdgeLabels": "Toon randlabels",
"showEdgeLabelsHelp": "Toon labels aan randen, waarmee de verbonden knooppunten mee worden aangegeven",
"clearWorkflowDesc2": "Je huidige werkstroom heeft niet-bewaarde wijzigingen.",
"unableToParseFieldType": "fout bij bepalen soort veld",
"sourceNodeDoesNotExist": "Ongeldige rand: bron-/uitvoerknooppunt {{node}} bestaat niet",
"unsupportedArrayItemType": "niet-ondersteunde soort van het array-onderdeel \"{{type}}\"",
"targetNodeFieldDoesNotExist": "Ongeldige rand: doel-/invoerveld {{node}}.{{field}} bestaat niet",
"reorderLinearView": "Herorden lineaire weergave",
"newWorkflowDesc": "Een nieuwe werkstroom aanmaken?",
"collectionOrScalarFieldType": "Verzameling|scalair {{name}}",
"newWorkflow": "Nieuwe werkstroom",
"unknownErrorValidatingWorkflow": "Onbekende fout bij valideren werkstroom",
"unsupportedAnyOfLength": "te veel union-leden ({{count}})",
"unknownOutput": "Onbekende uitvoer: {{name}}",
"viewMode": "Gebruik in lineaire weergave",
"unableToExtractSchemaNameFromRef": "fout bij het extraheren van de schemanaam via de ref",
"unsupportedMismatchedUnion": "niet-overeenkomende soort CollectionOrScalar met basissoorten {{firstType}} en {{secondType}}",
"unknownNodeType": "Onbekend soort knooppunt",
"edit": "Bewerk",
"updateAllNodes": "Werk knooppunten bij",
"allNodesUpdated": "Alle knooppunten bijgewerkt",
"nodeVersion": "Knooppuntversie",
"newWorkflowDesc2": "Je huidige werkstroom heeft niet-bewaarde wijzigingen.",
"clearWorkflow": "Maak werkstroom leeg",
"clearWorkflowDesc": "Deze werkstroom leegmaken en met een nieuwe beginnen?",
"inputFieldTypeParseError": "Fout bij bepalen van het soort invoerveld {{node}}.{{field}} ({{message}})",
"outputFieldTypeParseError": "Fout bij het bepalen van het soort uitvoerveld {{node}}.{{field}} ({{message}})",
"unableToExtractEnumOptions": "fout bij extraheren enumeratie-opties",
"unknownFieldType": "Soort $t(nodes.unknownField): {{type}}",
"unableToGetWorkflowVersion": "Fout bij ophalen schemaversie van werkstroom",
"betaDesc": "Deze uitvoering is in bèta. Totdat deze stabiel is kunnen er wijzigingen voorkomen gedurende app-updates die zaken kapotmaken. We zijn van plan om deze uitvoering op lange termijn te gaan ondersteunen.",
"prototypeDesc": "Deze uitvoering is een prototype. Er kunnen wijzigingen voorkomen gedurende app-updates die zaken kapotmaken. Deze kunnen op een willekeurig moment verwijderd worden.",
"noFieldsViewMode": "Deze werkstroom heeft geen geselecteerde velden om te tonen. Bekijk de volledige werkstroom om de waarden te configureren.",
"unableToUpdateNodes_one": "Fout bij bijwerken van {{count}} knooppunt",
"unableToUpdateNodes_other": "Fout bij bijwerken van {{count}} knooppunten"
},
"controlnet": {
"amult": "a_mult",
@@ -691,9 +873,28 @@
"canny": "Canny",
"depthZoeDescription": "Genereer diepteblad via Zoe",
"hedDescription": "Herkenning van holistisch-geneste randen",
"setControlImageDimensions": "Stel afmetingen controle-afbeelding in op B/H",
"setControlImageDimensions": "Kopieer grootte naar B/H (optimaliseer voor model)",
"scribble": "Krabbel",
"maxFaces": "Max. gezichten"
"maxFaces": "Max. gezichten",
"dwOpenpose": "DW Openpose",
"depthAnything": "Depth Anything",
"base": "Basis",
"hands": "Handen",
"selectCLIPVisionModel": "Selecteer een CLIP Vision-model",
"modelSize": "Modelgrootte",
"small": "Klein",
"large": "Groot",
"resizeSimple": "Wijzig grootte (eenvoudig)",
"beginEndStepPercentShort": "Begin-/eind-%",
"depthAnythingDescription": "Genereren dieptekaart d.m.v. de techniek Depth Anything",
"face": "Gezicht",
"body": "Lichaam",
"dwOpenposeDescription": "Schatting menselijke pose d.m.v. DW Openpose",
"ipAdapterMethod": "Methode",
"full": "Volledig",
"style": "Alleen stijl",
"composition": "Alleen samenstelling",
"setControlImageDimensionsForce": "Kopieer grootte naar B/H (negeer model)"
},
"dynamicPrompts": {
"seedBehaviour": {
@@ -706,7 +907,10 @@
"maxPrompts": "Max. prompts",
"promptsWithCount_one": "{{count}} prompt",
"promptsWithCount_other": "{{count}} prompts",
"dynamicPrompts": "Dynamische prompts"
"dynamicPrompts": "Dynamische prompts",
"showDynamicPrompts": "Toon dynamische prompts",
"loading": "Genereren van dynamische prompts...",
"promptsPreview": "Voorvertoning prompts"
},
"popovers": {
"noiseUseCPU": {
@@ -719,7 +923,7 @@
},
"paramScheduler": {
"paragraphs": [
"De planner bepaalt hoe ruis per iteratie wordt toegevoegd aan een afbeelding of hoe een monster wordt bijgewerkt op basis van de uitvoer van een model."
"De planner gebruikt gedurende het genereringsproces."
],
"heading": "Planner"
},
@@ -806,8 +1010,8 @@
},
"clipSkip": {
"paragraphs": [
"Kies hoeveel CLIP-modellagen je wilt overslaan.",
"Bepaalde modellen werken beter met bepaalde Overslaan CLIP-instellingen."
"Aantal over te slaan CLIP-modellagen.",
"Bepaalde modellen zijn beter geschikt met bepaalde Overslaan CLIP-instellingen."
],
"heading": "Overslaan CLIP"
},
@@ -991,17 +1195,26 @@
"denoisingStrength": "Sterkte ontruising",
"refinermodel": "Verfijningsmodel",
"posAestheticScore": "Positieve esthetische score",
"concatPromptStyle": "Plak prompt- en stijltekst aan elkaar",
"concatPromptStyle": "Koppelen van prompt en stijl",
"loading": "Bezig met laden...",
"steps": "Stappen",
"posStylePrompt": "Positieve-stijlprompt"
"posStylePrompt": "Positieve-stijlprompt",
"freePromptStyle": "Handmatige stijlprompt",
"refinerSteps": "Aantal stappen verfijner"
},
"models": {
"noMatchingModels": "Geen overeenkomend modellen",
"loading": "bezig met laden",
"noMatchingLoRAs": "Geen overeenkomende LoRA's",
"noModelsAvailable": "Geen modellen beschikbaar",
"selectModel": "Kies een model"
"selectModel": "Kies een model",
"noLoRAsInstalled": "Geen LoRA's geïnstalleerd",
"noRefinerModelsInstalled": "Geen SDXL-verfijningsmodellen geïnstalleerd",
"defaultVAE": "Standaard-VAE",
"lora": "LoRA",
"esrganModel": "ESRGAN-model",
"addLora": "Voeg LoRA toe",
"concepts": "Concepten"
},
"boards": {
"autoAddBoard": "Voeg automatisch bord toe",
@@ -1019,7 +1232,13 @@
"downloadBoard": "Download bord",
"changeBoard": "Wijzig bord",
"loading": "Bezig met laden...",
"clearSearch": "Maak zoekopdracht leeg"
"clearSearch": "Maak zoekopdracht leeg",
"deleteBoard": "Verwijder bord",
"deleteBoardAndImages": "Verwijder bord en afbeeldingen",
"deleteBoardOnly": "Verwijder alleen bord",
"deletedBoardsCannotbeRestored": "Verwijderde borden kunnen niet worden hersteld",
"movingImagesToBoard_one": "Verplaatsen van {{count}} afbeelding naar bord:",
"movingImagesToBoard_other": "Verplaatsen van {{count}} afbeeldingen naar bord:"
},
"invocationCache": {
"disable": "Schakel uit",
@@ -1036,5 +1255,39 @@
"clear": "Wis",
"maxCacheSize": "Max. grootte cache",
"cacheSize": "Grootte cache"
},
"accordions": {
"generation": {
"title": "Genereren"
},
"image": {
"title": "Afbeelding"
},
"advanced": {
"title": "Geavanceerd",
"options": "$t(accordions.advanced.title) Opties"
},
"control": {
"title": "Besturing"
},
"compositing": {
"title": "Samenstellen",
"coherenceTab": "Coherentiefase",
"infillTab": "Invullen"
}
},
"hrf": {
"upscaleMethod": "Opschaalmethode",
"metadata": {
"strength": "Sterkte oplossing voor hoge resolutie",
"method": "Methode oplossing voor hoge resolutie",
"enabled": "Oplossing voor hoge resolutie ingeschakeld"
},
"hrf": "Oplossing voor hoge resolutie",
"enableHrf": "Schakel oplossing in voor hoge resolutie"
},
"prompt": {
"addPromptTrigger": "Voeg prompttrigger toe",
"compatibleEmbeddings": "Compatibele embeddings"
}
}

View File

@@ -87,7 +87,11 @@
"viewing": "Просмотр",
"editing": "Редактирование",
"viewingDesc": "Просмотр изображений в режиме большой галереи",
"editingDesc": "Редактировать на холсте слоёв управления"
"editingDesc": "Редактировать на холсте слоёв управления",
"enabled": "Включено",
"disabled": "Отключено",
"comparingDesc": "Сравнение двух изображений",
"comparing": "Сравнение"
},
"gallery": {
"galleryImageSize": "Размер изображений",
@@ -124,7 +128,23 @@
"bulkDownloadRequested": "Подготовка к скачиванию",
"bulkDownloadRequestedDesc": "Ваш запрос на скачивание готовится. Это может занять несколько минут.",
"bulkDownloadRequestFailed": "Возникла проблема при подготовке скачивания",
"alwaysShowImageSizeBadge": "Всегда показывать значок размера изображения"
"alwaysShowImageSizeBadge": "Всегда показывать значок размера изображения",
"openInViewer": "Открыть в просмотрщике",
"selectForCompare": "Выбрать для сравнения",
"hover": "Наведение",
"swapImages": "Поменять местами",
"stretchToFit": "Растягивание до нужного размера",
"exitCompare": "Выйти из сравнения",
"compareHelp4": "Нажмите <Kbd>Z</Kbd> или <Kbd>Esc</Kbd> для выхода.",
"compareImage": "Сравнить изображение",
"viewerImage": "Изображение просмотрщика",
"selectAnImageToCompare": "Выберите изображение для сравнения",
"slider": "Слайдер",
"sideBySide": "Бок о бок",
"compareOptions": "Варианты сравнения",
"compareHelp1": "Удерживайте <Kbd>Alt</Kbd> при нажатии на изображение в галерее или при помощи клавиш со стрелками, чтобы изменить сравниваемое изображение.",
"compareHelp2": "Нажмите <Kbd>M</Kbd>, чтобы переключиться между режимами сравнения.",
"compareHelp3": "Нажмите <Kbd>C</Kbd>, чтобы поменять местами сравниваемые изображения."
},
"hotkeys": {
"keyboardShortcuts": "Горячие клавиши",
@@ -528,7 +548,20 @@
"missingFieldTemplate": "Отсутствует шаблон поля",
"addingImagesTo": "Добавление изображений в",
"invoke": "Создать",
"imageNotProcessedForControlAdapter": "Изображение адаптера контроля №{{number}} не обрабатывается"
"imageNotProcessedForControlAdapter": "Изображение адаптера контроля №{{number}} не обрабатывается",
"layer": {
"controlAdapterImageNotProcessed": "Изображение адаптера контроля не обработано",
"ipAdapterNoModelSelected": "IP адаптер не выбран",
"controlAdapterNoModelSelected": "не выбрана модель адаптера контроля",
"controlAdapterIncompatibleBaseModel": "несовместимая базовая модель адаптера контроля",
"controlAdapterNoImageSelected": "не выбрано изображение контрольного адаптера",
"initialImageNoImageSelected": "начальное изображение не выбрано",
"rgNoRegion": "регион не выбран",
"rgNoPromptsOrIPAdapters": "нет текстовых запросов или IP-адаптеров",
"ipAdapterIncompatibleBaseModel": "несовместимая базовая модель IP-адаптера",
"t2iAdapterIncompatibleDimensions": "Адаптер T2I требует, чтобы размеры изображения были кратны {{multiple}}",
"ipAdapterNoImageSelected": "изображение IP-адаптера не выбрано"
}
},
"isAllowedToUpscale": {
"useX2Model": "Изображение слишком велико для увеличения с помощью модели x4. Используйте модель x2",
@@ -606,12 +639,12 @@
"connected": "Подключено к серверу",
"canceled": "Обработка отменена",
"uploadFailedInvalidUploadDesc": "Должно быть одно изображение в формате PNG или JPEG",
"parameterNotSet": "Параметр {{parameter}} не задан",
"parameterSet": "Параметр {{parameter}} задан",
"parameterNotSet": "Параметр не задан",
"parameterSet": "Параметр задан",
"problemCopyingImage": "Не удается скопировать изображение",
"baseModelChangedCleared_one": "Базовая модель изменила, очистила или отключила {{count}} несовместимую подмодель",
"baseModelChangedCleared_few": "Базовая модель изменила, очистила или отключила {{count}} несовместимые подмодели",
"baseModelChangedCleared_many": "Базовая модель изменила, очистила или отключила {{count}} несовместимых подмоделей",
"baseModelChangedCleared_one": "Очищена или отключена {{count}} несовместимая подмодель",
"baseModelChangedCleared_few": "Очищены или отключены {{count}} несовместимые подмодели",
"baseModelChangedCleared_many": "Очищены или отключены {{count}} несовместимых подмоделей",
"imageSavingFailed": "Не удалось сохранить изображение",
"canvasSentControlnetAssets": "Холст отправлен в ControlNet и ресурсы",
"problemCopyingCanvasDesc": "Невозможно экспортировать базовый слой",
@@ -652,7 +685,17 @@
"resetInitialImage": "Сбросить начальное изображение",
"prunedQueue": "Урезанная очередь",
"modelImportCanceled": "Импорт модели отменен",
"parameters": "Параметры"
"parameters": "Параметры",
"parameterSetDesc": "Задан {{parameter}}",
"parameterNotSetDesc": "Невозможно задать {{parameter}}",
"baseModelChanged": "Базовая модель сменена",
"parameterNotSetDescWithMessage": "Не удалось задать {{parameter}}: {{message}}",
"parametersSet": "Параметры заданы",
"errorCopied": "Ошибка скопирована",
"sessionRef": "Сессия: {{sessionId}}",
"outOfMemoryError": "Ошибка нехватки памяти",
"outOfMemoryErrorDesc": "Ваши текущие настройки генерации превышают возможности системы. Пожалуйста, измените настройки и повторите попытку.",
"somethingWentWrong": "Что-то пошло не так"
},
"tooltip": {
"feature": {
@@ -739,7 +782,8 @@
"loadMore": "Загрузить больше",
"resetUI": "$t(accessibility.reset) интерфейс",
"createIssue": "Сообщить о проблеме",
"about": "Об этом"
"about": "Об этом",
"submitSupportTicket": "Отправить тикет в службу поддержки"
},
"nodes": {
"zoomInNodes": "Увеличьте масштаб",
@@ -832,7 +876,7 @@
"workflowName": "Название",
"collection": "Коллекция",
"unknownErrorValidatingWorkflow": "Неизвестная ошибка при проверке рабочего процесса",
"collectionFieldType": "Коллекция {{name}}",
"collectionFieldType": "{{name}} (Коллекция)",
"workflowNotes": "Примечания",
"string": "Строка",
"unknownNodeType": "Неизвестный тип узла",
@@ -848,7 +892,7 @@
"targetNodeDoesNotExist": "Недопустимое ребро: целевой/входной узел {{node}} не существует",
"mismatchedVersion": "Недопустимый узел: узел {{node}} типа {{type}} имеет несоответствующую версию (попробовать обновить?)",
"unknownFieldType": "$t(nodes.unknownField) тип: {{type}}",
"collectionOrScalarFieldType": "Коллекция | Скаляр {{name}}",
"collectionOrScalarFieldType": "{{name}} (Один или коллекция)",
"betaDesc": "Этот вызов находится в бета-версии. Пока он не станет стабильным, в нем могут происходить изменения при обновлении приложений. Мы планируем поддерживать этот вызов в течение длительного времени.",
"nodeVersion": "Версия узла",
"loadingNodes": "Загрузка узлов...",
@@ -870,7 +914,16 @@
"noFieldsViewMode": "В этом рабочем процессе нет выбранных полей для отображения. Просмотрите полный рабочий процесс для настройки значений.",
"graph": "График",
"showEdgeLabels": "Показать метки на ребрах",
"showEdgeLabelsHelp": "Показать метки на ребрах, указывающие на соединенные узлы"
"showEdgeLabelsHelp": "Показать метки на ребрах, указывающие на соединенные узлы",
"cannotMixAndMatchCollectionItemTypes": "Невозможно смешивать и сопоставлять типы элементов коллекции",
"missingNode": "Отсутствует узел вызова",
"missingInvocationTemplate": "Отсутствует шаблон вызова",
"missingFieldTemplate": "Отсутствующий шаблон поля",
"singleFieldType": "{{name}} (Один)",
"noGraph": "Нет графика",
"imageAccessError": "Невозможно найти изображение {{image_name}}, сбрасываем на значение по умолчанию",
"boardAccessError": "Невозможно найти доску {{board_id}}, сбрасываем на значение по умолчанию",
"modelAccessError": "Невозможно найти модель {{key}}, сброс на модель по умолчанию"
},
"controlnet": {
"amult": "a_mult",
@@ -1441,7 +1494,16 @@
"clearQueueAlertDialog2": "Вы уверены, что хотите очистить очередь?",
"item": "Элемент",
"graphFailedToQueue": "Не удалось поставить график в очередь",
"openQueue": "Открыть очередь"
"openQueue": "Открыть очередь",
"prompts_one": "Запрос",
"prompts_few": "Запроса",
"prompts_many": "Запросов",
"iterations_one": "Итерация",
"iterations_few": "Итерации",
"iterations_many": "Итераций",
"generations_one": "Генерация",
"generations_few": "Генерации",
"generations_many": "Генераций"
},
"sdxl": {
"refinerStart": "Запуск доработчика",
@@ -1594,7 +1656,6 @@
"deleteAll": "Удалить всё",
"addLayer": "Добавить слой",
"moveToFront": "На передний план",
"toggleVisibility": "Переключить видимость слоя",
"addPositivePrompt": "Добавить $t(common.positivePrompt)",
"addIPAdapter": "Добавить $t(common.ipAdapter)",
"regionalGuidanceLayer": "$t(controlLayers.regionalGuidance) $t(unifiedCanvas.layer)",

View File

@@ -1,6 +1,6 @@
{
"common": {
"nodes": "節點",
"nodes": "工作流程",
"img2img": "圖片轉圖片",
"statusDisconnected": "已中斷連線",
"back": "返回",
@@ -11,17 +11,239 @@
"reportBugLabel": "回報錯誤",
"githubLabel": "GitHub",
"hotkeysLabel": "快捷鍵",
"languagePickerLabel": "切換語言",
"languagePickerLabel": "語言",
"unifiedCanvas": "統一畫布",
"cancel": "取消",
"txt2img": "文字轉圖片"
"txt2img": "文字轉圖片",
"controlNet": "ControlNet",
"advanced": "進階",
"folder": "資料夾",
"installed": "已安裝",
"accept": "接受",
"goTo": "前往",
"input": "輸入",
"random": "隨機",
"selected": "已選擇",
"communityLabel": "社群",
"loading": "載入中",
"delete": "刪除",
"copy": "複製",
"error": "錯誤",
"file": "檔案",
"format": "格式",
"imageFailedToLoad": "無法載入圖片"
},
"accessibility": {
"invokeProgressBar": "Invoke 進度條",
"uploadImage": "上傳圖片",
"reset": "重",
"reset": "重",
"nextImage": "下一張圖片",
"previousImage": "上一張圖片",
"menu": "選單"
"menu": "選單",
"loadMore": "載入更多",
"about": "關於",
"createIssue": "建立問題",
"resetUI": "$t(accessibility.reset) 介面",
"submitSupportTicket": "提交支援工單",
"mode": "模式"
},
"boards": {
"loading": "載入中…",
"movingImagesToBoard_other": "正在移動 {{count}} 張圖片至板上:",
"move": "移動",
"uncategorized": "未分類",
"cancel": "取消"
},
"metadata": {
"workflow": "工作流程",
"steps": "步數",
"model": "模型",
"seed": "種子",
"vae": "VAE",
"seamless": "無縫",
"metadata": "元數據",
"width": "寬度",
"height": "高度"
},
"accordions": {
"control": {
"title": "控制"
},
"compositing": {
"title": "合成"
},
"advanced": {
"title": "進階",
"options": "$t(accordions.advanced.title) 選項"
}
},
"hotkeys": {
"nodesHotkeys": "節點",
"cancel": {
"title": "取消"
},
"generalHotkeys": "一般",
"keyboardShortcuts": "快捷鍵",
"appHotkeys": "應用程式"
},
"modelManager": {
"advanced": "進階",
"allModels": "全部模型",
"variant": "變體",
"config": "配置",
"model": "模型",
"selected": "已選擇",
"huggingFace": "HuggingFace",
"install": "安裝",
"metadata": "元數據",
"delete": "刪除",
"description": "描述",
"cancel": "取消",
"convert": "轉換",
"manual": "手動",
"none": "無",
"name": "名稱",
"load": "載入",
"height": "高度",
"width": "寬度",
"search": "搜尋",
"vae": "VAE",
"settings": "設定"
},
"controlnet": {
"mlsd": "M-LSD",
"canny": "Canny",
"duplicate": "重複",
"none": "無",
"pidi": "PIDI",
"h": "H",
"balanced": "平衡",
"crop": "裁切",
"processor": "處理器",
"control": "控制",
"f": "F",
"lineart": "線條藝術",
"w": "W",
"hed": "HED",
"delete": "刪除"
},
"queue": {
"queue": "佇列",
"canceled": "已取消",
"failed": "已失敗",
"completed": "已完成",
"cancel": "取消",
"session": "工作階段",
"batch": "批量",
"item": "項目",
"completedIn": "完成於",
"notReady": "無法排隊"
},
"parameters": {
"cancel": {
"cancel": "取消"
},
"height": "高度",
"type": "類型",
"symmetry": "對稱性",
"images": "圖片",
"width": "寬度",
"coherenceMode": "模式",
"seed": "種子",
"general": "一般",
"strength": "強度",
"steps": "步數",
"info": "資訊"
},
"settings": {
"beta": "Beta",
"developer": "開發者",
"general": "一般",
"models": "模型"
},
"popovers": {
"paramModel": {
"heading": "模型"
},
"compositingCoherenceMode": {
"heading": "模式"
},
"paramSteps": {
"heading": "步數"
},
"controlNetProcessor": {
"heading": "處理器"
},
"paramVAE": {
"heading": "VAE"
},
"paramHeight": {
"heading": "高度"
},
"paramSeed": {
"heading": "種子"
},
"paramWidth": {
"heading": "寬度"
},
"refinerSteps": {
"heading": "步數"
}
},
"unifiedCanvas": {
"undo": "復原",
"mask": "遮罩",
"eraser": "橡皮擦",
"antialiasing": "抗鋸齒",
"redo": "重做",
"layer": "圖層",
"accept": "接受",
"brush": "刷子",
"move": "移動",
"brushSize": "大小"
},
"nodes": {
"workflowName": "名稱",
"notes": "註釋",
"workflowVersion": "版本",
"workflowNotes": "註釋",
"executionStateError": "錯誤",
"unableToUpdateNodes_other": "無法更新 {{count}} 個節點",
"integer": "整數",
"workflow": "工作流程",
"enum": "枚舉",
"edit": "編輯",
"string": "字串",
"workflowTags": "標籤",
"node": "節點",
"boolean": "布林值",
"workflowAuthor": "作者",
"version": "版本",
"executionStateCompleted": "已完成",
"edge": "邊緣",
"versionUnknown": " 版本未知"
},
"sdxl": {
"steps": "步數",
"loading": "載入中…",
"refiner": "精煉器"
},
"gallery": {
"copy": "複製",
"download": "下載",
"loading": "載入中"
},
"ui": {
"tabs": {
"models": "模型",
"queueTab": "$t(ui.tabs.queue) $t(common.tab)",
"queue": "佇列"
}
},
"models": {
"loading": "載入中"
},
"workflows": {
"name": "名稱"
}
}

Some files were not shown because too many files have changed in this diff Show More